tesseract  5.0.0
lstmrecognizer.h
Go to the documentation of this file.
1 // File: lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
19 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 
21 #include "ccutil.h"
22 #include "helpers.h"
23 #include "matrix.h"
24 #include "network.h"
25 #include "networkscratch.h"
26 #include "params.h"
27 #include "recodebeam.h"
28 #include "series.h"
29 #include "unicharcompress.h"
30 
31 class BLOB_CHOICE_IT;
32 struct Pix;
33 class ROW_RES;
34 class ScrollView;
35 class TBOX;
36 class WERD_RES;
37 
38 namespace tesseract {
39 
40 class Dict;
41 class ImageData;
42 
43 // Enum indicating training mode control flags.
47 };
48 
49 // Top-level line recognizer class for LSTM-based networks.
50 // Note that a sub-class, LSTMTrainer is used for training.
52 public:
54  LSTMRecognizer(const std::string &language_data_path_prefix);
55  ~LSTMRecognizer();
56 
57  int NumOutputs() const {
58  return network_->NumOutputs();
59  }
60 
61  // Return the training iterations.
62  int training_iteration() const {
63  return training_iteration_;
64  }
65 
66  // Return the sample iterations.
67  int sample_iteration() const {
68  return sample_iteration_;
69  }
70 
71  // Return the learning rate.
72  float learning_rate() const {
73  return learning_rate_;
74  }
75 
77  if (network_ == nullptr) {
78  return LT_NONE;
79  }
80  StaticShape shape;
81  shape = network_->OutputShape(shape);
82  return shape.loss_type();
83  }
84  bool SimpleTextOutput() const {
85  return OutputLossType() == LT_SOFTMAX;
86  }
87  bool IsIntMode() const {
88  return (training_flags_ & TF_INT_MODE) != 0;
89  }
90  // True if recoder_ is active to re-encode text to a smaller space.
91  bool IsRecoding() const {
92  return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
93  }
94  // Returns true if the network is a TensorFlow network.
95  bool IsTensorFlow() const {
96  return network_->type() == NT_TENSORFLOW;
97  }
98  // Returns a vector of layer ids that can be passed to other layer functions
99  // to access a specific layer.
100  std::vector<std::string> EnumerateLayers() const {
101  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
102  auto *series = static_cast<Series *>(network_);
103  std::vector<std::string> layers;
104  series->EnumerateLayers(nullptr, layers);
105  return layers;
106  }
107  // Returns a specific layer from its id (from EnumerateLayers).
108  Network *GetLayer(const std::string &id) const {
109  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
110  ASSERT_HOST(id.length() > 1 && id[0] == ':');
111  auto *series = static_cast<Series *>(network_);
112  return series->GetLayer(&id[1]);
113  }
114  // Returns the learning rate of the layer from its id.
115  float GetLayerLearningRate(const std::string &id) const {
116  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
117  if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
118  ASSERT_HOST(id.length() > 1 && id[0] == ':');
119  auto *series = static_cast<Series *>(network_);
120  return series->LayerLearningRate(&id[1]);
121  } else {
122  return learning_rate_;
123  }
124  }
125 
126  // Return the network string.
127  const char *GetNetwork() const {
128  return network_str_.c_str();
129  }
130 
131  // Return the adam beta.
132  float GetAdamBeta() const {
133  return adam_beta_;
134  }
135 
136  // Return the momentum.
137  float GetMomentum() const {
138  return momentum_;
139  }
140 
141  // Multiplies the all the learning rate(s) by the given factor.
142  void ScaleLearningRate(double factor) {
143  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
144  learning_rate_ *= factor;
145  if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
146  std::vector<std::string> layers = EnumerateLayers();
147  for (auto &layer : layers) {
148  ScaleLayerLearningRate(layer, factor);
149  }
150  }
151  }
152  // Multiplies the learning rate of the layer with id, by the given factor.
153  void ScaleLayerLearningRate(const std::string &id, double factor) {
154  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
155  ASSERT_HOST(id.length() > 1 && id[0] == ':');
156  auto *series = static_cast<Series *>(network_);
157  series->ScaleLayerLearningRate(&id[1], factor);
158  }
159 
160  // Set the all the learning rate(s) to the given value.
161  void SetLearningRate(float learning_rate)
162  {
163  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
164  learning_rate_ = learning_rate;
165  if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
166  for (auto &id : EnumerateLayers()) {
167  SetLayerLearningRate(id, learning_rate);
168  }
169  }
170  }
171  // Set the learning rate of the layer with id, by the given value.
172  void SetLayerLearningRate(const std::string &id, float learning_rate)
173  {
174  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
175  ASSERT_HOST(id.length() > 1 && id[0] == ':');
176  auto *series = static_cast<Series *>(network_);
177  series->SetLayerLearningRate(&id[1], learning_rate);
178  }
179 
180  // Converts the network to int if not already.
181  void ConvertToInt() {
182  if ((training_flags_ & TF_INT_MODE) == 0) {
183  network_->ConvertToInt();
184  training_flags_ |= TF_INT_MODE;
185  }
186  }
187 
188  // Provides access to the UNICHARSET that this classifier works with.
189  const UNICHARSET &GetUnicharset() const {
190  return ccutil_.unicharset;
191  }
193  return ccutil_.unicharset;
194  }
195  // Provides access to the UnicharCompress that this classifier works with.
196  const UnicharCompress &GetRecoder() const {
197  return recoder_;
198  }
199  // Provides access to the Dict that this classifier works with.
200  const Dict *GetDict() const {
201  return dict_;
202  }
204  return dict_;
205  }
206  // Sets the sample iteration to the given value. The sample_iteration_
207  // determines the seed for the random number generator. The training
208  // iteration is incremented only by a successful training iteration.
209  void SetIteration(int iteration) {
210  sample_iteration_ = iteration;
211  }
212  // Accessors for textline image normalization.
213  int NumInputs() const {
214  return network_->NumInputs();
215  }
216 
217  // Return the null char index.
218  int null_char() const {
219  return null_char_;
220  }
221 
222  // Loads a model from mgr, including the dictionary only if lang is not null.
223  bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr);
224 
225  // Writes to the given file. Returns false in case of error.
226  // If mgr contains a unicharset and recoder, then they are not encoded to fp.
227  bool Serialize(const TessdataManager *mgr, TFile *fp) const;
228  // Reads from the given file. Returns false in case of error.
229  // If mgr contains a unicharset and recoder, then they are taken from there,
230  // otherwise, they are part of the serialization in fp.
231  bool DeSerialize(const TessdataManager *mgr, TFile *fp);
232  // Loads the charsets from mgr.
233  bool LoadCharsets(const TessdataManager *mgr);
234  // Loads the Recoder.
235  bool LoadRecoder(TFile *fp);
236  // Loads the dictionary if possible from the traineddata file.
237  // Prints a warning message, and returns false but otherwise fails silently
238  // and continues to work without it if loading fails.
239  // Note that dictionary load is independent from DeSerialize, but dependent
240  // on the unicharset matching. This enables training to deserialize a model
241  // from checkpoint or restore without having to go back and reload the
242  // dictionary.
243  bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr);
244 
245  // Recognizes the line image, contained within image_data, returning the
246  // recognized tesseract WERD_RES for the words.
247  // If invert, tries inverted as well if the normal interpretation doesn't
248  // produce a good enough result. The line_box is used for computing the
249  // box_word in the output words. worst_dict_cert is the worst certainty that
250  // will be used in a dictionary word.
251  void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert,
252  const TBOX &line_box, PointerVector<WERD_RES> *words, int lstm_choice_mode = 0,
253  int lstm_choice_amount = 5);
254 
255  // Helper computes min and mean best results in the output.
256  void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd);
257  // Recognizes the image_data, returning the labels,
258  // scores, and corresponding pairs of start, end x-coords in coords.
259  // Returned in scale_factor is the reduction factor
260  // between the image and the output coords, for computing bounding boxes.
261  // If re_invert is true, the input is inverted back to its original
262  // photometric interpretation if inversion is attempted but fails to
263  // improve the results. This ensures that outputs contains the correct
264  // forward outputs for the best photometric interpretation.
265  // inputs is filled with the used inputs to the network.
266  bool RecognizeLine(const ImageData &image_data, bool invert, bool debug, bool re_invert,
267  bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs);
268 
269  // Converts an array of labels to utf-8, whether or not the labels are
270  // augmented with character boundaries.
271  std::string DecodeLabels(const std::vector<int> &labels);
272 
273  // Displays the forward results in a window with the characters and
274  // boundaries as determined by the labels and label_coords.
275  void DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels,
276  const std::vector<int> &label_coords, const char *window_name,
277  ScrollView **window);
278  // Converts the network output to a sequence of labels. Outputs labels, scores
279  // and start xcoords of each char, and each null_char_, with an additional
280  // final xcoord for the end of the output.
281  // The conversion method is determined by internal state.
282  void LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels,
283  std::vector<int> *xcoords);
284 
285 protected:
286  // Sets the random seed from the sample_iteration_;
287  void SetRandomSeed() {
288  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
289  randomizer_.set_seed(seed);
290  randomizer_.IntRand();
291  }
292 
293  // Displays the labels and cuts at the corresponding xcoords.
294  // Size of labels should match xcoords.
295  void DisplayLSTMOutput(const std::vector<int> &labels, const std::vector<int> &xcoords,
296  int height, ScrollView *window);
297 
298  // Prints debug output detailing the activation path that is implied by the
299  // xcoords.
300  void DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels,
301  const std::vector<int> &xcoords);
302 
303  // Prints debug output detailing activations and 2nd choice over a range
304  // of positions.
305  void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice,
306  int x_start, int x_end);
307 
308  // As LabelsViaCTC except that this function constructs the best path that
309  // contains only legal sequences of subcodes for recoder_.
310  void LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels,
311  std::vector<int> *xcoords);
312  // Converts the network output to a sequence of labels, with scores, using
313  // the simple character model (each position is a char, and the null_char_ is
314  // mainly intended for tail padding.)
315  void LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels,
316  std::vector<int> *xcoords);
317 
318  // Returns a string corresponding to the label starting at start. Sets *end
319  // to the next start and if non-null, *decoded to the unichar id.
320  const char *DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end, int *decoded);
321 
322  // Returns a string corresponding to a given single label id, falling back to
323  // a default of ".." for part of a multi-label unichar-id.
324  const char *DecodeSingleLabel(int label);
325 
326 protected:
327  // The network hierarchy.
329  // The unicharset. Only the unicharset element is serialized.
330  // Has to be a CCUtil, so Dict can point to it.
332  // For backward compatibility, recoder_ is serialized iff
333  // training_flags_ & TF_COMPRESS_UNICHARSET.
334  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
336 
337  // ==Training parameters that are serialized to provide a record of them.==
338  std::string network_str_;
339  // Flags used to determine the training method of the network.
340  // See enum TrainingFlags above.
342  // Number of actual backward training steps used.
344  // Index into training sample set. sample_iteration >= training_iteration_.
346  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
347  // ccutil_.unicharset.size().
348  int32_t null_char_;
349  // Learning rate and momentum multipliers of deltas in backprop.
351  float momentum_;
352  // Smoothing factor for 2nd moment of gradients.
353  float adam_beta_;
354 
355  // === NOT SERIALIZED.
358  // Language model (optional) to use with the beam search.
360  // Beam search held between uses to optimize memory allocation/use.
362 
363  // == Debugging parameters.==
364  // Recognition debug display window.
366 };
367 
368 } // namespace tesseract.
369 
370 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define ASSERT_HOST(x)
Definition: errcode.h:59
@ TBOX
@ TF_COMPRESS_UNICHARSET
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:220
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:251
@ NT_TENSORFLOW
Definition: network.h:76
@ NT_SERIES
Definition: network.h:52
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
LossType OutputLossType() const
void SetLayerLearningRate(const std::string &id, float learning_rate)
NetworkScratch scratch_space_
const UNICHARSET & GetUnicharset() const
void SetIteration(int iteration)
RecodeBeamSearch * search_
const char * GetNetwork() const
void ScaleLearningRate(double factor)
void ScaleLayerLearningRate(const std::string &id, double factor)
const Dict * GetDict() const
std::vector< std::string > EnumerateLayers() const
float GetLayerLearningRate(const std::string &id) const
const UnicharCompress & GetRecoder() const
UNICHARSET & GetUnicharset()
Network * GetLayer(const std::string &id) const
void SetLearningRate(float learning_rate)
float LayerLearningRate(const char *id)
Definition: plumbing.h:112
TESS_API void EnumerateLayers(const std::string *prefix, std::vector< std::string > &layers) const
Definition: plumbing.cpp:144
void SetLayerLearningRate(const char *id, float learning_rate)
Definition: plumbing.h:125
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:118
TESS_API Network * GetLayer(const char *id) const
Definition: plumbing.cpp:161
LossType loss_type() const
Definition: static_shape.h:65
#define TESS_API
Definition: export.h:34