tesseract  5.0.0
lstmtrainer.h
Go to the documentation of this file.
1 // File: lstmtrainer.h
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
19 #define TESSERACT_LSTM_LSTMTRAINER_H_
20 
21 #include "export.h"
22 
23 #include "imagedata.h" // for DocumentCache
24 #include "lstmrecognizer.h"
25 #include "rect.h"
26 
27 #include <functional> // for std::function
28 
29 namespace tesseract {
30 
31 class LSTM;
32 class LSTMTester;
33 class LSTMTrainer;
34 class Parallel;
35 class Reversed;
36 class Softmax;
37 class Series;
38 
39 // Enum for the types of errors that are counted.
40 enum ErrorTypes {
41  ET_RMS, // RMS activation error.
42  ET_DELTA, // Number of big errors in deltas.
43  ET_WORD_RECERR, // Output text string word recall error.
44  ET_CHAR_ERROR, // Output text string total char error.
45  ET_SKIP_RATIO, // Fraction of samples skipped.
46  ET_COUNT // For array sizing.
47 };
48 
49 // Enum for the trainability_ flags.
51  TRAINABLE, // Non-zero delta error.
52  PERFECT, // Zero delta error.
53  UNENCODABLE, // Not trainable due to coding/alignment trouble.
54  HI_PRECISION_ERR, // Hi confidence disagreement.
55  NOT_BOXED, // Early in training and has no character boxes.
56 };
57 
58 // Enum to define the amount of data to get serialized.
60  LIGHT, // Minimal data for remote training.
61  NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
62  FULL, // All data including best_trainer_.
63 };
64 
65 // Enum to indicate how the sub_trainer_ training went.
67  STR_NONE, // Did nothing as not good enough.
68  STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
69  STR_REPLACED // Subtrainer replaced *this.
70 };
71 
72 class LSTMTrainer;
73 // Function to compute and record error rates on some external test set(s).
74 // Args are: iteration, mean errors, model, training stage.
75 // Returns a string containing logging information about the tests.
76 using TestCallback = std::function<std::string(int, const double *,
77  const TessdataManager &, int)>;
78 
79 // Trainer class for LSTM networks. Most of the effort is in creating the
80 // ideal target outputs from the transcription. A box file is used if it is
81 // available, otherwise estimates of the char widths from the unicharset are
82 // used to guide a DP search for the best fit to the transcription.
83 class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
84 public:
85  LSTMTrainer();
86  LSTMTrainer(const char *model_base, const char *checkpoint_name,
87  int debug_interval, int64_t max_memory);
88  virtual ~LSTMTrainer();
89 
90  // Tries to deserialize a trainer from the given file and silently returns
91  // false in case of failure. If old_traineddata is not null, then it is
92  // assumed that the character set is to be re-mapped from old_traineddata to
93  // the new, with consequent change in weight matrices etc.
94  bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata);
95 
96  // Initializes the character set encode/decode mechanism directly from a
97  // previously setup traineddata containing dawgs, UNICHARSET and
98  // UnicharCompress. Note: Call before InitNetwork!
99  bool InitCharSet(const std::string &traineddata_path) {
100  bool success = mgr_.Init(traineddata_path.c_str());
101  if (success) {
102  InitCharSet();
103  }
104  return success;
105  }
106  void InitCharSet(const TessdataManager &mgr) {
107  mgr_ = mgr;
108  InitCharSet();
109  }
110 
111  // Initializes the trainer with a network_spec in the network description
112  // net_flags control network behavior according to the NetworkFlags enum.
113  // There isn't really much difference between them - only where the effects
114  // are implemented.
115  // For other args see NetworkBuilder::InitNetwork.
116  // Note: Be sure to call InitCharSet before InitNetwork!
117  bool InitNetwork(const char *network_spec, int append_index, int net_flags,
118  float weight_range, float learning_rate, float momentum,
119  float adam_beta);
120  // Initializes a trainer from a serialized TFNetworkModel proto.
121  // Returns the global step of TensorFlow graph or 0 if failed.
122  // Building a compatible TF graph: See tfnetwork.proto.
123  int InitTensorFlowNetwork(const std::string &tf_proto);
124  // Resets all the iteration counters for fine tuning or training a head,
125  // where we want the error reporting to reset.
126  void InitIterations();
127 
128  // Accessors.
129  double ActivationError() const {
130  return error_rates_[ET_DELTA];
131  }
132  double CharError() const {
133  return error_rates_[ET_CHAR_ERROR];
134  }
135  const double *error_rates() const {
136  return error_rates_;
137  }
138  double best_error_rate() const {
139  return best_error_rate_;
140  }
141  int best_iteration() const {
142  return best_iteration_;
143  }
144  int learning_iteration() const {
145  return learning_iteration_;
146  }
147  int32_t improvement_steps() const {
148  return improvement_steps_;
149  }
150  void set_perfect_delay(int delay) {
151  perfect_delay_ = delay;
152  }
153  const std::vector<char> &best_trainer() const {
154  return best_trainer_;
155  }
156  // Returns the error that was just calculated by PrepareForBackward.
157  double NewSingleError(ErrorTypes type) const {
158  return error_buffers_[type][training_iteration() % kRollingBufferSize_];
159  }
160  // Returns the error that was just calculated by TrainOnLine. Since
161  // TrainOnLine rolls the error buffers, this is one further back than
162  // NewSingleError.
163  double LastSingleError(ErrorTypes type) const {
164  return error_buffers_[type]
165  [(training_iteration() + kRollingBufferSize_ - 1) %
166  kRollingBufferSize_];
167  }
168  const DocumentCache &training_data() const {
169  return training_data_;
170  }
172  return &training_data_;
173  }
174 
175  // If the training sample is usable, grid searches for the optimal
176  // dict_ratio/cert_offset, and returns the results in a string of space-
177  // separated triplets of ratio,offset=worderr.
178  Trainability GridSearchDictParams(
179  const ImageData *trainingdata, int iteration, double min_dict_ratio,
180  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
181  double cert_offset_step, double max_cert_offset, std::string &results);
182 
183  // Provides output on the distribution of weight values.
184  void DebugNetwork();
185 
186  // Loads a set of lstmf files that were created using the lstm.train config to
187  // tesseract into memory ready for training. Returns false if nothing was
188  // loaded.
189  bool LoadAllTrainingData(const std::vector<std::string> &filenames,
190  CachingStrategy cache_strategy,
191  bool randomly_rotate);
192 
193  // Keeps track of best and locally worst error rate, using internally computed
194  // values. See MaintainCheckpointsSpecific for more detail.
195  bool MaintainCheckpoints(const TestCallback &tester, std::string &log_msg);
196  // Keeps track of best and locally worst error_rate (whatever it is) and
197  // launches tests using rec_model, when a new min or max is reached.
198  // Writes checkpoints using train_model at appropriate times and builds and
199  // returns a log message to indicate progress. Returns false if nothing
200  // interesting happened.
201  bool MaintainCheckpointsSpecific(int iteration,
202  const std::vector<char> *train_model,
203  const std::vector<char> *rec_model,
204  TestCallback tester, std::string &log_msg);
205  // Builds a string containing a progress message with current error rates.
206  void PrepareLogMsg(std::string &log_msg) const;
207  // Appends <intro_str> iteration learning_iteration()/training_iteration()/
208  // sample_iteration() to the log_msg.
209  void LogIterations(const char *intro_str, std::string &log_msg) const;
210 
211  // TODO(rays) Add curriculum learning.
212  // Returns true and increments the training_stage_ if the error rate has just
213  // passed through the given threshold for the first time.
214  bool TransitionTrainingStage(float error_threshold);
215  // Returns the current training stage.
216  int CurrentTrainingStage() const {
217  return training_stage_;
218  }
219 
220  // Writes to the given file. Returns false in case of error.
221  bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr,
222  TFile *fp) const;
223  // Reads from the given file. Returns false in case of error.
224  bool DeSerialize(const TessdataManager *mgr, TFile *fp);
225 
226  // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
227  // learning rates (by scaling reduction, or layer specific, according to
228  // NF_LAYER_SPECIFIC_LR).
229  void StartSubtrainer(std::string &log_msg);
230  // While the sub_trainer_ is behind the current training iteration and its
231  // training error is at least kSubTrainerMarginFraction better than the
232  // current training error, trains the sub_trainer_, and returns STR_UPDATED if
233  // it did anything. If it catches up, and has a better error rate than the
234  // current best, as well as a margin over the current error rate, then the
235  // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
236  // returned. STR_NONE is returned if the subtrainer wasn't good enough to
237  // receive any training iterations.
238  SubTrainerResult UpdateSubtrainer(std::string &log_msg);
239  // Reduces network learning rates, either for everything, or for layers
240  // independently, according to NF_LAYER_SPECIFIC_LR.
241  void ReduceLearningRates(LSTMTrainer *samples_trainer, std::string &log_msg);
242  // Considers reducing the learning rate independently for each layer down by
243  // factor(<1), or leaving it the same, by double-training the given number of
244  // samples and minimizing the amount of changing of sign of weight updates.
245  // Even if it looks like all weights should remain the same, an adjustment
246  // will be made to guarantee a different result when reverting to an old best.
247  // Returns the number of layer learning rates that were reduced.
248  int ReduceLayerLearningRates(TFloat factor, int num_samples,
249  LSTMTrainer *samples_trainer);
250 
251  // Converts the string to integer class labels, with appropriate null_char_s
252  // in between if not in SimpleTextOutput mode. Returns false on failure.
253  bool EncodeString(const std::string &str, std::vector<int> *labels) const {
254  return EncodeString(str, GetUnicharset(),
255  IsRecoding() ? &recoder_ : nullptr, SimpleTextOutput(),
256  null_char_, labels);
257  }
258  // Static version operates on supplied unicharset, encoder, simple_text.
259  static bool EncodeString(const std::string &str, const UNICHARSET &unicharset,
260  const UnicharCompress *recoder, bool simple_text,
261  int null_char, std::vector<int> *labels);
262 
263  // Performs forward-backward on the given trainingdata.
264  // Returns the sample that was used or nullptr if the next sample was deemed
265  // unusable. samples_trainer could be this or an alternative trainer that
266  // holds the training samples.
267  const ImageData *TrainOnLine(LSTMTrainer *samples_trainer, bool batch) {
268  int sample_index = sample_iteration();
269  const ImageData *image =
270  samples_trainer->training_data_.GetPageBySerial(sample_index);
271  if (image != nullptr) {
272  Trainability trainable = TrainOnLine(image, batch);
273  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
274  return nullptr; // Sample was unusable.
275  }
276  } else {
277  ++sample_iteration_;
278  }
279  return image;
280  }
281  Trainability TrainOnLine(const ImageData *trainingdata, bool batch);
282 
283  // Prepares the ground truth, runs forward, and prepares the targets.
284  // Returns a Trainability enum to indicate the suitability of the sample.
285  Trainability PrepareForBackward(const ImageData *trainingdata,
286  NetworkIO *fwd_outputs, NetworkIO *targets);
287 
288  // Writes the trainer to memory, so that the current training state can be
289  // restored. *this must always be the master trainer that retains the only
290  // copy of the training data and language model. trainer is the model that is
291  // actually serialized.
292  bool SaveTrainingDump(SerializeAmount serialize_amount,
293  const LSTMTrainer &trainer,
294  std::vector<char> *data) const;
295 
296  // Reads previously saved trainer from memory. *this must always be the
297  // master trainer that retains the only copy of the training data and
298  // language model. trainer is the model that is restored.
299  bool ReadTrainingDump(const std::vector<char> &data,
300  LSTMTrainer &trainer) const {
301  if (data.empty()) {
302  return false;
303  }
304  return ReadSizedTrainingDump(&data[0], data.size(), trainer);
305  }
306  bool ReadSizedTrainingDump(const char *data, int size,
307  LSTMTrainer &trainer) const {
308  return trainer.ReadLocalTrainingDump(&mgr_, data, size);
309  }
310  // Restores the model to *this.
311  bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data,
312  int size);
313 
314  // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
316 
317  // Writes the full recognition traineddata to the given filename.
318  bool SaveTraineddata(const char *filename);
319 
320  // Writes the recognizer to memory, so that it can be used for testing later.
321  void SaveRecognitionDump(std::vector<char> *data) const;
322 
323  // Returns a suitable filename for a training dump, based on the model_base_,
324  // the iteration and the error rates.
325  std::string DumpFilename() const;
326 
327  // Fills the whole error buffer of the given type with the given value.
328  void FillErrorBuffer(double new_error, ErrorTypes type);
329  // Helper generates a map from each current recoder_ code (ie softmax index)
330  // to the corresponding old_recoder code, or -1 if there isn't one.
331  std::vector<int> MapRecoder(const UNICHARSET &old_chset,
332  const UnicharCompress &old_recoder) const;
333 
334 protected:
335  // Private version of InitCharSet above finishes the job after initializing
336  // the mgr_ data member.
337  void InitCharSet();
338  // Helper computes and sets the null_char_.
339  void SetNullChar();
340 
341  // Factored sub-constructor sets up reasonable default values.
342  void EmptyConstructor();
343 
344  // Outputs the string and periodically displays the given network inputs
345  // as an image in the given window, and the corresponding labels at the
346  // corresponding x_starts.
347  // Returns false if the truth string is empty.
348  bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata,
349  const NetworkIO &fwd_outputs,
350  const std::vector<int> &truth_labels,
351  const NetworkIO &outputs);
352  // Displays the network targets as line a line graph.
353  void DisplayTargets(const NetworkIO &targets, const char *window_name,
354  ScrollView **window);
355 
356  // Builds a no-compromises target where the first positions should be the
357  // truth labels and the rest is padded with the null_char_.
358  bool ComputeTextTargets(const NetworkIO &outputs,
359  const std::vector<int> &truth_labels,
360  NetworkIO *targets);
361 
362  // Builds a target using standard CTC. truth_labels should be pre-padded with
363  // nulls wherever desired. They don't have to be between all labels.
364  // outputs is input-output, as it gets clipped to minimum probability.
365  bool ComputeCTCTargets(const std::vector<int> &truth_labels,
366  NetworkIO *outputs, NetworkIO *targets);
367 
368  // Computes network errors, and stores the results in the rolling buffers,
369  // along with the supplied text_error.
370  // Returns the delta error of the current sample (not running average.)
371  double ComputeErrorRates(const NetworkIO &deltas, double char_error,
372  double word_error);
373 
374  // Computes the network activation RMS error rate.
375  double ComputeRMSError(const NetworkIO &deltas);
376 
377  // Computes network activation winner error rate. (Number of values that are
378  // in error by >= 0.5 divided by number of time-steps.) More closely related
379  // to final character error than RMS, but still directly calculable from
380  // just the deltas. Because of the binary nature of the targets, zero winner
381  // error is a sufficient but not necessary condition for zero char error.
382  double ComputeWinnerError(const NetworkIO &deltas);
383 
384  // Computes a very simple bag of chars char error rate.
385  double ComputeCharError(const std::vector<int> &truth_str,
386  const std::vector<int> &ocr_str);
387  // Computes a very simple bag of words word recall error rate.
388  // NOTE that this is destructive on both input strings.
389  double ComputeWordError(std::string *truth_str, std::string *ocr_str);
390 
391  // Updates the error buffer and corresponding mean of the given type with
392  // the new_error.
393  void UpdateErrorBuffer(double new_error, ErrorTypes type);
394 
395  // Rolls error buffers and reports the current means.
396  void RollErrorBuffers();
397 
398  // Given that error_rate is either a new min or max, updates the best/worst
399  // error rates, and record of progress.
400  std::string UpdateErrorGraph(int iteration, double error_rate,
401  const std::vector<char> &model_data,
402  const TestCallback &tester);
403 
404 protected:
405 #ifndef GRAPHICS_DISABLED
406  // Alignment display window.
408  // CTC target display window.
410  // CTC output display window.
412  // Reconstructed image window.
414 #endif
415  // How often to display a debug image.
417  // Iteration at which the last checkpoint was dumped.
419  // Basename of files to save best models to.
420  std::string model_base_;
421  // Checkpoint filename.
422  std::string checkpoint_name_;
423  // Training data.
426  // Name to use when saving best_trainer_.
427  std::string best_model_name_;
428  // Number of available training stages.
430 
431  // ===Serialized data to ensure that a restart produces the same results.===
432  // These members are only serialized when serialize_amount != LIGHT.
433  // Best error rate so far.
435  // Snapshot of all error rates at best_iteration_.
436  double best_error_rates_[ET_COUNT];
437  // Iteration of best_error_rate_.
439  // Worst error rate since best_error_rate_.
441  // Snapshot of all error rates at worst_iteration_.
442  double worst_error_rates_[ET_COUNT];
443  // Iteration of worst_error_rate_.
445  // Iteration at which the process will be thought stalled.
447  // Saved recognition models for computing test error for graph points.
448  std::vector<char> best_model_data_;
449  std::vector<char> worst_model_data_;
450  // Saved trainer for reverting back to last known best.
451  std::vector<char> best_trainer_;
452  // A subsidiary trainer running with a different learning rate until either
453  // *this or sub_trainer_ hits a new best.
454  std::unique_ptr<LSTMTrainer> sub_trainer_;
455  // Error rate at which last best model was dumped.
457  // Current stage of training.
459  // History of best error rate against iteration. Used for computing the
460  // number of steps to each 2% improvement.
461  std::vector<double> best_error_history_;
462  std::vector<int32_t> best_error_iterations_;
463  // Number of iterations since the best_error_rate_ was 2% more than it is now.
465  // Number of iterations that yielded a non-zero delta error and thus provided
466  // significant learning. learning_iteration_ <= training_iteration_.
467  // learning_iteration_ is used to measure rate of learning progress.
469  // Saved value of sample_iteration_ before looking for the the next sample.
471  // How often to include a PERFECT training sample in backprop.
472  // A PERFECT training sample is used if the current
473  // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
474  // so with perfect_delay_ == 0, all samples are used, and with
475  // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
477  // Value of training_iteration_ at which the last PERFECT training sample
478  // was used in back prop.
480  // Rolling buffers storing recent training errors are indexed by
481  // training_iteration % kRollingBufferSize_.
482  static const int kRollingBufferSize_ = 1000;
483  std::vector<double> error_buffers_[ET_COUNT];
484  // Rounded mean percent trailing training errors in the buffers.
485  double error_rates_[ET_COUNT]; // RMS training error.
486  // Traineddata file with optional dawgs + UNICHARSET and recoder.
488 };
489 
490 } // namespace tesseract.
491 
492 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
@ ET_WORD_RECERR
Definition: lstmtrainer.h:43
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:45
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:44
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:54
@ STR_REPLACED
Definition: lstmtrainer.h:69
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:220
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:251
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:77
double TFloat
Definition: tesstypes.h:39
CachingStrategy
Definition: imagedata.h:42
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:61
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:317
std::vector< int32_t > best_error_iterations_
Definition: lstmtrainer.h:462
std::vector< char > worst_model_data_
Definition: lstmtrainer.h:449
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
ScrollView * target_win_
Definition: lstmtrainer.h:409
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:253
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:171
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:99
int InitTensorFlowNetwork(const std::string &tf_proto)
const double * error_rates() const
Definition: lstmtrainer.h:135
std::string model_base_
Definition: lstmtrainer.h:420
std::string best_model_name_
Definition: lstmtrainer.h:427
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:157
double CharError() const
Definition: lstmtrainer.h:132
std::vector< char > best_trainer_
Definition: lstmtrainer.h:451
double best_error_rate() const
Definition: lstmtrainer.h:138
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:163
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:267
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:456
ScrollView * recon_win_
Definition: lstmtrainer.h:413
const DocumentCache & training_data() const
Definition: lstmtrainer.h:168
int learning_iteration() const
Definition: lstmtrainer.h:144
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:150
std::string checkpoint_name_
Definition: lstmtrainer.h:422
ScrollView * ctc_win_
Definition: lstmtrainer.h:411
const std::vector< char > & best_trainer() const
Definition: lstmtrainer.h:153
int CurrentTrainingStage() const
Definition: lstmtrainer.h:216
double ActivationError() const
Definition: lstmtrainer.h:129
std::vector< char > best_model_data_
Definition: lstmtrainer.h:448
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:306
void InitCharSet(const TessdataManager &mgr)
Definition: lstmtrainer.h:106
DocumentCache training_data_
Definition: lstmtrainer.h:425
std::unique_ptr< LSTMTrainer > sub_trainer_
Definition: lstmtrainer.h:454
bool ReadTrainingDump(const std::vector< char > &data, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:299
int32_t improvement_steps() const
Definition: lstmtrainer.h:147
bool MaintainCheckpointsSpecific(int iteration, const std::vector< char > *train_model, const std::vector< char > *rec_model, TestCallback tester, std::string &log_msg)
std::vector< double > best_error_history_
Definition: lstmtrainer.h:461
int best_iteration() const
Definition: lstmtrainer.h:141
TessdataManager mgr_
Definition: lstmtrainer.h:487
ScrollView * align_win_
Definition: lstmtrainer.h:407