tesseract  5.0.0
lstmtrainer.cpp
Go to the documentation of this file.
1 // File: lstmtrainer.cpp
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
19 
20 // Include automatically generated configuration file if running autoconf.
21 #ifdef HAVE_CONFIG_H
22 # include "config_auto.h"
23 #endif
24 
25 #include <cmath>
26 #include <string>
27 #include "lstmtrainer.h"
28 
29 #include <allheaders.h>
30 #include "boxread.h"
31 #include "ctc.h"
32 #include "imagedata.h"
33 #include "input.h"
34 #include "networkbuilder.h"
35 #include "ratngs.h"
36 #include "recodebeam.h"
37 #ifdef INCLUDE_TENSORFLOW
38 # include "tfnetwork.h"
39 #endif
40 #include "tprintf.h"
41 
42 namespace tesseract {
43 
44 // Min actual error rate increase to constitute divergence.
45 const double kMinDivergenceRate = 50.0;
46 // Min iterations since last best before acting on a stall.
47 const int kMinStallIterations = 10000;
48 // Fraction of current char error rate that sub_trainer_ has to be ahead
49 // before we declare the sub_trainer_ a success and switch to it.
50 const double kSubTrainerMarginFraction = 3.0 / 128;
51 // Factor to reduce learning rate on divergence.
52 const double kLearningRateDecay = M_SQRT1_2;
53 // LR adjustment iterations.
54 const int kNumAdjustmentIterations = 100;
55 // How often to add data to the error_graph_.
56 const int kErrorGraphInterval = 1000;
57 // Number of training images to train between calls to MaintainCheckpoints.
58 const int kNumPagesPerBatch = 100;
59 // Min percent error rate to consider start-up phase over.
60 const int kMinStartedErrorRate = 75;
61 // Error rate at which to transition to stage 1.
62 const double kStageTransitionThreshold = 10.0;
63 // Confidence beyond which the truth is more likely wrong than the recognizer.
64 const double kHighConfidence = 0.9375; // 15/16.
65 // Fraction of weight sign-changing total to constitute a definite improvement.
66 const double kImprovementFraction = 15.0 / 16.0;
67 // Fraction of last written best to make it worth writing another.
68 const double kBestCheckpointFraction = 31.0 / 32.0;
69 #ifndef GRAPHICS_DISABLED
70 // Scale factor for display of target activations of CTC.
71 const int kTargetXScale = 5;
72 const int kTargetYScale = 100;
73 #endif // !GRAPHICS_DISABLED
74 
76  : randomly_rotate_(false), training_data_(0), sub_trainer_(nullptr) {
78  debug_interval_ = 0;
79 }
80 
81 LSTMTrainer::LSTMTrainer(const char *model_base, const char *checkpoint_name,
82  int debug_interval, int64_t max_memory)
83  : randomly_rotate_(false),
84  training_data_(max_memory),
85  sub_trainer_(nullptr) {
87  debug_interval_ = debug_interval;
88  model_base_ = model_base;
89  checkpoint_name_ = checkpoint_name;
90 }
91 
93 #ifndef GRAPHICS_DISABLED
94  delete align_win_;
95  delete target_win_;
96  delete ctc_win_;
97  delete recon_win_;
98 #endif
99 }
100 
101 // Tries to deserialize a trainer from the given file and silently returns
102 // false in case of failure.
103 bool LSTMTrainer::TryLoadingCheckpoint(const char *filename,
104  const char *old_traineddata) {
105  std::vector<char> data;
106  if (!LoadDataFromFile(filename, &data)) {
107  return false;
108  }
109  tprintf("Loaded file %s, unpacking...\n", filename);
110  if (!ReadTrainingDump(data, *this)) {
111  return false;
112  }
113  if (IsIntMode()) {
114  tprintf("Error, %s is an integer (fast) model, cannot continue training\n",
115  filename);
116  return false;
117  }
118  if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
120  filename == old_traineddata) {
121  return true; // Normal checkpoint load complete.
122  }
123  tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
124  recoder_.code_range());
125  if (old_traineddata == nullptr || *old_traineddata == '\0') {
126  tprintf("Must supply the old traineddata for code conversion!\n");
127  return false;
128  }
129  TessdataManager old_mgr;
130  ASSERT_HOST(old_mgr.Init(old_traineddata));
131  TFile fp;
132  if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
133  return false;
134  }
135  UNICHARSET old_chset;
136  if (!old_chset.load_from_file(&fp, false)) {
137  return false;
138  }
139  if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
140  return false;
141  }
142  UnicharCompress old_recoder;
143  if (!old_recoder.DeSerialize(&fp)) {
144  return false;
145  }
146  std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
147  // Set the null_char_ to the new value.
148  int old_null_char = null_char_;
149  SetNullChar();
150  // Map the softmax(s) in the network.
151  network_->RemapOutputs(old_recoder.code_range(), code_map);
152  tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
153  return true;
154 }
155 
156 // Initializes the trainer with a network_spec in the network description
157 // net_flags control network behavior according to the NetworkFlags enum.
158 // There isn't really much difference between them - only where the effects
159 // are implemented.
160 // For other args see NetworkBuilder::InitNetwork.
161 // Note: Be sure to call InitCharSet before InitNetwork!
162 bool LSTMTrainer::InitNetwork(const char *network_spec, int append_index,
163  int net_flags, float weight_range,
164  float learning_rate, float momentum,
165  float adam_beta) {
166  mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec);
167  adam_beta_ = adam_beta;
169  momentum_ = momentum;
170  SetNullChar();
171  if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
172  append_index, net_flags, weight_range,
173  &randomizer_, &network_)) {
174  return false;
175  }
176  network_str_ += network_spec;
177  tprintf("Built network:%s from request %s\n", network_->spec().c_str(),
178  network_spec);
179  tprintf(
180  "Training parameters:\n Debug interval = %d,"
181  " weights = %g, learning rate = %g, momentum=%g\n",
182  debug_interval_, weight_range, learning_rate_, momentum_);
183  tprintf("null char=%d\n", null_char_);
184  return true;
185 }
186 
187 // Initializes a trainer from a serialized TFNetworkModel proto.
188 // Returns the global step of TensorFlow graph or 0 if failed.
189 #ifdef INCLUDE_TENSORFLOW
190 int LSTMTrainer::InitTensorFlowNetwork(const std::string &tf_proto) {
191  delete network_;
192  TFNetwork *tf_net = new TFNetwork("TensorFlow");
193  training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
194  if (training_iteration_ == 0) {
195  tprintf("InitFromProtoStr failed!!\n");
196  return 0;
197  }
198  network_ = tf_net;
199  ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
200  return training_iteration_;
201 }
202 #endif
203 
204 // Resets all the iteration counters for fine tuning or traininng a head,
205 // where we want the error reporting to reset.
207  sample_iteration_ = 0;
211  best_error_rate_ = 100.0;
212  best_iteration_ = 0;
213  worst_error_rate_ = 0.0;
214  worst_iteration_ = 0;
216  best_error_history_.clear();
217  best_error_iterations_.clear();
219  perfect_delay_ = 0;
221  for (int i = 0; i < ET_COUNT; ++i) {
222  best_error_rates_[i] = 100.0;
223  worst_error_rates_[i] = 0.0;
224  error_buffers_[i].clear();
226  error_rates_[i] = 100.0;
227  }
229 }
230 
231 // If the training sample is usable, grid searches for the optimal
232 // dict_ratio/cert_offset, and returns the results in a string of space-
233 // separated triplets of ratio,offset=worderr.
235  const ImageData *trainingdata, int iteration, double min_dict_ratio,
236  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
237  double cert_offset_step, double max_cert_offset, std::string &results) {
238  sample_iteration_ = iteration;
239  NetworkIO fwd_outputs, targets;
240  Trainability result =
241  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
242  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr) {
243  return result;
244  }
245 
246  // Encode/decode the truth to get the normalization.
247  std::vector<int> truth_labels, ocr_labels, xcoords;
248  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
249  // NO-dict error.
251  nullptr);
252  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
253  nullptr);
254  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
255  std::string truth_text = DecodeLabels(truth_labels);
256  std::string ocr_text = DecodeLabels(ocr_labels);
257  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
258  results += "0,0=" + std::to_string(baseline_error);
259 
261  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
262  for (double c = min_cert_offset; c < max_cert_offset;
263  c += cert_offset_step) {
264  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty,
265  nullptr);
266  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
267  truth_text = DecodeLabels(truth_labels);
268  ocr_text = DecodeLabels(ocr_labels);
269  // This is destructive on both strings.
270  double word_error = ComputeWordError(&truth_text, &ocr_text);
271  if ((r == min_dict_ratio && c == min_cert_offset) ||
272  !std::isfinite(word_error)) {
273  std::string t = DecodeLabels(truth_labels);
274  std::string o = DecodeLabels(ocr_labels);
275  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
276  t.c_str(), o.c_str(), word_error, truth_labels[0]);
277  }
278  results += " " + std::to_string(r);
279  results += "," + std::to_string(c);
280  results += "=" + std::to_string(word_error);
281  }
282  }
283  return result;
284 }
285 
286 // Provides output on the distribution of weight values.
289 }
290 
291 // Loads a set of lstmf files that were created using the lstm.train config to
292 // tesseract into memory ready for training. Returns false if nothing was
293 // loaded.
294 bool LSTMTrainer::LoadAllTrainingData(const std::vector<std::string> &filenames,
295  CachingStrategy cache_strategy,
296  bool randomly_rotate) {
297  randomly_rotate_ = randomly_rotate;
299  return training_data_.LoadDocuments(filenames, cache_strategy,
301 }
302 
303 // Keeps track of best and locally worst char error_rate and launches tests
304 // using tester, when a new min or max is reached.
305 // Writes checkpoints at appropriate times and builds and returns a log message
306 // to indicate progress. Returns false if nothing interesting happened.
308  std::string &log_msg) {
309  PrepareLogMsg(log_msg);
310  double error_rate = CharError();
311  int iteration = learning_iteration();
312  if (iteration >= stall_iteration_ &&
313  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
315  // It hasn't got any better in a long while, and is a margin worse than the
316  // best, so go back to the best model and try a different learning rate.
317  StartSubtrainer(log_msg);
318  }
319  SubTrainerResult sub_trainer_result = STR_NONE;
320  if (sub_trainer_ != nullptr) {
321  sub_trainer_result = UpdateSubtrainer(log_msg);
322  if (sub_trainer_result == STR_REPLACED) {
323  // Reset the inputs, as we have overwritten *this.
324  error_rate = CharError();
325  iteration = learning_iteration();
326  PrepareLogMsg(log_msg);
327  }
328  }
329  bool result = true; // Something interesting happened.
330  std::vector<char> rec_model_data;
331  if (error_rate < best_error_rate_) {
332  SaveRecognitionDump(&rec_model_data);
333  log_msg += " New best BCER = " + std::to_string(error_rate);
334  log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
335  // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
336  // just overwrote *this. In either case, we have finished with it.
337  sub_trainer_.reset();
340  log_msg +=
341  " Transitioned to stage " + std::to_string(CurrentTrainingStage());
342  }
345  std::string best_model_name = DumpFilename();
346  if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
347  log_msg += " failed to write best model:";
348  } else {
349  log_msg += " wrote best model:";
351  }
352  log_msg += best_model_name;
353  }
354  } else if (error_rate > worst_error_rate_) {
355  SaveRecognitionDump(&rec_model_data);
356  log_msg += " New worst BCER = " + std::to_string(error_rate);
357  log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
360  // Error rate has ballooned. Go back to the best model.
361  log_msg += "\nDivergence! ";
362  // Copy best_trainer_ before reading it, as it will get overwritten.
363  std::vector<char> revert_data(best_trainer_);
364  if (ReadTrainingDump(revert_data, *this)) {
365  LogIterations("Reverted to", log_msg);
366  ReduceLearningRates(this, log_msg);
367  } else {
368  LogIterations("Failed to Revert at", log_msg);
369  }
370  // If it fails again, we will wait twice as long before reverting again.
371  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
372  // Re-save the best trainer with the new learning rates and stall
373  // iteration.
375  }
376  } else {
377  // Something interesting happened only if the sub_trainer_ was trained.
378  result = sub_trainer_result != STR_NONE;
379  }
380  if (checkpoint_name_.length() > 0) {
381  // Write a current checkpoint.
382  std::vector<char> checkpoint;
383  if (!SaveTrainingDump(FULL, *this, &checkpoint) ||
384  !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
385  log_msg += " failed to write checkpoint.";
386  } else {
387  log_msg += " wrote checkpoint.";
388  }
389  }
390  log_msg += "\n";
391  return result;
392 }
393 
394 // Builds a string containing a progress message with current error rates.
395 void LSTMTrainer::PrepareLogMsg(std::string &log_msg) const {
396  LogIterations("At", log_msg);
397  log_msg += ", Mean rms=" + std::to_string(error_rates_[ET_RMS]);
398  log_msg += "%, delta=" + std::to_string(error_rates_[ET_DELTA]);
399  log_msg += "%, BCER train=" + std::to_string(error_rates_[ET_CHAR_ERROR]);
400  log_msg += "%, BWER train=" + std::to_string(error_rates_[ET_WORD_RECERR]);
401  log_msg += "%, skip ratio=" + std::to_string(error_rates_[ET_SKIP_RATIO]);
402  log_msg += "%, ";
403 }
404 
405 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
406 // sample_iteration() to the log_msg.
407 void LSTMTrainer::LogIterations(const char *intro_str,
408  std::string &log_msg) const {
409  log_msg += intro_str;
410  log_msg += " iteration " + std::to_string(learning_iteration());
411  log_msg += "/" + std::to_string(training_iteration());
412  log_msg += "/" + std::to_string(sample_iteration());
413 }
414 
415 // Returns true and increments the training_stage_ if the error rate has just
416 // passed through the given threshold for the first time.
417 bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
418  if (best_error_rate_ < error_threshold &&
420  ++training_stage_;
421  return true;
422  }
423  return false;
424 }
425 
426 // Writes to the given file. Returns false in case of error.
428  const TessdataManager *mgr, TFile *fp) const {
429  if (!LSTMRecognizer::Serialize(mgr, fp)) {
430  return false;
431  }
432  if (!fp->Serialize(&learning_iteration_)) {
433  return false;
434  }
435  if (!fp->Serialize(&prev_sample_iteration_)) {
436  return false;
437  }
438  if (!fp->Serialize(&perfect_delay_)) {
439  return false;
440  }
442  return false;
443  }
444  for (const auto &error_buffer : error_buffers_) {
445  if (!fp->Serialize(error_buffer)) {
446  return false;
447  }
448  }
449  if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) {
450  return false;
451  }
452  if (!fp->Serialize(&training_stage_)) {
453  return false;
454  }
455  uint8_t amount = serialize_amount;
456  if (!fp->Serialize(&amount)) {
457  return false;
458  }
459  if (serialize_amount == LIGHT) {
460  return true; // We are done.
461  }
462  if (!fp->Serialize(&best_error_rate_)) {
463  return false;
464  }
466  return false;
467  }
468  if (!fp->Serialize(&best_iteration_)) {
469  return false;
470  }
471  if (!fp->Serialize(&worst_error_rate_)) {
472  return false;
473  }
475  return false;
476  }
477  if (!fp->Serialize(&worst_iteration_)) {
478  return false;
479  }
480  if (!fp->Serialize(&stall_iteration_)) {
481  return false;
482  }
483  if (!fp->Serialize(best_model_data_)) {
484  return false;
485  }
486  if (!fp->Serialize(worst_model_data_)) {
487  return false;
488  }
489  if (serialize_amount != NO_BEST_TRAINER && !fp->Serialize(best_trainer_)) {
490  return false;
491  }
492  std::vector<char> sub_data;
493  if (sub_trainer_ != nullptr &&
494  !SaveTrainingDump(LIGHT, *sub_trainer_, &sub_data)) {
495  return false;
496  }
497  if (!fp->Serialize(sub_data)) {
498  return false;
499  }
500  if (!fp->Serialize(best_error_history_)) {
501  return false;
502  }
503  if (!fp->Serialize(best_error_iterations_)) {
504  return false;
505  }
506  return fp->Serialize(&improvement_steps_);
507 }
508 
509 // Reads from the given file. Returns false in case of error.
510 // NOTE: It is assumed that the trainer is never read cross-endian.
512  if (!LSTMRecognizer::DeSerialize(mgr, fp)) {
513  return false;
514  }
515  if (!fp->DeSerialize(&learning_iteration_)) {
516  // Special case. If we successfully decoded the recognizer, but fail here
517  // then it means we were just given a recognizer, so issue a warning and
518  // allow it.
519  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
522  return true;
523  }
524  if (!fp->DeSerialize(&prev_sample_iteration_)) {
525  return false;
526  }
527  if (!fp->DeSerialize(&perfect_delay_)) {
528  return false;
529  }
531  return false;
532  }
533  for (auto &error_buffer : error_buffers_) {
534  if (!fp->DeSerialize(error_buffer)) {
535  return false;
536  }
537  }
538  if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) {
539  return false;
540  }
541  if (!fp->DeSerialize(&training_stage_)) {
542  return false;
543  }
544  uint8_t amount;
545  if (!fp->DeSerialize(&amount)) {
546  return false;
547  }
548  if (amount == LIGHT) {
549  return true; // Don't read the rest.
550  }
551  if (!fp->DeSerialize(&best_error_rate_)) {
552  return false;
553  }
555  return false;
556  }
557  if (!fp->DeSerialize(&best_iteration_)) {
558  return false;
559  }
560  if (!fp->DeSerialize(&worst_error_rate_)) {
561  return false;
562  }
564  return false;
565  }
566  if (!fp->DeSerialize(&worst_iteration_)) {
567  return false;
568  }
569  if (!fp->DeSerialize(&stall_iteration_)) {
570  return false;
571  }
572  if (!fp->DeSerialize(best_model_data_)) {
573  return false;
574  }
575  if (!fp->DeSerialize(worst_model_data_)) {
576  return false;
577  }
578  if (amount != NO_BEST_TRAINER && !fp->DeSerialize(best_trainer_)) {
579  return false;
580  }
581  std::vector<char> sub_data;
582  if (!fp->DeSerialize(sub_data)) {
583  return false;
584  }
585  if (sub_data.empty()) {
586  sub_trainer_ = nullptr;
587  } else {
588  sub_trainer_ = std::make_unique<LSTMTrainer>();
589  if (!ReadTrainingDump(sub_data, *sub_trainer_)) {
590  return false;
591  }
592  }
593  if (!fp->DeSerialize(best_error_history_)) {
594  return false;
595  }
597  return false;
598  }
599  return fp->DeSerialize(&improvement_steps_);
600 }
601 
602 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
603 // learning rates (by scaling reduction, or layer specific, according to
604 // NF_LAYER_SPECIFIC_LR).
605 void LSTMTrainer::StartSubtrainer(std::string &log_msg) {
606  sub_trainer_ = std::make_unique<LSTMTrainer>();
608  log_msg += " Failed to revert to previous best for trial!";
609  sub_trainer_.reset();
610  } else {
611  log_msg += " Trial sub_trainer_ from iteration " +
612  std::to_string(sub_trainer_->training_iteration());
613  // Reduce learning rate so it doesn't diverge this time.
614  sub_trainer_->ReduceLearningRates(this, log_msg);
615  // If it fails again, we will wait twice as long before reverting again.
616  int stall_offset =
617  learning_iteration() - sub_trainer_->learning_iteration();
618  stall_iteration_ = learning_iteration() + 2 * stall_offset;
619  sub_trainer_->stall_iteration_ = stall_iteration_;
620  // Re-save the best trainer with the new learning rates and stall iteration.
622  }
623 }
624 
625 // While the sub_trainer_ is behind the current training iteration and its
626 // training error is at least kSubTrainerMarginFraction better than the
627 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
628 // it did anything. If it catches up, and has a better error rate than the
629 // current best, as well as a margin over the current error rate, then the
630 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
631 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
632 // receive any training iterations.
634  double training_error = CharError();
635  double sub_error = sub_trainer_->CharError();
636  double sub_margin = (training_error - sub_error) / sub_error;
637  if (sub_margin >= kSubTrainerMarginFraction) {
638  log_msg += " sub_trainer=" + std::to_string(sub_error);
639  log_msg += " margin=" + std::to_string(100.0 * sub_margin);
640  log_msg += "\n";
641  // Catch up to current iteration.
642  int end_iteration = training_iteration();
643  while (sub_trainer_->training_iteration() < end_iteration &&
644  sub_margin >= kSubTrainerMarginFraction) {
645  int target_iteration =
646  sub_trainer_->training_iteration() + kNumPagesPerBatch;
647  while (sub_trainer_->training_iteration() < target_iteration) {
648  sub_trainer_->TrainOnLine(this, false);
649  }
650  std::string batch_log = "Sub:";
651  sub_trainer_->PrepareLogMsg(batch_log);
652  batch_log += "\n";
653  tprintf("UpdateSubtrainer:%s", batch_log.c_str());
654  log_msg += batch_log;
655  sub_error = sub_trainer_->CharError();
656  sub_margin = (training_error - sub_error) / sub_error;
657  }
658  if (sub_error < best_error_rate_ &&
659  sub_margin >= kSubTrainerMarginFraction) {
660  // The sub_trainer_ has won the race to a new best. Switch to it.
661  std::vector<char> updated_trainer;
662  SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer);
663  ReadTrainingDump(updated_trainer, *this);
664  log_msg += " Sub trainer wins at iteration " +
665  std::to_string(training_iteration());
666  log_msg += "\n";
667  return STR_REPLACED;
668  }
669  return STR_UPDATED;
670  }
671  return STR_NONE;
672 }
673 
674 // Reduces network learning rates, either for everything, or for layers
675 // independently, according to NF_LAYER_SPECIFIC_LR.
677  std::string &log_msg) {
679  int num_reduced = ReduceLayerLearningRates(
680  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
681  log_msg +=
682  "\nReduced learning rate on layers: " + std::to_string(num_reduced);
683  } else {
685  log_msg += "\nReduced learning rate to :" + std::to_string(learning_rate_);
686  }
687  log_msg += "\n";
688 }
689 
690 // Considers reducing the learning rate independently for each layer down by
691 // factor(<1), or leaving it the same, by double-training the given number of
692 // samples and minimizing the amount of changing of sign of weight updates.
693 // Even if it looks like all weights should remain the same, an adjustment
694 // will be made to guarantee a different result when reverting to an old best.
695 // Returns the number of layer learning rates that were reduced.
696 int LSTMTrainer::ReduceLayerLearningRates(TFloat factor, int num_samples,
697  LSTMTrainer *samples_trainer) {
698  enum WhichWay {
699  LR_DOWN, // Learning rate will go down by factor.
700  LR_SAME, // Learning rate will stay the same.
701  LR_COUNT // Size of arrays.
702  };
703  std::vector<std::string> layers = EnumerateLayers();
704  int num_layers = layers.size();
705  std::vector<int> num_weights(num_layers);
706  std::vector<TFloat> bad_sums[LR_COUNT];
707  std::vector<TFloat> ok_sums[LR_COUNT];
708  for (int i = 0; i < LR_COUNT; ++i) {
709  bad_sums[i].resize(num_layers, 0.0);
710  ok_sums[i].resize(num_layers, 0.0);
711  }
712  auto momentum_factor = 1 / (1 - momentum_);
713  std::vector<char> orig_trainer;
714  samples_trainer->SaveTrainingDump(LIGHT, *this, &orig_trainer);
715  for (int i = 0; i < num_layers; ++i) {
716  Network *layer = GetLayer(layers[i]);
717  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
718  }
719  int iteration = sample_iteration();
720  for (int s = 0; s < num_samples; ++s) {
721  // Which way will we modify the learning rate?
722  for (int ww = 0; ww < LR_COUNT; ++ww) {
723  // Transfer momentum to learning rate and adjust by the ww factor.
724  auto ww_factor = momentum_factor;
725  if (ww == LR_DOWN) {
726  ww_factor *= factor;
727  }
728  // Make a copy of *this, so we can mess about without damaging anything.
729  LSTMTrainer copy_trainer;
730  samples_trainer->ReadTrainingDump(orig_trainer, copy_trainer);
731  // Clear the updates, doing nothing else.
732  copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
733  // Adjust the learning rate in each layer.
734  for (int i = 0; i < num_layers; ++i) {
735  if (num_weights[i] == 0) {
736  continue;
737  }
738  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
739  }
740  copy_trainer.SetIteration(iteration);
741  // Train on the sample, but keep the update in updates_ instead of
742  // applying to the weights.
743  const ImageData *trainingdata =
744  copy_trainer.TrainOnLine(samples_trainer, true);
745  if (trainingdata == nullptr) {
746  continue;
747  }
748  // We'll now use this trainer again for each layer.
749  std::vector<char> updated_trainer;
750  samples_trainer->SaveTrainingDump(LIGHT, copy_trainer, &updated_trainer);
751  for (int i = 0; i < num_layers; ++i) {
752  if (num_weights[i] == 0) {
753  continue;
754  }
755  LSTMTrainer layer_trainer;
756  samples_trainer->ReadTrainingDump(updated_trainer, layer_trainer);
757  Network *layer = layer_trainer.GetLayer(layers[i]);
758  // Update the weights in just the layer, using Adam if enabled.
759  layer->Update(0.0, momentum_, adam_beta_,
760  layer_trainer.training_iteration_ + 1);
761  // Zero the updates matrix again.
762  layer->Update(0.0, 0.0, 0.0, 0);
763  // Train again on the same sample, again holding back the updates.
764  layer_trainer.TrainOnLine(trainingdata, true);
765  // Count the sign changes in the updates in layer vs in copy_trainer.
766  float before_bad = bad_sums[ww][i];
767  float before_ok = ok_sums[ww][i];
768  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
769  &ok_sums[ww][i], &bad_sums[ww][i]);
770  float bad_frac =
771  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
772  if (bad_frac > 0.0f) {
773  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
774  }
775  }
776  }
777  ++iteration;
778  }
779  int num_lowered = 0;
780  for (int i = 0; i < num_layers; ++i) {
781  if (num_weights[i] == 0) {
782  continue;
783  }
784  Network *layer = GetLayer(layers[i]);
785  float lr = GetLayerLearningRate(layers[i]);
786  TFloat total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
787  TFloat total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
788  TFloat frac_down = bad_sums[LR_DOWN][i] / total_down;
789  TFloat frac_same = bad_sums[LR_SAME][i] / total_same;
790  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(),
791  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
792  if (frac_down < frac_same * kImprovementFraction) {
793  tprintf(" REDUCED\n");
794  ScaleLayerLearningRate(layers[i], factor);
795  ++num_lowered;
796  } else {
797  tprintf(" SAME\n");
798  }
799  }
800  if (num_lowered == 0) {
801  // Just lower everything to make sure.
802  for (int i = 0; i < num_layers; ++i) {
803  if (num_weights[i] > 0) {
804  ScaleLayerLearningRate(layers[i], factor);
805  ++num_lowered;
806  }
807  }
808  }
809  return num_lowered;
810 }
811 
812 // Converts the string to integer class labels, with appropriate null_char_s
813 // in between if not in SimpleTextOutput mode. Returns false on failure.
814 /* static */
815 bool LSTMTrainer::EncodeString(const std::string &str,
816  const UNICHARSET &unicharset,
817  const UnicharCompress *recoder, bool simple_text,
818  int null_char, std::vector<int> *labels) {
819  if (str.c_str() == nullptr || str.length() <= 0) {
820  tprintf("Empty truth string!\n");
821  return false;
822  }
823  unsigned err_index;
824  std::vector<int> internal_labels;
825  labels->clear();
826  if (!simple_text) {
827  labels->push_back(null_char);
828  }
829  std::string cleaned = unicharset.CleanupString(str.c_str());
830  if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
831  &err_index)) {
832  bool success = true;
833  for (auto internal_label : internal_labels) {
834  if (recoder != nullptr) {
835  // Re-encode labels via recoder.
836  RecodedCharID code;
837  int len = recoder->EncodeUnichar(internal_label, &code);
838  if (len > 0) {
839  for (int j = 0; j < len; ++j) {
840  labels->push_back(code(j));
841  if (!simple_text) {
842  labels->push_back(null_char);
843  }
844  }
845  } else {
846  success = false;
847  err_index = 0;
848  break;
849  }
850  } else {
851  labels->push_back(internal_label);
852  if (!simple_text) {
853  labels->push_back(null_char);
854  }
855  }
856  }
857  if (success) {
858  return true;
859  }
860  }
861  tprintf("Encoding of string failed! Failure bytes:");
862  while (err_index < cleaned.size()) {
863  tprintf(" %x", cleaned[err_index++] & 0xff);
864  }
865  tprintf("\n");
866  return false;
867 }
868 
869 // Performs forward-backward on the given trainingdata.
870 // Returns a Trainability enum to indicate the suitability of the sample.
872  bool batch) {
873  NetworkIO fwd_outputs, targets;
874  Trainability trainable =
875  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
877  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
878  return trainable; // Sample was unusable.
879  }
880  bool debug =
882  // Run backprop on the output.
883  NetworkIO bp_deltas;
884  if (network_->IsTraining() &&
885  (trainable != PERFECT ||
888  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
890  training_iteration_ + 1);
891  }
892 #ifndef GRAPHICS_DISABLED
893  if (debug_interval_ == 1 && debug_win_ != nullptr) {
895  }
896 #endif // !GRAPHICS_DISABLED
897  // Roll the memory of past means.
899  return trainable;
900 }
901 
902 // Prepares the ground truth, runs forward, and prepares the targets.
903 // Returns a Trainability enum to indicate the suitability of the sample.
905  NetworkIO *fwd_outputs,
906  NetworkIO *targets) {
907  if (trainingdata == nullptr) {
908  tprintf("Null trainingdata.\n");
909  return UNENCODABLE;
910  }
911  // Ensure repeatability of random elements even across checkpoints.
912  bool debug =
914  std::vector<int> truth_labels;
915  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
916  tprintf("Can't encode transcription: '%s' in language '%s'\n",
917  trainingdata->transcription().c_str(),
918  trainingdata->language().c_str());
919  return UNENCODABLE;
920  }
921  bool upside_down = false;
922  if (randomly_rotate_) {
923  // This ensures consistent training results.
924  SetRandomSeed();
925  upside_down = randomizer_.SignedRand(1.0) > 0.0;
926  if (upside_down) {
927  // Modify the truth labels to match the rotation:
928  // Apart from space and null, increment the label. This changes the
929  // script-id to the same script-id but upside-down.
930  // The labels need to be reversed in order, as the first is now the last.
931  for (auto truth_label : truth_labels) {
932  if (truth_label != UNICHAR_SPACE && truth_label != null_char_) {
933  ++truth_label;
934  }
935  }
936  std::reverse(truth_labels.begin(), truth_labels.end());
937  }
938  }
939  unsigned w = 0;
940  while (w < truth_labels.size() &&
941  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) {
942  ++w;
943  }
944  if (w == truth_labels.size()) {
945  tprintf("Blank transcription: %s\n", trainingdata->transcription().c_str());
946  return UNENCODABLE;
947  }
948  float image_scale;
949  NetworkIO inputs;
950  bool invert = trainingdata->boxes().empty();
951  if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
952  &image_scale, &inputs, fwd_outputs)) {
953  tprintf("Image %s not trainable\n", trainingdata->imagefilename().c_str());
954  return UNENCODABLE;
955  }
956  targets->Resize(*fwd_outputs, network_->NumOutputs());
957  LossType loss_type = OutputLossType();
958  if (loss_type == LT_SOFTMAX) {
959  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
960  tprintf("Compute simple targets failed for %s!\n",
961  trainingdata->imagefilename().c_str());
962  return UNENCODABLE;
963  }
964  } else if (loss_type == LT_CTC) {
965  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
966  tprintf("Compute CTC targets failed for %s!\n",
967  trainingdata->imagefilename().c_str());
968  return UNENCODABLE;
969  }
970  } else {
971  tprintf("Logistic outputs not implemented yet!\n");
972  return UNENCODABLE;
973  }
974  std::vector<int> ocr_labels;
975  std::vector<int> xcoords;
976  LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
977  // CTC does not produce correct target labels to begin with.
978  if (loss_type != LT_CTC) {
979  LabelsFromOutputs(*targets, &truth_labels, &xcoords);
980  }
981  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
982  *targets)) {
983  tprintf("Input width was %d\n", inputs.Width());
984  return UNENCODABLE;
985  }
986  std::string ocr_text = DecodeLabels(ocr_labels);
987  std::string truth_text = DecodeLabels(truth_labels);
988  targets->SubtractAllFromFloat(*fwd_outputs);
989  if (debug_interval_ != 0) {
990  if (truth_text != ocr_text) {
991  tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
992  ocr_text.c_str());
993  }
994  }
995  double char_error = ComputeCharError(truth_labels, ocr_labels);
996  double word_error = ComputeWordError(&truth_text, &ocr_text);
997  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
998  if (debug_interval_ != 0) {
999  tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(),
1000  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
1001  }
1002  if (delta_error == 0.0) {
1003  return PERFECT;
1004  }
1005  if (targets->AnySuspiciousTruth(kHighConfidence)) {
1006  return HI_PRECISION_ERR;
1007  }
1008  return TRAINABLE;
1009 }
1010 
1011 // Writes the trainer to memory, so that the current training state can be
1012 // restored. *this must always be the master trainer that retains the only
1013 // copy of the training data and language model. trainer is the model that is
1014 // actually serialized.
1016  const LSTMTrainer &trainer,
1017  std::vector<char> *data) const {
1018  TFile fp;
1019  fp.OpenWrite(data);
1020  return trainer.Serialize(serialize_amount, &mgr_, &fp);
1021 }
1022 
1023 // Restores the model to *this.
1025  const char *data, int size) {
1026  if (size == 0) {
1027  tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
1028  return false;
1029  }
1030  TFile fp;
1031  fp.Open(data, size);
1032  return DeSerialize(mgr, &fp);
1033 }
1034 
1035 // Writes the full recognition traineddata to the given filename.
1036 bool LSTMTrainer::SaveTraineddata(const char *filename) {
1037  std::vector<char> recognizer_data;
1038  SaveRecognitionDump(&recognizer_data);
1039  mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
1040  recognizer_data.size());
1041  return mgr_.SaveFile(filename, SaveDataToFile);
1042 }
1043 
1044 // Writes the recognizer to memory, so that it can be used for testing later.
1045 void LSTMTrainer::SaveRecognitionDump(std::vector<char> *data) const {
1046  TFile fp;
1047  fp.OpenWrite(data);
1051 }
1052 
1053 // Returns a suitable filename for a training dump, based on the model_base_,
1054 // best_error_rate_, best_iteration_ and training_iteration_.
1055 std::string LSTMTrainer::DumpFilename() const {
1056  std::string filename;
1057  filename += model_base_.c_str();
1058  filename += "_" + std::to_string(best_error_rate_);
1059  filename += "_" + std::to_string(best_iteration_);
1060  filename += "_" + std::to_string(training_iteration_);
1061  filename += ".checkpoint";
1062  return filename;
1063 }
1064 
1065 // Fills the whole error buffer of the given type with the given value.
1066 void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
1067  for (int i = 0; i < kRollingBufferSize_; ++i) {
1068  error_buffers_[type][i] = new_error;
1069  }
1070  error_rates_[type] = 100.0 * new_error;
1071 }
1072 
1073 // Helper generates a map from each current recoder_ code (ie softmax index)
1074 // to the corresponding old_recoder code, or -1 if there isn't one.
1075 std::vector<int> LSTMTrainer::MapRecoder(
1076  const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const {
1077  int num_new_codes = recoder_.code_range();
1078  int num_new_unichars = GetUnicharset().size();
1079  std::vector<int> code_map(num_new_codes, -1);
1080  for (int c = 0; c < num_new_codes; ++c) {
1081  int old_code = -1;
1082  // Find all new unichar_ids that recode to something that includes c.
1083  // The <= is to include the null char, which may be beyond the unicharset.
1084  for (int uid = 0; uid <= num_new_unichars; ++uid) {
1085  RecodedCharID codes;
1086  int length = recoder_.EncodeUnichar(uid, &codes);
1087  int code_index = 0;
1088  while (code_index < length && codes(code_index) != c) {
1089  ++code_index;
1090  }
1091  if (code_index == length) {
1092  continue;
1093  }
1094  // The old unicharset must have the same unichar.
1095  int old_uid =
1096  uid < num_new_unichars
1097  ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
1098  : old_chset.size() - 1;
1099  if (old_uid == INVALID_UNICHAR_ID) {
1100  continue;
1101  }
1102  // The encoding of old_uid at the same code_index is the old code.
1103  RecodedCharID old_codes;
1104  if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
1105  old_code = old_codes(code_index);
1106  break;
1107  }
1108  }
1109  code_map[c] = old_code;
1110  }
1111  return code_map;
1112 }
1113 
1114 // Private version of InitCharSet above finishes the job after initializing
1115 // the mgr_ data member.
1117  EmptyConstructor();
1119  // Initialize the unicharset and recoder.
1120  if (!LoadCharsets(&mgr_)) {
1121  ASSERT_HOST(
1122  "Must provide a traineddata containing lstm_unicharset and"
1123  " lstm_recoder!\n" != nullptr);
1124  }
1125  SetNullChar();
1126 }
1127 
1128 // Helper computes and sets the null_char_.
1131  : GetUnicharset().size();
1132  RecodedCharID code;
1134  null_char_ = code(0);
1135 }
1136 
1137 // Factored sub-constructor sets up reasonable default values.
1139 #ifndef GRAPHICS_DISABLED
1140  align_win_ = nullptr;
1141  target_win_ = nullptr;
1142  ctc_win_ = nullptr;
1143  recon_win_ = nullptr;
1144 #endif
1146  training_stage_ = 0;
1148  InitIterations();
1149 }
1150 
1151 // Outputs the string and periodically displays the given network inputs
1152 // as an image in the given window, and the corresponding labels at the
1153 // corresponding x_starts.
1154 // Returns false if the truth string is empty.
1156  const ImageData &trainingdata,
1157  const NetworkIO &fwd_outputs,
1158  const std::vector<int> &truth_labels,
1159  const NetworkIO &outputs) {
1160  const std::string &truth_text = DecodeLabels(truth_labels);
1161  if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
1162  tprintf("Empty truth string at decode time!\n");
1163  return false;
1164  }
1165  if (debug_interval_ != 0) {
1166  // Get class labels, xcoords and string.
1167  std::vector<int> labels;
1168  std::vector<int> xcoords;
1169  LabelsFromOutputs(outputs, &labels, &xcoords);
1170  std::string text = DecodeLabels(labels);
1171  tprintf("Iteration %d: GROUND TRUTH : %s\n", training_iteration(),
1172  truth_text.c_str());
1173  if (truth_text != text) {
1174  tprintf("Iteration %d: ALIGNED TRUTH : %s\n", training_iteration(),
1175  text.c_str());
1176  }
1177  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1178  tprintf("TRAINING activation path for truth string %s\n",
1179  truth_text.c_str());
1180  DebugActivationPath(outputs, labels, xcoords);
1181 #ifndef GRAPHICS_DISABLED
1182  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1183  if (OutputLossType() == LT_CTC) {
1184  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1185  DisplayTargets(outputs, "CTC Targets", &target_win_);
1186  }
1187 #endif
1188  }
1189  }
1190  return true;
1191 }
1192 
1193 #ifndef GRAPHICS_DISABLED
1194 
1195 // Displays the network targets as line a line graph.
1197  const char *window_name, ScrollView **window) {
1198  int width = targets.Width();
1199  int num_features = targets.NumFeatures();
1200  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1201  window);
1202  for (int c = 0; c < num_features; ++c) {
1203  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1204  (*window)->Pen(static_cast<ScrollView::Color>(color));
1205  int start_t = -1;
1206  for (int t = 0; t < width; ++t) {
1207  double target = targets.f(t)[c];
1208  target *= kTargetYScale;
1209  if (target >= 1) {
1210  if (start_t < 0) {
1211  (*window)->SetCursor(t - 1, 0);
1212  start_t = t;
1213  }
1214  (*window)->DrawTo(t, target);
1215  } else if (start_t >= 0) {
1216  (*window)->DrawTo(t, 0);
1217  (*window)->DrawTo(start_t - 1, 0);
1218  start_t = -1;
1219  }
1220  }
1221  if (start_t >= 0) {
1222  (*window)->DrawTo(width, 0);
1223  (*window)->DrawTo(start_t - 1, 0);
1224  }
1225  }
1226  (*window)->Update();
1227 }
1228 
1229 #endif // !GRAPHICS_DISABLED
1230 
1231 // Builds a no-compromises target where the first positions should be the
1232 // truth labels and the rest is padded with the null_char_.
1234  const std::vector<int> &truth_labels,
1235  NetworkIO *targets) {
1236  if (truth_labels.size() > targets->Width()) {
1237  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1238  DecodeLabels(truth_labels).c_str(), targets->Width());
1239  return false;
1240  }
1241  size_t i = 0;
1242  for (auto truth_label : truth_labels) {
1243  targets->SetActivations(i, truth_label, 1.0);
1244  ++i;
1245  }
1246  for (i = truth_labels.size(); i < targets->Width(); ++i) {
1247  targets->SetActivations(i, null_char_, 1.0);
1248  }
1249  return true;
1250 }
1251 
1252 // Builds a target using standard CTC. truth_labels should be pre-padded with
1253 // nulls wherever desired. They don't have to be between all labels.
1254 // outputs is input-output, as it gets clipped to minimum probability.
1255 bool LSTMTrainer::ComputeCTCTargets(const std::vector<int> &truth_labels,
1256  NetworkIO *outputs, NetworkIO *targets) {
1257  // Bottom-clip outputs to a minimum probability.
1258  CTC::NormalizeProbs(outputs);
1259  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1260  outputs->float_array(), targets);
1261 }
1262 
1263 // Computes network errors, and stores the results in the rolling buffers,
1264 // along with the supplied text_error.
1265 // Returns the delta error of the current sample (not running average.)
1267  double char_error, double word_error) {
1269  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1270  // score. If zero, then the top choice characters are guaranteed correct,
1271  // even when there is residue in the RMS error.
1272  double delta_error = ComputeWinnerError(deltas);
1273  UpdateErrorBuffer(delta_error, ET_DELTA);
1274  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1275  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1276  // Skip ratio measures the difference between sample_iteration_ and
1277  // training_iteration_, which reflects the number of unusable samples,
1278  // usually due to unencodable truth text, or the text not fitting in the
1279  // space for the output.
1280  double skip_count = sample_iteration_ - prev_sample_iteration_;
1281  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1282  return delta_error;
1283 }
1284 
1285 // Computes the network activation RMS error rate.
1287  double total_error = 0.0;
1288  int width = deltas.Width();
1289  int num_classes = deltas.NumFeatures();
1290  for (int t = 0; t < width; ++t) {
1291  const float *class_errs = deltas.f(t);
1292  for (int c = 0; c < num_classes; ++c) {
1293  double error = class_errs[c];
1294  total_error += error * error;
1295  }
1296  }
1297  return sqrt(total_error / (width * num_classes));
1298 }
1299 
1300 // Computes network activation winner error rate. (Number of values that are
1301 // in error by >= 0.5 divided by number of time-steps.) More closely related
1302 // to final character error than RMS, but still directly calculable from
1303 // just the deltas. Because of the binary nature of the targets, zero winner
1304 // error is a sufficient but not necessary condition for zero char error.
1306  int num_errors = 0;
1307  int width = deltas.Width();
1308  int num_classes = deltas.NumFeatures();
1309  for (int t = 0; t < width; ++t) {
1310  const float *class_errs = deltas.f(t);
1311  for (int c = 0; c < num_classes; ++c) {
1312  float abs_delta = std::fabs(class_errs[c]);
1313  // TODO(rays) Filtering cases where the delta is very large to cut out
1314  // GT errors doesn't work. Find a better way or get better truth.
1315  if (0.5 <= abs_delta) {
1316  ++num_errors;
1317  }
1318  }
1319  }
1320  return static_cast<double>(num_errors) / width;
1321 }
1322 
1323 // Computes a very simple bag of chars char error rate.
1324 double LSTMTrainer::ComputeCharError(const std::vector<int> &truth_str,
1325  const std::vector<int> &ocr_str) {
1326  std::vector<int> label_counts(NumOutputs());
1327  unsigned truth_size = 0;
1328  for (auto ch : truth_str) {
1329  if (ch != null_char_) {
1330  ++label_counts[ch];
1331  ++truth_size;
1332  }
1333  }
1334  for (auto ch : ocr_str) {
1335  if (ch != null_char_) {
1336  --label_counts[ch];
1337  }
1338  }
1339  unsigned char_errors = 0;
1340  for (auto label_count : label_counts) {
1341  char_errors += abs(label_count);
1342  }
1343  // Limit BCER to interval [0,1] and avoid division by zero.
1344  if (truth_size <= char_errors) {
1345  return (char_errors == 0) ? 0.0 : 1.0;
1346  }
1347  return static_cast<double>(char_errors) / truth_size;
1348 }
1349 
1350 // Computes word recall error rate using a very simple bag of words algorithm.
1351 // NOTE that this is destructive on both input strings.
1352 double LSTMTrainer::ComputeWordError(std::string *truth_str,
1353  std::string *ocr_str) {
1354  using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1355  std::vector<std::string> truth_words = split(*truth_str, ' ');
1356  if (truth_words.empty()) {
1357  return 0.0;
1358  }
1359  std::vector<std::string> ocr_words = split(*ocr_str, ' ');
1360  StrMap word_counts;
1361  for (const auto &truth_word : truth_words) {
1362  std::string truth_word_string(truth_word.c_str());
1363  auto it = word_counts.find(truth_word_string);
1364  if (it == word_counts.end()) {
1365  word_counts.insert(std::make_pair(truth_word_string, 1));
1366  } else {
1367  ++it->second;
1368  }
1369  }
1370  for (const auto &ocr_word : ocr_words) {
1371  std::string ocr_word_string(ocr_word.c_str());
1372  auto it = word_counts.find(ocr_word_string);
1373  if (it == word_counts.end()) {
1374  word_counts.insert(std::make_pair(ocr_word_string, -1));
1375  } else {
1376  --it->second;
1377  }
1378  }
1379  int word_recall_errs = 0;
1380  for (const auto &word_count : word_counts) {
1381  if (word_count.second > 0) {
1382  word_recall_errs += word_count.second;
1383  }
1384  }
1385  return static_cast<double>(word_recall_errs) / truth_words.size();
1386 }
1387 
1388 // Updates the error buffer and corresponding mean of the given type with
1389 // the new_error.
1390 void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) {
1392  error_buffers_[type][index] = new_error;
1393  // Compute the mean error.
1394  int mean_count =
1395  std::min<int>(training_iteration_ + 1, error_buffers_[type].size());
1396  double buffer_sum = 0.0;
1397  for (int i = 0; i < mean_count; ++i) {
1398  buffer_sum += error_buffers_[type][i];
1399  }
1400  double mean = buffer_sum / mean_count;
1401  // Trim precision to 1/1000 of 1%.
1402  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1403 }
1404 
1405 // Rolls error buffers and reports the current means.
1408  if (NewSingleError(ET_DELTA) > 0.0) {
1410  } else {
1412  }
1414  if (debug_interval_ != 0) {
1415  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1419  }
1420 }
1421 
1422 // Given that error_rate is either a new min or max, updates the best/worst
1423 // error rates, and record of progress.
1424 // Tester is an externally supplied callback function that tests on some
1425 // data set with a given model and records the error rates in a graph.
1426 std::string LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1427  const std::vector<char> &model_data,
1428  const TestCallback &tester) {
1429  if (error_rate > best_error_rate_ &&
1430  iteration < best_iteration_ + kErrorGraphInterval) {
1431  // Too soon to record a new point.
1432  if (tester != nullptr && !worst_model_data_.empty()) {
1434  worst_model_data_.size());
1435  return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1436  } else {
1437  return "";
1438  }
1439  }
1440  std::string result;
1441  // NOTE: there are 2 asymmetries here:
1442  // 1. We are computing the global minimum, but the local maximum in between.
1443  // 2. If the tester returns an empty string, indicating that it is busy,
1444  // call it repeatedly on new local maxima to test the previous min, but
1445  // not the other way around, as there is little point testing the maxima
1446  // between very frequent minima.
1447  if (error_rate < best_error_rate_) {
1448  // This is a new (global) minimum.
1449  if (tester != nullptr && !worst_model_data_.empty()) {
1451  worst_model_data_.size());
1452  result = tester(worst_iteration_, worst_error_rates_, mgr_,
1454  worst_model_data_.clear();
1455  best_model_data_ = model_data;
1456  }
1457  best_error_rate_ = error_rate;
1458  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1459  best_iteration_ = iteration;
1460  best_error_history_.push_back(error_rate);
1461  best_error_iterations_.push_back(iteration);
1462  // Compute 2% decay time.
1463  double two_percent_more = error_rate + 2.0;
1464  int i;
1465  for (i = best_error_history_.size() - 1;
1466  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1467  }
1468  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1469  improvement_steps_ = iteration - old_iteration;
1470  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1471  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1472  old_iteration);
1473  } else if (error_rate > best_error_rate_) {
1474  // This is a new (local) maximum.
1475  if (tester != nullptr) {
1476  if (!best_model_data_.empty()) {
1478  best_model_data_.size());
1479  result = tester(best_iteration_, best_error_rates_, mgr_,
1481  } else if (!worst_model_data_.empty()) {
1482  // Allow for multiple data points with "worst" error rate.
1484  worst_model_data_.size());
1485  result = tester(worst_iteration_, worst_error_rates_, mgr_,
1487  }
1488  if (result.length() > 0) {
1489  best_model_data_.clear();
1490  }
1491  worst_model_data_ = model_data;
1492  }
1493  }
1494  worst_error_rate_ = error_rate;
1495  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1496  worst_iteration_ = iteration;
1497  return result;
1498 }
1499 
1500 } // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:59
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:41
@ TF_COMPRESS_UNICHARSET
@ ET_WORD_RECERR
Definition: lstmtrainer.h:43
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:45
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:44
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:54
const double kLearningRateDecay
Definition: lstmtrainer.cpp:52
const double kImprovementFraction
Definition: lstmtrainer.cpp:66
const int kTargetYScale
Definition: lstmtrainer.cpp:72
@ STR_REPLACED
Definition: lstmtrainer.h:69
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
int IntCastRounded(double x)
Definition: helpers.h:175
@ SVET_CLICK
Definition: scrollview.h:55
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:77
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:56
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:42
bool SaveDataToFile(const GenericVector< char > &data, const char *filename)
@ UNICHAR_SPACE
Definition: unicharset.h:36
@ UNICHAR_BROKEN
Definition: unicharset.h:38
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_ENABLED
Definition: network.h:93
@ TS_RE_ENABLE
Definition: network.h:97
double TFloat
Definition: tesstypes.h:39
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:211
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:45
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:54
const double kHighConfidence
Definition: lstmtrainer.cpp:64
CachingStrategy
Definition: imagedata.h:42
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:68
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:58
const int kTargetXScale
Definition: lstmtrainer.cpp:71
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:62
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:61
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
int page_number() const
Definition: imagedata.h:89
const std::string & transcription() const
Definition: imagedata.h:104
const std::string & language() const
Definition: imagedata.h:98
const std::string & imagefilename() const
Definition: imagedata.h:83
const std::vector< TBOX > & boxes() const
Definition: imagedata.h:107
TESS_API bool LoadDocuments(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:614
double SignedRand(double range)
Definition: helpers.h:76
void OpenWrite(std::vector< char > *data)
Definition: serialis.cpp:246
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
bool Open(const char *filename, FileReader reader)
Definition: serialis.cpp:140
void OverwriteEntry(TessdataType type, const char *data, int size)
std::string VersionString() const
void SetVersionString(const std::string &v_str)
bool GetComponent(TessdataType type, TFile *fp)
bool SaveFile(const char *filename, FileWriter writer) const
bool Init(const char *data_file_name)
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
bool encode_string(const char *str, bool give_up_on_failure, std::vector< UNICHAR_ID > *encoding, std::vector< char > *lengths, unsigned *encoded_length) const
Definition: unicharset.cpp:239
bool has_special_codes() const
Definition: unicharset.h:757
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:186
size_t size() const
Definition: unicharset.h:355
static std::string CleanupString(const char *utf8_str)
Definition: unicharset.h:265
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
LossType OutputLossType() const
std::string DecodeLabels(const std::vector< int > &labels)
NetworkScratch scratch_space_
bool LoadCharsets(const TessdataManager *mgr)
const UNICHARSET & GetUnicharset() const
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void SetIteration(int iteration)
void ScaleLearningRate(double factor)
void ScaleLayerLearningRate(const std::string &id, double factor)
std::vector< std::string > EnumerateLayers() const
float GetLayerLearningRate(const std::string &id) const
bool Serialize(const TessdataManager *mgr, TFile *fp) const
Network * GetLayer(const std::string &id) const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
virtual void Update([[maybe_unused]] float learning_rate, [[maybe_unused]] float momentum, [[maybe_unused]] float adam_beta, [[maybe_unused]] int num_samples)
Definition: network.h:235
virtual void CountAlternators([[maybe_unused]] const Network &other, [[maybe_unused]] TFloat *same, [[maybe_unused]] TFloat *changed) const
Definition: network.h:242
int NumOutputs() const
Definition: network.h:125
int num_weights() const
Definition: network.h:119
const std::string & name() const
Definition: network.h:140
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:113
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool IsTraining() const
Definition: network.h:113
virtual void DebugWeights()=0
virtual int RemapOutputs([[maybe_unused]] int old_no, [[maybe_unused]] const std::vector< int > &code_map)
Definition: network.h:190
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
virtual std::string spec() const
Definition: network.h:143
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
float * f(int t)
Definition: networkio.h:111
const GENERIC_2D_ARRAY< float > & float_array() const
Definition: networkio.h:135
int Width() const
Definition: networkio.h:103
void SetActivations(int t, int label, float ok_score)
Definition: networkio.cpp:557
bool AnySuspiciousTruth(float confidence_thr) const
Definition: networkio.cpp:600
void SubtractAllFromFloat(const NetworkIO &src)
Definition: networkio.cpp:847
int NumFeatures() const
Definition: networkio.h:107
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:89
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:207
static constexpr float kMinCertainty
Definition: recodebeam.h:246
static bool ComputeCTCTargets(const std::vector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
static bool InitNetwork(int num_outputs, const char *network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
bool TransitionTrainingStage(float error_threshold)
std::vector< int32_t > best_error_iterations_
Definition: lstmtrainer.h:462
std::vector< char > worst_model_data_
Definition: lstmtrainer.h:449
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
std::string UpdateErrorGraph(int iteration, double error_rate, const std::vector< char > &model_data, const TestCallback &tester)
ScrollView * target_win_
Definition: lstmtrainer.h:409
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:253
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:485
bool LoadAllTrainingData(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
int InitTensorFlowNetwork(const std::string &tf_proto)
double ComputeWordError(std::string *truth_str, std::string *ocr_str)
void ReduceLearningRates(LSTMTrainer *samples_trainer, std::string &log_msg)
std::string model_base_
Definition: lstmtrainer.h:420
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:157
double CharError() const
Definition: lstmtrainer.h:132
void PrepareLogMsg(std::string &log_msg) const
bool ComputeCTCTargets(const std::vector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
std::vector< char > best_trainer_
Definition: lstmtrainer.h:451
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:442
bool MaintainCheckpoints(const TestCallback &tester, std::string &log_msg)
void SaveRecognitionDump(std::vector< char > *data) const
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:267
bool ComputeTextTargets(const NetworkIO &outputs, const std::vector< int > &truth_labels, NetworkIO *targets)
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:456
ScrollView * recon_win_
Definition: lstmtrainer.h:413
void FillErrorBuffer(double new_error, ErrorTypes type)
void LogIterations(const char *intro_str, std::string &log_msg) const
int learning_iteration() const
Definition: lstmtrainer.h:144
bool SaveTraineddata(const char *filename)
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
double ComputeRMSError(const NetworkIO &deltas)
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, std::string &results)
double ComputeWinnerError(const NetworkIO &deltas)
std::string checkpoint_name_
Definition: lstmtrainer.h:422
bool InitNetwork(const char *network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
void StartSubtrainer(std::string &log_msg)
SubTrainerResult UpdateSubtrainer(std::string &log_msg)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
ScrollView * ctc_win_
Definition: lstmtrainer.h:411
int CurrentTrainingStage() const
Definition: lstmtrainer.h:216
std::string DumpFilename() const
std::vector< char > best_model_data_
Definition: lstmtrainer.h:448
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer &trainer, std::vector< char > *data) const
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const std::vector< int > &truth_labels, const NetworkIO &outputs)
DocumentCache training_data_
Definition: lstmtrainer.h:425
static const int kRollingBufferSize_
Definition: lstmtrainer.h:482
std::vector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:483
std::unique_ptr< LSTMTrainer > sub_trainer_
Definition: lstmtrainer.h:454
double ComputeCharError(const std::vector< int > &truth_str, const std::vector< int > &ocr_str)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
bool ReadTrainingDump(const std::vector< char > &data, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:299
int ReduceLayerLearningRates(TFloat factor, int num_samples, LSTMTrainer *samples_trainer)
std::vector< double > best_error_history_
Definition: lstmtrainer.h:461
TessdataManager mgr_
Definition: lstmtrainer.h:487
ScrollView * align_win_
Definition: lstmtrainer.h:407
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:436
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:445