tesseract  5.0.0
lstmrecognizer.cpp
Go to the documentation of this file.
1 // File: lstmrecognizer.cpp
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 // Include automatically generated configuration file if running autoconf.
19 #ifdef HAVE_CONFIG_H
20 # include "config_auto.h"
21 #endif
22 
23 #include "lstmrecognizer.h"
24 
25 #include <allheaders.h>
26 #include "dict.h"
27 #include "genericheap.h"
28 #include "helpers.h"
29 #include "imagedata.h"
30 #include "input.h"
31 #include "lstm.h"
32 #include "normalis.h"
33 #include "pageres.h"
34 #include "ratngs.h"
35 #include "recodebeam.h"
36 #include "scrollview.h"
37 #include "statistc.h"
38 #include "tprintf.h"
39 
40 #include <unordered_set>
41 #include <vector>
42 
43 namespace tesseract {
44 
45 // Default ratio between dict and non-dict words.
46 const double kDictRatio = 2.25;
47 // Default certainty offset to give the dictionary a chance.
48 const double kCertOffset = -0.085;
49 
50 LSTMRecognizer::LSTMRecognizer(const std::string &language_data_path_prefix)
52  ccutil_.language_data_path_prefix = language_data_path_prefix;
53 }
54 
56  : network_(nullptr)
57  , training_flags_(0)
58  , training_iteration_(0)
59  , sample_iteration_(0)
60  , null_char_(UNICHAR_BROKEN)
61  , learning_rate_(0.0f)
62  , momentum_(0.0f)
63  , adam_beta_(0.0f)
64  , dict_(nullptr)
65  , search_(nullptr)
66  , debug_win_(nullptr) {}
67 
69  delete network_;
70  delete dict_;
71  delete search_;
72 }
73 
74 // Loads a model from mgr, including the dictionary only if lang is not null.
75 bool LSTMRecognizer::Load(const ParamsVectors *params, const std::string &lang,
76  TessdataManager *mgr) {
77  TFile fp;
78  if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) {
79  return false;
80  }
81  if (!DeSerialize(mgr, &fp)) {
82  return false;
83  }
84  if (lang.empty()) {
85  return true;
86  }
87  // Allow it to run without a dictionary.
88  LoadDictionary(params, lang, mgr);
89  return true;
90 }
91 
92 // Writes to the given file. Returns false in case of error.
93 bool LSTMRecognizer::Serialize(const TessdataManager *mgr, TFile *fp) const {
94  bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
96  if (!network_->Serialize(fp)) {
97  return false;
98  }
99  if (include_charsets && !GetUnicharset().save_to_file(fp)) {
100  return false;
101  }
102  if (!fp->Serialize(network_str_)) {
103  return false;
104  }
105  if (!fp->Serialize(&training_flags_)) {
106  return false;
107  }
108  if (!fp->Serialize(&training_iteration_)) {
109  return false;
110  }
111  if (!fp->Serialize(&sample_iteration_)) {
112  return false;
113  }
114  if (!fp->Serialize(&null_char_)) {
115  return false;
116  }
117  if (!fp->Serialize(&adam_beta_)) {
118  return false;
119  }
120  if (!fp->Serialize(&learning_rate_)) {
121  return false;
122  }
123  if (!fp->Serialize(&momentum_)) {
124  return false;
125  }
126  if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) {
127  return false;
128  }
129  return true;
130 }
131 
132 // Reads from the given file. Returns false in case of error.
134  delete network_;
136  if (network_ == nullptr) {
137  return false;
138  }
139  bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
141  if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) {
142  return false;
143  }
144  if (!fp->DeSerialize(network_str_)) {
145  return false;
146  }
147  if (!fp->DeSerialize(&training_flags_)) {
148  return false;
149  }
150  if (!fp->DeSerialize(&training_iteration_)) {
151  return false;
152  }
153  if (!fp->DeSerialize(&sample_iteration_)) {
154  return false;
155  }
156  if (!fp->DeSerialize(&null_char_)) {
157  return false;
158  }
159  if (!fp->DeSerialize(&adam_beta_)) {
160  return false;
161  }
162  if (!fp->DeSerialize(&learning_rate_)) {
163  return false;
164  }
165  if (!fp->DeSerialize(&momentum_)) {
166  return false;
167  }
168  if (include_charsets && !LoadRecoder(fp)) {
169  return false;
170  }
171  if (!include_charsets && !LoadCharsets(mgr)) {
172  return false;
173  }
176  return true;
177 }
178 
179 // Loads the charsets from mgr.
181  TFile fp;
182  if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
183  return false;
184  }
185  if (!ccutil_.unicharset.load_from_file(&fp, false)) {
186  return false;
187  }
188  if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
189  return false;
190  }
191  if (!LoadRecoder(&fp)) {
192  return false;
193  }
194  return true;
195 }
196 
197 // Loads the Recoder.
199  if (IsRecoding()) {
200  if (!recoder_.DeSerialize(fp)) {
201  return false;
202  }
203  RecodedCharID code;
205  if (code(0) != UNICHAR_SPACE) {
206  tprintf("Space was garbled in recoding!!\n");
207  return false;
208  }
209  } else {
212  }
213  return true;
214 }
215 
216 // Loads the dictionary if possible from the traineddata file.
217 // Prints a warning message, and returns false but otherwise fails silently
218 // and continues to work without it if loading fails.
219 // Note that dictionary load is independent from DeSerialize, but dependent
220 // on the unicharset matching. This enables training to deserialize a model
221 // from checkpoint or restore without having to go back and reload the
222 // dictionary.
223 // Some parameters have to be passed in (from langdata/config/api via Tesseract)
224 bool LSTMRecognizer::LoadDictionary(const ParamsVectors *params, const std::string &lang,
225  TessdataManager *mgr) {
226  delete dict_;
227  dict_ = new Dict(&ccutil_);
228  dict_->user_words_file.ResetFrom(params);
229  dict_->user_words_suffix.ResetFrom(params);
230  dict_->user_patterns_file.ResetFrom(params);
231  dict_->user_patterns_suffix.ResetFrom(params);
233  dict_->LoadLSTM(lang, mgr);
234  if (dict_->FinishLoad()) {
235  return true; // Success.
236  }
237  if (log_level <= 0) {
238  tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", lang.c_str());
239  }
240  delete dict_;
241  dict_ = nullptr;
242  return false;
243 }
244 
245 // Recognizes the line image, contained within image_data, returning the
246 // ratings matrix and matching box_word for each WERD_RES in the output.
247 void LSTMRecognizer::RecognizeLine(const ImageData &image_data, bool invert, bool debug,
248  double worst_dict_cert, const TBOX &line_box,
249  PointerVector<WERD_RES> *words, int lstm_choice_mode,
250  int lstm_choice_amount) {
251  NetworkIO outputs;
252  float scale_factor;
253  NetworkIO inputs;
254  if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor, &inputs, &outputs)) {
255  return;
256  }
257  if (search_ == nullptr) {
259  }
260  search_->excludedUnichars.clear();
261  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, &GetUnicharset(),
262  lstm_choice_mode);
263  search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words,
264  lstm_choice_mode);
265  if (lstm_choice_mode) {
267  for (int i = 0; i < lstm_choice_amount; ++i) {
268  search_->DecodeSecondaryBeams(outputs, kDictRatio, kCertOffset, worst_dict_cert,
269  &GetUnicharset(), lstm_choice_mode);
271  }
273  unsigned char_it = 0;
274  for (size_t i = 0; i < words->size(); ++i) {
275  for (int j = 0; j < words->at(i)->end; ++j) {
276  if (char_it < search_->ctc_choices.size()) {
277  words->at(i)->CTC_symbol_choices.push_back(search_->ctc_choices[char_it]);
278  }
279  if (char_it < search_->segmentedTimesteps.size()) {
280  words->at(i)->segmented_timesteps.push_back(search_->segmentedTimesteps[char_it]);
281  }
282  ++char_it;
283  }
284  words->at(i)->timesteps =
285  search_->combineSegmentedTimesteps(&words->at(i)->segmented_timesteps);
286  }
287  search_->segmentedTimesteps.clear();
288  search_->ctc_choices.clear();
289  search_->excludedUnichars.clear();
290  }
291 }
292 
293 // Helper computes min and mean best results in the output.
294 void LSTMRecognizer::OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output,
295  float *sd) {
296  const int kOutputScale = INT8_MAX;
297  STATS stats(0, kOutputScale + 1);
298  for (int t = 0; t < outputs.Width(); ++t) {
299  int best_label = outputs.BestLabel(t, nullptr);
300  if (best_label != null_char_) {
301  float best_output = outputs.f(t)[best_label];
302  stats.add(static_cast<int>(kOutputScale * best_output), 1);
303  }
304  }
305  // If the output is all nulls it could be that the photometric interpretation
306  // is wrong, so make it look bad, so the other way can win, even if not great.
307  if (stats.get_total() == 0) {
308  *min_output = 0.0f;
309  *mean_output = 0.0f;
310  *sd = 1.0f;
311  } else {
312  *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
313  *mean_output = stats.mean() / kOutputScale;
314  *sd = stats.sd() / kOutputScale;
315  }
316 }
317 
318 // Recognizes the image_data, returning the labels,
319 // scores, and corresponding pairs of start, end x-coords in coords.
320 bool LSTMRecognizer::RecognizeLine(const ImageData &image_data, bool invert, bool debug,
321  bool re_invert, bool upside_down, float *scale_factor,
322  NetworkIO *inputs, NetworkIO *outputs) {
323  // This ensures consistent recognition results.
324  SetRandomSeed();
325  int min_width = network_->XScaleFactor();
326  Image pix = Input::PrepareLSTMInputs(image_data, network_, min_width, &randomizer_, scale_factor);
327  if (pix == nullptr) {
328  tprintf("Line cannot be recognized!!\n");
329  return false;
330  }
331  // Maximum width of image to train on.
332  const int kMaxImageWidth = 128 * pixGetHeight(pix);
333  if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
334  tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix));
335  pix.destroy();
336  return false;
337  }
338  if (upside_down) {
339  pixRotate180(pix, pix);
340  }
341  // Reduction factor from image to coords.
342  *scale_factor = min_width / *scale_factor;
343  inputs->set_int_mode(IsIntMode());
344  SetRandomSeed();
346  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
347  // Check for auto inversion.
348  if (invert) {
349  float pos_min, pos_mean, pos_sd;
350  OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
351  if (pos_mean < 0.5f) {
352  // Run again inverted and see if it is any better.
353  NetworkIO inv_inputs, inv_outputs;
354  inv_inputs.set_int_mode(IsIntMode());
355  SetRandomSeed();
356  pixInvert(pix, pix);
357  Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, &inv_inputs);
358  network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs);
359  float inv_min, inv_mean, inv_sd;
360  OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
361  if (inv_mean > pos_mean) {
362  // Inverted did better. Use inverted data.
363  if (debug) {
364  tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
365  pos_sd, inv_min, inv_mean, inv_sd);
366  }
367  *outputs = inv_outputs;
368  *inputs = inv_inputs;
369  } else if (re_invert) {
370  // Inverting was not an improvement, so undo and run again, so the
371  // outputs match the best forward result.
372  SetRandomSeed();
373  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
374  }
375  }
376  }
377 
378  pix.destroy();
379  if (debug) {
380  std::vector<int> labels, coords;
381  LabelsFromOutputs(*outputs, &labels, &coords);
382 #ifndef GRAPHICS_DISABLED
383  DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
384 #endif
385  DebugActivationPath(*outputs, labels, coords);
386  }
387  return true;
388 }
389 
390 // Converts an array of labels to utf-8, whether or not the labels are
391 // augmented with character boundaries.
392 std::string LSTMRecognizer::DecodeLabels(const std::vector<int> &labels) {
393  std::string result;
394  unsigned end = 1;
395  for (unsigned start = 0; start < labels.size(); start = end) {
396  if (labels[start] == null_char_) {
397  end = start + 1;
398  } else {
399  result += DecodeLabel(labels, start, &end, nullptr);
400  }
401  }
402  return result;
403 }
404 
405 #ifndef GRAPHICS_DISABLED
406 
407 // Displays the forward results in a window with the characters and
408 // boundaries as determined by the labels and label_coords.
409 void LSTMRecognizer::DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels,
410  const std::vector<int> &label_coords, const char *window_name,
411  ScrollView **window) {
412  Image input_pix = inputs.ToPix();
413  Network::ClearWindow(false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window);
414  int line_height = Network::DisplayImage(input_pix, *window);
415  DisplayLSTMOutput(labels, label_coords, line_height, *window);
416 }
417 
418 // Displays the labels and cuts at the corresponding xcoords.
419 // Size of labels should match xcoords.
420 void LSTMRecognizer::DisplayLSTMOutput(const std::vector<int> &labels,
421  const std::vector<int> &xcoords, int height,
422  ScrollView *window) {
423  int x_scale = network_->XScaleFactor();
424  window->TextAttributes("Arial", height / 4, false, false, false);
425  unsigned end = 1;
426  for (unsigned start = 0; start < labels.size(); start = end) {
427  int xpos = xcoords[start] * x_scale;
428  if (labels[start] == null_char_) {
429  end = start + 1;
430  window->Pen(ScrollView::RED);
431  } else {
432  window->Pen(ScrollView::GREEN);
433  const char *str = DecodeLabel(labels, start, &end, nullptr);
434  if (*str == '\\') {
435  str = "\\\\";
436  }
437  xpos = xcoords[(start + end) / 2] * x_scale;
438  window->Text(xpos, height, str);
439  }
440  window->Line(xpos, 0, xpos, height * 3 / 2);
441  }
442  window->Update();
443 }
444 
445 #endif // !GRAPHICS_DISABLED
446 
447 // Prints debug output detailing the activation path that is implied by the
448 // label_coords.
449 void LSTMRecognizer::DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels,
450  const std::vector<int> &xcoords) {
451  if (xcoords[0] > 0) {
452  DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
453  }
454  unsigned end = 1;
455  for (unsigned start = 0; start < labels.size(); start = end) {
456  if (labels[start] == null_char_) {
457  end = start + 1;
458  DebugActivationRange(outputs, "<null>", null_char_, xcoords[start], xcoords[end]);
459  continue;
460  } else {
461  int decoded;
462  const char *label = DecodeLabel(labels, start, &end, &decoded);
463  DebugActivationRange(outputs, label, labels[start], xcoords[start], xcoords[start + 1]);
464  for (unsigned i = start + 1; i < end; ++i) {
465  DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], xcoords[i],
466  xcoords[i + 1]);
467  }
468  }
469  }
470 }
471 
472 // Prints debug output detailing activations and 2nd choice over a range
473 // of positions.
474 void LSTMRecognizer::DebugActivationRange(const NetworkIO &outputs, const char *label,
475  int best_choice, int x_start, int x_end) {
476  tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
477  double max_score = 0.0;
478  double mean_score = 0.0;
479  const int width = x_end - x_start;
480  for (int x = x_start; x < x_end; ++x) {
481  const float *line = outputs.f(x);
482  const double score = line[best_choice] * 100.0;
483  if (score > max_score) {
484  max_score = score;
485  }
486  mean_score += score / width;
487  int best_c = 0;
488  double best_score = 0.0;
489  for (int c = 0; c < outputs.NumFeatures(); ++c) {
490  if (c != best_choice && line[c] > best_score) {
491  best_c = c;
492  best_score = line[c];
493  }
494  }
495  tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, best_score * 100.0);
496  }
497  tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
498 }
499 
500 // Helper returns true if the null_char is the winner at t, and it beats the
501 // null_threshold, or the next choice is space, in which case we will use the
502 // null anyway.
503 #if 0 // TODO: unused, remove if still unused after 2020.
504 static bool NullIsBest(const NetworkIO& output, float null_thr,
505  int null_char, int t) {
506  if (output.f(t)[null_char] >= null_thr) return true;
507  if (output.BestLabel(t, null_char, null_char, nullptr) != UNICHAR_SPACE)
508  return false;
509  return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE];
510 }
511 #endif
512 
513 // Converts the network output to a sequence of labels. Outputs labels, scores
514 // and start xcoords of each char, and each null_char_, with an additional
515 // final xcoord for the end of the output.
516 // The conversion method is determined by internal state.
517 void LSTMRecognizer::LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels,
518  std::vector<int> *xcoords) {
519  if (SimpleTextOutput()) {
520  LabelsViaSimpleText(outputs, labels, xcoords);
521  } else {
522  LabelsViaReEncode(outputs, labels, xcoords);
523  }
524 }
525 
526 // As LabelsViaCTC except that this function constructs the best path that
527 // contains only legal sequences of subcodes for CJK.
528 void LSTMRecognizer::LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels,
529  std::vector<int> *xcoords) {
530  if (search_ == nullptr) {
532  }
533  search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, nullptr);
534  search_->ExtractBestPathAsLabels(labels, xcoords);
535 }
536 
537 // Converts the network output to a sequence of labels, with scores, using
538 // the simple character model (each position is a char, and the null_char_ is
539 // mainly intended for tail padding.)
540 void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels,
541  std::vector<int> *xcoords) {
542  labels->clear();
543  xcoords->clear();
544  const int width = output.Width();
545  for (int t = 0; t < width; ++t) {
546  float score = 0.0f;
547  const int label = output.BestLabel(t, &score);
548  if (label != null_char_) {
549  labels->push_back(label);
550  xcoords->push_back(t);
551  }
552  }
553  xcoords->push_back(width);
554 }
555 
556 // Returns a string corresponding to the label starting at start. Sets *end
557 // to the next start and if non-null, *decoded to the unichar id.
558 const char *LSTMRecognizer::DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end,
559  int *decoded) {
560  *end = start + 1;
561  if (IsRecoding()) {
562  // Decode labels via recoder_.
563  RecodedCharID code;
564  if (labels[start] == null_char_) {
565  if (decoded != nullptr) {
566  code.Set(0, null_char_);
567  *decoded = recoder_.DecodeUnichar(code);
568  }
569  return "<null>";
570  }
571  unsigned index = start;
572  while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) {
573  code.Set(code.length(), labels[index++]);
574  while (index < labels.size() && labels[index] == null_char_) {
575  ++index;
576  }
577  int uni_id = recoder_.DecodeUnichar(code);
578  // If the next label isn't a valid first code, then we need to continue
579  // extending even if we have a valid uni_id from this prefix.
580  if (uni_id != INVALID_UNICHAR_ID &&
581  (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen ||
582  recoder_.IsValidFirstCode(labels[index]))) {
583  *end = index;
584  if (decoded != nullptr) {
585  *decoded = uni_id;
586  }
587  if (uni_id == UNICHAR_SPACE) {
588  return " ";
589  }
590  return GetUnicharset().get_normed_unichar(uni_id);
591  }
592  }
593  return "<Undecodable>";
594  } else {
595  if (decoded != nullptr) {
596  *decoded = labels[start];
597  }
598  if (labels[start] == null_char_) {
599  return "<null>";
600  }
601  if (labels[start] == UNICHAR_SPACE) {
602  return " ";
603  }
604  return GetUnicharset().get_normed_unichar(labels[start]);
605  }
606 }
607 
608 // Returns a string corresponding to a given single label id, falling back to
609 // a default of ".." for part of a multi-label unichar-id.
610 const char *LSTMRecognizer::DecodeSingleLabel(int label) {
611  if (label == null_char_) {
612  return "<null>";
613  }
614  if (IsRecoding()) {
615  // Decode label via recoder_.
616  RecodedCharID code;
617  code.Set(0, label);
618  label = recoder_.DecodeUnichar(code);
619  if (label == INVALID_UNICHAR_ID) {
620  return ".."; // Part of a bigger code.
621  }
622  }
623  if (label == UNICHAR_SPACE) {
624  return " ";
625  }
626  return GetUnicharset().get_normed_unichar(label);
627 }
628 
629 } // namespace tesseract.
@ TF_COMPRESS_UNICHARSET
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
const double kCertOffset
@ UNICHAR_SPACE
Definition: unicharset.h:36
@ UNICHAR_BROKEN
Definition: unicharset.h:38
const double kDictRatio
int log_level
Definition: tprintf.cpp:36
void destroy()
Definition: image.cpp:32
void add(int32_t value, int32_t count)
Definition: statistc.cpp:99
int32_t get_total() const
Definition: statistc.h:85
int32_t min_bucket() const
Definition: statistc.cpp:205
double sd() const
Definition: statistc.cpp:149
double mean() const
Definition: statistc.cpp:133
std::string language_data_path_prefix
Definition: ccutil.h:60
UNICHARSET unicharset
Definition: ccutil.h:61
unsigned size() const
Definition: genericvector.h:74
T & at(int index) const
Definition: genericvector.h:93
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
bool GetComponent(TessdataType type, TFile *fp)
bool IsComponentAvailable(TessdataType type) const
void Set(int index, int value)
static const int kMaxCodeLen
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
bool IsValidFirstCode(int code) const
void SetupPassThrough(const UNICHARSET &unicharset)
int DecodeUnichar(const RecodedCharID &code) const
bool Serialize(TFile *fp) const
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:860
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
static DawgCache * GlobalDawgCache()
Definition: dict.cpp:172
void LoadLSTM(const std::string &lang, TessdataManager *data_file)
Definition: dict.cpp:291
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:180
bool FinishLoad()
Definition: dict.cpp:357
static Image PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:81
static void PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:107
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
std::string DecodeLabels(const std::vector< int > &labels)
NetworkScratch scratch_space_
void LabelsViaReEncode(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
const char * DecodeSingleLabel(int label)
bool LoadCharsets(const TessdataManager *mgr)
const UNICHARSET & GetUnicharset() const
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
const char * DecodeLabel(const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
RecodeBeamSearch * search_
void LabelsViaSimpleText(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayLSTMOutput(const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
bool Serialize(const TessdataManager *mgr, TFile *fp) const
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
virtual int XScaleFactor() const
Definition: network.h:214
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:217
bool IsTraining() const
Definition: network.h:113
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158
static int DisplayImage(Image pix, ScrollView *window)
Definition: network.cpp:378
virtual void CacheXScaleFactor([[maybe_unused]] int factor)
Definition: network.h:220
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145
virtual StaticShape InputShape() const
Definition: network.h:129
Image ToPix() const
Definition: networkio.cpp:300
float * f(int t)
Definition: networkio.h:111
int Width() const
Definition: networkio.h:103
void set_int_mode(bool is_quantized)
Definition: networkio.h:126
int NumFeatures() const
Definition: networkio.h:107
int BestLabel(int t, float *score) const
Definition: networkio.h:163
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:89
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
Definition: recodebeam.h:237
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:118
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
Definition: recodebeam.h:235
void extractSymbolChoices(const UNICHARSET *unicharset)
Definition: recodebeam.cpp:415
std::vector< std::unordered_set< int > > excludedUnichars
Definition: recodebeam.h:239
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float >>>> *segmentedTimesteps)
Definition: recodebeam.cpp:181
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:207
static constexpr float kMinCertainty
Definition: recodebeam.h:246
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:245
void Line(int x1, int y1, int x2, int y2)
Definition: scrollview.cpp:511
void TextAttributes(const char *font, int pixel_size, bool bold, bool italic, bool underlined)
Definition: scrollview.cpp:623
void Text(int x, int y, const char *mystring)
Definition: scrollview.cpp:648
void Pen(Color color)
Definition: scrollview.cpp:723
static void Update()
Definition: scrollview.cpp:713