tesseract  5.0.0
lstmtraining.cpp
Go to the documentation of this file.
1 // File: lstmtraining.cpp
3 // Description: Training program for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #include <cerrno>
19 #if defined(__USE_GNU)
20 # include <cfenv> // for feenableexcept
21 #endif
22 #include "commontraining.h"
23 #include "fileio.h" // for LoadFileLinesToStrings
24 #include "lstmtester.h"
25 #include "lstmtrainer.h"
26 #include "params.h"
27 #include "tprintf.h"
29 
30 using namespace tesseract;
31 
32 static INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
33 static STRING_PARAM_FLAG(net_spec, "", "Network specification");
34 static INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
35 static INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples between perfect ones.");
36 static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
37 static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
38 static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
39 static BOOL_PARAM_FLAG(reset_learning_rate, false,
40  "Resets all stored learning rates to the value specified by --learning_rate.");
41 static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
42 static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
43 static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
44 static STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
45 static STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
46 static STRING_PARAM_FLAG(train_listfile, "",
47  "File listing training files in lstmf training format.");
48 static STRING_PARAM_FLAG(eval_listfile, "", "File listing eval files in lstmf training format.");
49 #if defined(__USE_GNU)
50 static BOOL_PARAM_FLAG(debug_float, false, "Raise error on certain float errors.");
51 #endif
52 static BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model.");
53 static BOOL_PARAM_FLAG(convert_to_int, false, "Convert the recognition model to an integer model.");
54 static BOOL_PARAM_FLAG(sequential_training, false,
55  "Use the training files sequentially instead of round-robin.");
56 static INT_PARAM_FLAG(append_index, -1,
57  "Index in continue_from Network at which to"
58  " attach the new network defined by net_spec");
59 static BOOL_PARAM_FLAG(debug_network, false, "Get info on distribution of weight values");
60 static INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
61 static STRING_PARAM_FLAG(traineddata, "", "Combined Dawgs/Unicharset/Recoder for language model");
62 static STRING_PARAM_FLAG(old_traineddata, "",
63  "When changing the character set, this specifies the old"
64  " character set that is to be replaced");
65 static BOOL_PARAM_FLAG(randomly_rotate, false,
66  "Train OSD and randomly turn training samples upside-down");
67 
68 // Number of training images to train between calls to MaintainCheckpoints.
69 const int kNumPagesPerBatch = 100;
70 
71 // Apart from command-line flags, input is a collection of lstmf files, that
72 // were previously created using tesseract with the lstm.train config file.
73 // The program iterates over the inputs, feeding the data to the network,
74 // until the error rate reaches a specified target or max_iterations is reached.
75 int main(int argc, char **argv) {
76  tesseract::CheckSharedLibraryVersion();
77  ParseArguments(&argc, &argv);
78 #if defined(__USE_GNU)
79  if (FLAGS_debug_float) {
80  // Raise SIGFPE for unwanted floating point calculations.
81  feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
82  }
83 #endif
84  if (FLAGS_model_output.empty()) {
85  tprintf("Must provide a --model_output!\n");
86  return EXIT_FAILURE;
87  }
88  if (FLAGS_traineddata.empty()) {
89  tprintf("Must provide a --traineddata see training documentation\n");
90  return EXIT_FAILURE;
91  }
92 
93  // Check write permissions.
94  std::string test_file = FLAGS_model_output.c_str();
95  test_file += "_wtest";
96  FILE *f = fopen(test_file.c_str(), "wb");
97  if (f != nullptr) {
98  fclose(f);
99  if (remove(test_file.c_str()) != 0) {
100  tprintf("Error, failed to remove %s: %s\n", test_file.c_str(), strerror(errno));
101  return EXIT_FAILURE;
102  }
103  } else {
104  tprintf("Error, model output cannot be written: %s\n", strerror(errno));
105  return EXIT_FAILURE;
106  }
107 
108  // Setup the trainer.
109  std::string checkpoint_file = FLAGS_model_output.c_str();
110  checkpoint_file += "_checkpoint";
111  std::string checkpoint_bak = checkpoint_file + ".bak";
112  tesseract::LSTMTrainer trainer(FLAGS_model_output.c_str(), checkpoint_file.c_str(),
113  FLAGS_debug_interval,
114  static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
115  if (!trainer.InitCharSet(FLAGS_traineddata.c_str())) {
116  tprintf("Error, failed to read %s\n", FLAGS_traineddata.c_str());
117  return EXIT_FAILURE;
118  }
119 
120  // Reading something from an existing model doesn't require many flags,
121  // so do it now and exit.
122  if (FLAGS_stop_training || FLAGS_debug_network) {
123  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
124  tprintf("Failed to read continue from: %s\n", FLAGS_continue_from.c_str());
125  return EXIT_FAILURE;
126  }
127  if (FLAGS_debug_network) {
128  trainer.DebugNetwork();
129  } else {
130  if (FLAGS_convert_to_int) {
131  trainer.ConvertToInt();
132  }
133  if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
134  tprintf("Failed to write recognition model : %s\n", FLAGS_model_output.c_str());
135  }
136  }
137  return EXIT_SUCCESS;
138  }
139 
140  // Get the list of files to process.
141  if (FLAGS_train_listfile.empty()) {
142  tprintf("Must supply a list of training filenames! --train_listfile\n");
143  return EXIT_FAILURE;
144  }
145  std::vector<std::string> filenames;
146  if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(), &filenames)) {
147  tprintf("Failed to load list of training filenames from %s\n", FLAGS_train_listfile.c_str());
148  return EXIT_FAILURE;
149  }
150 
151  // Checkpoints always take priority if they are available.
152  if (trainer.TryLoadingCheckpoint(checkpoint_file.c_str(), nullptr) ||
153  trainer.TryLoadingCheckpoint(checkpoint_bak.c_str(), nullptr)) {
154  tprintf("Successfully restored trainer from %s\n", checkpoint_file.c_str());
155  } else {
156  if (!FLAGS_continue_from.empty()) {
157  // Load a past model file to improve upon.
158  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
159  FLAGS_append_index >= 0 ? FLAGS_continue_from.c_str()
160  : FLAGS_old_traineddata.c_str())) {
161  tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
162  return EXIT_FAILURE;
163  }
164  tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
165  if (FLAGS_reset_learning_rate) {
166  trainer.SetLearningRate(FLAGS_learning_rate);
167  tprintf("Set learning rate to %f\n", static_cast<float>(FLAGS_learning_rate));
168  }
169  trainer.InitIterations();
170  }
171  if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
172  if (FLAGS_append_index >= 0) {
173  tprintf("Appending a new network to an old one!!");
174  if (FLAGS_continue_from.empty()) {
175  tprintf("Must set --continue_from for appending!\n");
176  return EXIT_FAILURE;
177  }
178  }
179  // We are initializing from scratch.
180  if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, FLAGS_net_mode,
181  FLAGS_weight_range, FLAGS_learning_rate, FLAGS_momentum,
182  FLAGS_adam_beta)) {
183  tprintf("Failed to create network from spec: %s\n", FLAGS_net_spec.c_str());
184  return EXIT_FAILURE;
185  }
186  trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
187  }
188  }
189  if (!trainer.LoadAllTrainingData(
190  filenames,
191  FLAGS_sequential_training ? tesseract::CS_SEQUENTIAL : tesseract::CS_ROUND_ROBIN,
192  FLAGS_randomly_rotate)) {
193  tprintf("Load of images failed!!\n");
194  return EXIT_FAILURE;
195  }
196 
197  tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
198  tesseract::TestCallback tester_callback = nullptr;
199  if (!FLAGS_eval_listfile.empty()) {
200  using namespace std::placeholders; // for _1, _2, _3...
201  if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
202  tprintf("Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str());
203  return EXIT_FAILURE;
204  }
205  tester_callback = std::bind(&tesseract::LSTMTester::RunEvalAsync, &tester, _1, _2, _3, _4);
206  }
207 
208  int max_iterations = FLAGS_max_iterations;
209  if (max_iterations < 0) {
210  // A negative value is interpreted as epochs
211  max_iterations = filenames.size() * (-max_iterations);
212  } else if (max_iterations == 0) {
213  // "Infinite" iterations.
214  max_iterations = INT_MAX;
215  }
216 
217  do {
218  // Train a few.
219  int iteration = trainer.training_iteration();
220  for (int target_iteration = iteration + kNumPagesPerBatch;
221  iteration < target_iteration && iteration < max_iterations;
222  iteration = trainer.training_iteration()) {
223  trainer.TrainOnLine(&trainer, false);
224  }
225  std::string log_str;
226  trainer.MaintainCheckpoints(tester_callback, log_str);
227  tprintf("%s\n", log_str.c_str());
228  } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
229  (trainer.training_iteration() < max_iterations));
230  tprintf("Finished! Selected model with minimal training error rate (BCER) = %g\n",
231  trainer.best_error_rate());
232  return EXIT_SUCCESS;
233 } /* main */
#define DOUBLE_PARAM_FLAG(name, val, comment)
#define BOOL_PARAM_FLAG(name, val, comment)
#define STRING_PARAM_FLAG(name, val, comment)
int main(int argc, char **argv)
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
void ParseArguments(int *argc, char ***argv)
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:77
INT_PARAM_FLAG(debug_level, 0, "Level of Trainer debugging")
@ CS_SEQUENTIAL
Definition: imagedata.h:49
@ CS_ROUND_ROBIN
Definition: imagedata.h:54
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:58
bool LoadFileLinesToStrings(const char *filename, std::vector< std::string > *lines)
Definition: fileio.h:32
void SetLearningRate(float learning_rate)
std::string RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
Definition: lstmtester.cpp:50
bool LoadAllEvalData(const char *filenames_file)
Definition: lstmtester.cpp:29
bool LoadAllTrainingData(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:99
bool MaintainCheckpoints(const TestCallback &tester, std::string &log_msg)
double best_error_rate() const
Definition: lstmtrainer.h:138
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:267
bool SaveTraineddata(const char *filename)
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:150
bool InitNetwork(const char *network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)