tesseract  5.0.0
static_shape.h
Go to the documentation of this file.
1 // File: static_shape.h
3 // Description: Defines the size of the 4-d tensor input/output from a network.
4 // Author: Ray Smith
5 // Created: Fri Oct 14 09:07:31 PST 2016
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
20 #define TESSERACT_LSTM_STATIC_SHAPE_H_
21 
22 #include "serialis.h" // for TFile
23 #include "tprintf.h" // for tprintf
24 
25 namespace tesseract {
26 
27 // Enum describing the loss function to apply during training and/or the
28 // decoding method to apply at runtime.
29 enum LossType {
30  LT_NONE, // Undefined.
31  LT_CTC, // Softmax with standard CTC for training/decoding.
32  LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
33  LT_LOGISTIC, // Logistic outputs with independent values.
34 };
35 
36 // Simple class to hold the tensor shape that is known at network build time
37 // and the LossType of the loss function.
38 class StaticShape {
39 public:
40  StaticShape() : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
41  int batch() const {
42  return batch_;
43  }
44  void set_batch(int value) {
45  batch_ = value;
46  }
47  int height() const {
48  return height_;
49  }
50  void set_height(int value) {
51  height_ = value;
52  }
53  int width() const {
54  return width_;
55  }
56  void set_width(int value) {
57  width_ = value;
58  }
59  int depth() const {
60  return depth_;
61  }
62  void set_depth(int value) {
63  depth_ = value;
64  }
65  LossType loss_type() const {
66  return loss_type_;
67  }
68  void set_loss_type(LossType value) {
69  loss_type_ = value;
70  }
71  void SetShape(int batch, int height, int width, int depth) {
72  batch_ = batch;
73  height_ = height;
74  width_ = width;
75  depth_ = depth;
76  }
77 
78  void Print() const {
79  tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_, height_, width_, depth_,
80  loss_type_);
81  }
82 
83  bool DeSerialize(TFile *fp) {
84  int32_t tmp = LT_NONE;
85  bool result = fp->DeSerialize(&batch_) && fp->DeSerialize(&height_) &&
86  fp->DeSerialize(&width_) && fp->DeSerialize(&depth_) && fp->DeSerialize(&tmp);
87  loss_type_ = static_cast<LossType>(tmp);
88  return result;
89  }
90 
91  bool Serialize(TFile *fp) const {
92  int32_t tmp = loss_type_;
93  return fp->Serialize(&batch_) && fp->Serialize(&height_) && fp->Serialize(&width_) &&
94  fp->Serialize(&depth_) && fp->Serialize(&tmp);
95  }
96 
97 private:
98  // Size of the 4-D tensor input/output to a network. A value of zero is
99  // allowed for all except depth_ and means to be determined at runtime, and
100  // regarded as variable.
101  // Number of elements in a batch, or number of frames in a video stream.
102  int32_t batch_;
103  // Height of the image.
104  int32_t height_;
105  // Width of the image.
106  int32_t width_;
107  // Depth of the image. (Number of "nodes").
108  int32_t depth_;
109  // How to train/interpret the output.
110  LossType loss_type_;
111 };
112 
113 } // namespace tesseract
114 
115 #endif // TESSERACT_LSTM_STATIC_SHAPE_H_
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
void set_batch(int value)
Definition: static_shape.h:44
void set_loss_type(LossType value)
Definition: static_shape.h:68
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:71
void set_depth(int value)
Definition: static_shape.h:62
LossType loss_type() const
Definition: static_shape.h:65
bool Serialize(TFile *fp) const
Definition: static_shape.h:91
void set_width(int value)
Definition: static_shape.h:56
void set_height(int value)
Definition: static_shape.h:50
bool DeSerialize(TFile *fp)
Definition: static_shape.h:83