tesseract  5.0.0
tesseract::LSTMRecognizer Class Reference

#include <lstmrecognizer.h>

Inheritance diagram for tesseract::LSTMRecognizer:
tesseract::LSTMTrainer

Public Member Functions

 LSTMRecognizer ()
 
 LSTMRecognizer (const std::string &language_data_path_prefix)
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
float learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
std::vector< std::string > EnumerateLayers () const
 
NetworkGetLayer (const std::string &id) const
 
float GetLayerLearningRate (const std::string &id) const
 
const char * GetNetwork () const
 
float GetAdamBeta () const
 
float GetMomentum () const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const std::string &id, double factor)
 
void SetLearningRate (float learning_rate)
 
void SetLayerLearningRate (const std::string &id, float learning_rate)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
UNICHARSETGetUnicharset ()
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
DictGetDict ()
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
std::string DecodeLabels (const std::vector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
 

Protected Member Functions

void SetRandomSeed ()
 
void DisplayLSTMOutput (const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
 
const char * DecodeLabel (const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
std::string network_str_
 
int32_t training_flags_
 
int32_t training_iteration_
 
int32_t sample_iteration_
 
int32_t null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Detailed Description

Definition at line 51 of file lstmrecognizer.h.

Constructor & Destructor Documentation

◆ LSTMRecognizer() [1/2]

tesseract::LSTMRecognizer::LSTMRecognizer ( )

Definition at line 55 of file lstmrecognizer.cpp.

56  : network_(nullptr)
57  , training_flags_(0)
61  , learning_rate_(0.0f)
62  , momentum_(0.0f)
63  , adam_beta_(0.0f)
64  , dict_(nullptr)
65  , search_(nullptr)
66  , debug_win_(nullptr) {}
@ UNICHAR_BROKEN
Definition: unicharset.h:38
RecodeBeamSearch * search_

◆ LSTMRecognizer() [2/2]

tesseract::LSTMRecognizer::LSTMRecognizer ( const std::string &  language_data_path_prefix)

Definition at line 50 of file lstmrecognizer.cpp.

52  ccutil_.language_data_path_prefix = language_data_path_prefix;
53 }
std::string language_data_path_prefix
Definition: ccutil.h:60

◆ ~LSTMRecognizer()

tesseract::LSTMRecognizer::~LSTMRecognizer ( )

Definition at line 68 of file lstmrecognizer.cpp.

68  {
69  delete network_;
70  delete dict_;
71  delete search_;
72 }

Member Function Documentation

◆ ConvertToInt()

void tesseract::LSTMRecognizer::ConvertToInt ( )
inline

Definition at line 181 of file lstmrecognizer.h.

181  {
182  if ((training_flags_ & TF_INT_MODE) == 0) {
185  }
186  }
virtual void ConvertToInt()
Definition: network.h:196

◆ DebugActivationPath()

void tesseract::LSTMRecognizer::DebugActivationPath ( const NetworkIO outputs,
const std::vector< int > &  labels,
const std::vector< int > &  xcoords 
)
protected

Definition at line 449 of file lstmrecognizer.cpp.

450  {
451  if (xcoords[0] > 0) {
452  DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
453  }
454  unsigned end = 1;
455  for (unsigned start = 0; start < labels.size(); start = end) {
456  if (labels[start] == null_char_) {
457  end = start + 1;
458  DebugActivationRange(outputs, "<null>", null_char_, xcoords[start], xcoords[end]);
459  continue;
460  } else {
461  int decoded;
462  const char *label = DecodeLabel(labels, start, &end, &decoded);
463  DebugActivationRange(outputs, label, labels[start], xcoords[start], xcoords[start + 1]);
464  for (unsigned i = start + 1; i < end; ++i) {
465  DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], xcoords[i],
466  xcoords[i + 1]);
467  }
468  }
469  }
470 }
const char * DecodeSingleLabel(int label)
const char * DecodeLabel(const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)

◆ DebugActivationRange()

void tesseract::LSTMRecognizer::DebugActivationRange ( const NetworkIO outputs,
const char *  label,
int  best_choice,
int  x_start,
int  x_end 
)
protected

Definition at line 474 of file lstmrecognizer.cpp.

475  {
476  tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
477  double max_score = 0.0;
478  double mean_score = 0.0;
479  const int width = x_end - x_start;
480  for (int x = x_start; x < x_end; ++x) {
481  const float *line = outputs.f(x);
482  const double score = line[best_choice] * 100.0;
483  if (score > max_score) {
484  max_score = score;
485  }
486  mean_score += score / width;
487  int best_c = 0;
488  double best_score = 0.0;
489  for (int c = 0; c < outputs.NumFeatures(); ++c) {
490  if (c != best_choice && line[c] > best_score) {
491  best_c = c;
492  best_score = line[c];
493  }
494  }
495  tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, best_score * 100.0);
496  }
497  tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
498 }
void tprintf(const char *format,...)
Definition: tprintf.cpp:41

◆ DecodeLabel()

const char * tesseract::LSTMRecognizer::DecodeLabel ( const std::vector< int > &  labels,
unsigned  start,
unsigned *  end,
int *  decoded 
)
protected

Definition at line 558 of file lstmrecognizer.cpp.

559  {
560  *end = start + 1;
561  if (IsRecoding()) {
562  // Decode labels via recoder_.
563  RecodedCharID code;
564  if (labels[start] == null_char_) {
565  if (decoded != nullptr) {
566  code.Set(0, null_char_);
567  *decoded = recoder_.DecodeUnichar(code);
568  }
569  return "<null>";
570  }
571  unsigned index = start;
572  while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) {
573  code.Set(code.length(), labels[index++]);
574  while (index < labels.size() && labels[index] == null_char_) {
575  ++index;
576  }
577  int uni_id = recoder_.DecodeUnichar(code);
578  // If the next label isn't a valid first code, then we need to continue
579  // extending even if we have a valid uni_id from this prefix.
580  if (uni_id != INVALID_UNICHAR_ID &&
581  (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen ||
582  recoder_.IsValidFirstCode(labels[index]))) {
583  *end = index;
584  if (decoded != nullptr) {
585  *decoded = uni_id;
586  }
587  if (uni_id == UNICHAR_SPACE) {
588  return " ";
589  }
590  return GetUnicharset().get_normed_unichar(uni_id);
591  }
592  }
593  return "<Undecodable>";
594  } else {
595  if (decoded != nullptr) {
596  *decoded = labels[start];
597  }
598  if (labels[start] == null_char_) {
599  return "<null>";
600  }
601  if (labels[start] == UNICHAR_SPACE) {
602  return " ";
603  }
604  return GetUnicharset().get_normed_unichar(labels[start]);
605  }
606 }
@ UNICHAR_SPACE
Definition: unicharset.h:36
static const int kMaxCodeLen
bool IsValidFirstCode(int code) const
int DecodeUnichar(const RecodedCharID &code) const
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:860
const UNICHARSET & GetUnicharset() const

◆ DecodeLabels()

std::string tesseract::LSTMRecognizer::DecodeLabels ( const std::vector< int > &  labels)

Definition at line 392 of file lstmrecognizer.cpp.

392  {
393  std::string result;
394  unsigned end = 1;
395  for (unsigned start = 0; start < labels.size(); start = end) {
396  if (labels[start] == null_char_) {
397  end = start + 1;
398  } else {
399  result += DecodeLabel(labels, start, &end, nullptr);
400  }
401  }
402  return result;
403 }

◆ DecodeSingleLabel()

const char * tesseract::LSTMRecognizer::DecodeSingleLabel ( int  label)
protected

Definition at line 610 of file lstmrecognizer.cpp.

610  {
611  if (label == null_char_) {
612  return "<null>";
613  }
614  if (IsRecoding()) {
615  // Decode label via recoder_.
616  RecodedCharID code;
617  code.Set(0, label);
618  label = recoder_.DecodeUnichar(code);
619  if (label == INVALID_UNICHAR_ID) {
620  return ".."; // Part of a bigger code.
621  }
622  }
623  if (label == UNICHAR_SPACE) {
624  return " ";
625  }
626  return GetUnicharset().get_normed_unichar(label);
627 }

◆ DeSerialize()

bool tesseract::LSTMRecognizer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)

Definition at line 133 of file lstmrecognizer.cpp.

133  {
134  delete network_;
136  if (network_ == nullptr) {
137  return false;
138  }
139  bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
140  !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
141  if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) {
142  return false;
143  }
144  if (!fp->DeSerialize(network_str_)) {
145  return false;
146  }
147  if (!fp->DeSerialize(&training_flags_)) {
148  return false;
149  }
150  if (!fp->DeSerialize(&training_iteration_)) {
151  return false;
152  }
153  if (!fp->DeSerialize(&sample_iteration_)) {
154  return false;
155  }
156  if (!fp->DeSerialize(&null_char_)) {
157  return false;
158  }
159  if (!fp->DeSerialize(&adam_beta_)) {
160  return false;
161  }
162  if (!fp->DeSerialize(&learning_rate_)) {
163  return false;
164  }
165  if (!fp->DeSerialize(&momentum_)) {
166  return false;
167  }
168  if (include_charsets && !LoadRecoder(fp)) {
169  return false;
170  }
171  if (!include_charsets && !LoadCharsets(mgr)) {
172  return false;
173  }
176  return true;
177 }
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
UNICHARSET unicharset
Definition: ccutil.h:61
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
bool LoadCharsets(const TessdataManager *mgr)
virtual int XScaleFactor() const
Definition: network.h:214
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:217
virtual void CacheXScaleFactor([[maybe_unused]] int factor)
Definition: network.h:220
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145

◆ DisplayForward()

void tesseract::LSTMRecognizer::DisplayForward ( const NetworkIO inputs,
const std::vector< int > &  labels,
const std::vector< int > &  label_coords,
const char *  window_name,
ScrollView **  window 
)

Definition at line 409 of file lstmrecognizer.cpp.

411  {
412  Image input_pix = inputs.ToPix();
413  Network::ClearWindow(false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window);
414  int line_height = Network::DisplayImage(input_pix, *window);
415  DisplayLSTMOutput(labels, label_coords, line_height, *window);
416 }
void DisplayLSTMOutput(const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350
static int DisplayImage(Image pix, ScrollView *window)
Definition: network.cpp:378

◆ DisplayLSTMOutput()

void tesseract::LSTMRecognizer::DisplayLSTMOutput ( const std::vector< int > &  labels,
const std::vector< int > &  xcoords,
int  height,
ScrollView window 
)
protected

Definition at line 420 of file lstmrecognizer.cpp.

422  {
423  int x_scale = network_->XScaleFactor();
424  window->TextAttributes("Arial", height / 4, false, false, false);
425  unsigned end = 1;
426  for (unsigned start = 0; start < labels.size(); start = end) {
427  int xpos = xcoords[start] * x_scale;
428  if (labels[start] == null_char_) {
429  end = start + 1;
430  window->Pen(ScrollView::RED);
431  } else {
432  window->Pen(ScrollView::GREEN);
433  const char *str = DecodeLabel(labels, start, &end, nullptr);
434  if (*str == '\\') {
435  str = "\\\\";
436  }
437  xpos = xcoords[(start + end) / 2] * x_scale;
438  window->Text(xpos, height, str);
439  }
440  window->Line(xpos, 0, xpos, height * 3 / 2);
441  }
442  window->Update();
443 }

◆ EnumerateLayers()

std::vector<std::string> tesseract::LSTMRecognizer::EnumerateLayers ( ) const
inline

Definition at line 100 of file lstmrecognizer.h.

100  {
101  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
102  auto *series = static_cast<Series *>(network_);
103  std::vector<std::string> layers;
104  series->EnumerateLayers(nullptr, layers);
105  return layers;
106  }
#define ASSERT_HOST(x)
Definition: errcode.h:59
@ NT_SERIES
Definition: network.h:52
NetworkType type() const
Definition: network.h:110

◆ GetAdamBeta()

float tesseract::LSTMRecognizer::GetAdamBeta ( ) const
inline

Definition at line 132 of file lstmrecognizer.h.

132  {
133  return adam_beta_;
134  }

◆ GetDict() [1/2]

Dict* tesseract::LSTMRecognizer::GetDict ( )
inline

Definition at line 203 of file lstmrecognizer.h.

203  {
204  return dict_;
205  }

◆ GetDict() [2/2]

const Dict* tesseract::LSTMRecognizer::GetDict ( ) const
inline

Definition at line 200 of file lstmrecognizer.h.

200  {
201  return dict_;
202  }

◆ GetLayer()

Network* tesseract::LSTMRecognizer::GetLayer ( const std::string &  id) const
inline

Definition at line 108 of file lstmrecognizer.h.

108  {
109  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
110  ASSERT_HOST(id.length() > 1 && id[0] == ':');
111  auto *series = static_cast<Series *>(network_);
112  return series->GetLayer(&id[1]);
113  }

◆ GetLayerLearningRate()

float tesseract::LSTMRecognizer::GetLayerLearningRate ( const std::string &  id) const
inline

Definition at line 115 of file lstmrecognizer.h.

115  {
116  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
118  ASSERT_HOST(id.length() > 1 && id[0] == ':');
119  auto *series = static_cast<Series *>(network_);
120  return series->LayerLearningRate(&id[1]);
121  } else {
122  return learning_rate_;
123  }
124  }
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146

◆ GetMomentum()

float tesseract::LSTMRecognizer::GetMomentum ( ) const
inline

Definition at line 137 of file lstmrecognizer.h.

137  {
138  return momentum_;
139  }

◆ GetNetwork()

const char* tesseract::LSTMRecognizer::GetNetwork ( ) const
inline

Definition at line 127 of file lstmrecognizer.h.

127  {
128  return network_str_.c_str();
129  }

◆ GetRecoder()

const UnicharCompress& tesseract::LSTMRecognizer::GetRecoder ( ) const
inline

Definition at line 196 of file lstmrecognizer.h.

196  {
197  return recoder_;
198  }

◆ GetUnicharset() [1/2]

UNICHARSET& tesseract::LSTMRecognizer::GetUnicharset ( )
inline

Definition at line 192 of file lstmrecognizer.h.

192  {
193  return ccutil_.unicharset;
194  }

◆ GetUnicharset() [2/2]

const UNICHARSET& tesseract::LSTMRecognizer::GetUnicharset ( ) const
inline

Definition at line 189 of file lstmrecognizer.h.

189  {
190  return ccutil_.unicharset;
191  }

◆ IsIntMode()

bool tesseract::LSTMRecognizer::IsIntMode ( ) const
inline

Definition at line 87 of file lstmrecognizer.h.

87  {
88  return (training_flags_ & TF_INT_MODE) != 0;
89  }

◆ IsRecoding()

bool tesseract::LSTMRecognizer::IsRecoding ( ) const
inline

Definition at line 91 of file lstmrecognizer.h.

91  {
93  }
@ TF_COMPRESS_UNICHARSET

◆ IsTensorFlow()

bool tesseract::LSTMRecognizer::IsTensorFlow ( ) const
inline

Definition at line 95 of file lstmrecognizer.h.

95  {
96  return network_->type() == NT_TENSORFLOW;
97  }
@ NT_TENSORFLOW
Definition: network.h:76

◆ LabelsFromOutputs()

void tesseract::LSTMRecognizer::LabelsFromOutputs ( const NetworkIO outputs,
std::vector< int > *  labels,
std::vector< int > *  xcoords 
)

Definition at line 517 of file lstmrecognizer.cpp.

518  {
519  if (SimpleTextOutput()) {
520  LabelsViaSimpleText(outputs, labels, xcoords);
521  } else {
522  LabelsViaReEncode(outputs, labels, xcoords);
523  }
524 }
void LabelsViaReEncode(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
void LabelsViaSimpleText(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)

◆ LabelsViaReEncode()

void tesseract::LSTMRecognizer::LabelsViaReEncode ( const NetworkIO output,
std::vector< int > *  labels,
std::vector< int > *  xcoords 
)
protected

Definition at line 528 of file lstmrecognizer.cpp.

529  {
530  if (search_ == nullptr) {
531  search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
532  }
533  search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, nullptr);
534  search_->ExtractBestPathAsLabels(labels, xcoords);
535 }
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:89
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:207
static constexpr float kMinCertainty
Definition: recodebeam.h:246

◆ LabelsViaSimpleText()

void tesseract::LSTMRecognizer::LabelsViaSimpleText ( const NetworkIO output,
std::vector< int > *  labels,
std::vector< int > *  xcoords 
)
protected

Definition at line 540 of file lstmrecognizer.cpp.

541  {
542  labels->clear();
543  xcoords->clear();
544  const int width = output.Width();
545  for (int t = 0; t < width; ++t) {
546  float score = 0.0f;
547  const int label = output.BestLabel(t, &score);
548  if (label != null_char_) {
549  labels->push_back(label);
550  xcoords->push_back(t);
551  }
552  }
553  xcoords->push_back(width);
554 }

◆ learning_rate()

float tesseract::LSTMRecognizer::learning_rate ( ) const
inline

Definition at line 72 of file lstmrecognizer.h.

72  {
73  return learning_rate_;
74  }

◆ Load()

bool tesseract::LSTMRecognizer::Load ( const ParamsVectors params,
const std::string &  lang,
TessdataManager mgr 
)

Definition at line 75 of file lstmrecognizer.cpp.

76  {
77  TFile fp;
78  if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) {
79  return false;
80  }
81  if (!DeSerialize(mgr, &fp)) {
82  return false;
83  }
84  if (lang.empty()) {
85  return true;
86  }
87  // Allow it to run without a dictionary.
88  LoadDictionary(params, lang, mgr);
89  return true;
90 }
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)

◆ LoadCharsets()

bool tesseract::LSTMRecognizer::LoadCharsets ( const TessdataManager mgr)

Definition at line 180 of file lstmrecognizer.cpp.

180  {
181  TFile fp;
182  if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
183  return false;
184  }
185  if (!ccutil_.unicharset.load_from_file(&fp, false)) {
186  return false;
187  }
188  if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
189  return false;
190  }
191  if (!LoadRecoder(&fp)) {
192  return false;
193  }
194  return true;
195 }

◆ LoadDictionary()

bool tesseract::LSTMRecognizer::LoadDictionary ( const ParamsVectors params,
const std::string &  lang,
TessdataManager mgr 
)

Definition at line 224 of file lstmrecognizer.cpp.

225  {
226  delete dict_;
227  dict_ = new Dict(&ccutil_);
228  dict_->user_words_file.ResetFrom(params);
229  dict_->user_words_suffix.ResetFrom(params);
230  dict_->user_patterns_file.ResetFrom(params);
231  dict_->user_patterns_suffix.ResetFrom(params);
233  dict_->LoadLSTM(lang, mgr);
234  if (dict_->FinishLoad()) {
235  return true; // Success.
236  }
237  if (log_level <= 0) {
238  tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", lang.c_str());
239  }
240  delete dict_;
241  dict_ = nullptr;
242  return false;
243 }
int log_level
Definition: tprintf.cpp:36
static DawgCache * GlobalDawgCache()
Definition: dict.cpp:172
void LoadLSTM(const std::string &lang, TessdataManager *data_file)
Definition: dict.cpp:291
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:180
bool FinishLoad()
Definition: dict.cpp:357

◆ LoadRecoder()

bool tesseract::LSTMRecognizer::LoadRecoder ( TFile fp)

Definition at line 198 of file lstmrecognizer.cpp.

198  {
199  if (IsRecoding()) {
200  if (!recoder_.DeSerialize(fp)) {
201  return false;
202  }
203  RecodedCharID code;
205  if (code(0) != UNICHAR_SPACE) {
206  tprintf("Space was garbled in recoding!!\n");
207  return false;
208  }
209  } else {
212  }
213  return true;
214 }
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
void SetupPassThrough(const UNICHARSET &unicharset)

◆ null_char()

int tesseract::LSTMRecognizer::null_char ( ) const
inline

Definition at line 218 of file lstmrecognizer.h.

218  {
219  return null_char_;
220  }

◆ NumInputs()

int tesseract::LSTMRecognizer::NumInputs ( ) const
inline

Definition at line 213 of file lstmrecognizer.h.

213  {
214  return network_->NumInputs();
215  }
int NumInputs() const
Definition: network.h:122

◆ NumOutputs()

int tesseract::LSTMRecognizer::NumOutputs ( ) const
inline

Definition at line 57 of file lstmrecognizer.h.

57  {
58  return network_->NumOutputs();
59  }
int NumOutputs() const
Definition: network.h:125

◆ OutputLossType()

LossType tesseract::LSTMRecognizer::OutputLossType ( ) const
inline

Definition at line 76 of file lstmrecognizer.h.

76  {
77  if (network_ == nullptr) {
78  return LT_NONE;
79  }
80  StaticShape shape;
81  shape = network_->OutputShape(shape);
82  return shape.loss_type();
83  }
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:135
LossType loss_type() const
Definition: static_shape.h:65

◆ OutputStats()

void tesseract::LSTMRecognizer::OutputStats ( const NetworkIO outputs,
float *  min_output,
float *  mean_output,
float *  sd 
)

Definition at line 294 of file lstmrecognizer.cpp.

295  {
296  const int kOutputScale = INT8_MAX;
297  STATS stats(0, kOutputScale + 1);
298  for (int t = 0; t < outputs.Width(); ++t) {
299  int best_label = outputs.BestLabel(t, nullptr);
300  if (best_label != null_char_) {
301  float best_output = outputs.f(t)[best_label];
302  stats.add(static_cast<int>(kOutputScale * best_output), 1);
303  }
304  }
305  // If the output is all nulls it could be that the photometric interpretation
306  // is wrong, so make it look bad, so the other way can win, even if not great.
307  if (stats.get_total() == 0) {
308  *min_output = 0.0f;
309  *mean_output = 0.0f;
310  *sd = 1.0f;
311  } else {
312  *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
313  *mean_output = stats.mean() / kOutputScale;
314  *sd = stats.sd() / kOutputScale;
315  }
316 }

◆ RecognizeLine() [1/2]

bool tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
bool  invert,
bool  debug,
bool  re_invert,
bool  upside_down,
float *  scale_factor,
NetworkIO inputs,
NetworkIO outputs 
)

Definition at line 320 of file lstmrecognizer.cpp.

322  {
323  // This ensures consistent recognition results.
324  SetRandomSeed();
325  int min_width = network_->XScaleFactor();
326  Image pix = Input::PrepareLSTMInputs(image_data, network_, min_width, &randomizer_, scale_factor);
327  if (pix == nullptr) {
328  tprintf("Line cannot be recognized!!\n");
329  return false;
330  }
331  // Maximum width of image to train on.
332  const int kMaxImageWidth = 128 * pixGetHeight(pix);
333  if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
334  tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix));
335  pix.destroy();
336  return false;
337  }
338  if (upside_down) {
339  pixRotate180(pix, pix);
340  }
341  // Reduction factor from image to coords.
342  *scale_factor = min_width / *scale_factor;
343  inputs->set_int_mode(IsIntMode());
344  SetRandomSeed();
346  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
347  // Check for auto inversion.
348  if (invert) {
349  float pos_min, pos_mean, pos_sd;
350  OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
351  if (pos_mean < 0.5f) {
352  // Run again inverted and see if it is any better.
353  NetworkIO inv_inputs, inv_outputs;
354  inv_inputs.set_int_mode(IsIntMode());
355  SetRandomSeed();
356  pixInvert(pix, pix);
357  Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, &inv_inputs);
358  network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs);
359  float inv_min, inv_mean, inv_sd;
360  OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
361  if (inv_mean > pos_mean) {
362  // Inverted did better. Use inverted data.
363  if (debug) {
364  tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
365  pos_sd, inv_min, inv_mean, inv_sd);
366  }
367  *outputs = inv_outputs;
368  *inputs = inv_inputs;
369  } else if (re_invert) {
370  // Inverting was not an improvement, so undo and run again, so the
371  // outputs match the best forward result.
372  SetRandomSeed();
373  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
374  }
375  }
376  }
377 
378  pix.destroy();
379  if (debug) {
380  std::vector<int> labels, coords;
381  LabelsFromOutputs(*outputs, &labels, &coords);
382 #ifndef GRAPHICS_DISABLED
383  DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
384 #endif
385  DebugActivationPath(*outputs, labels, coords);
386  }
387  return true;
388 }
static Image PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:81
static void PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:107
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
NetworkScratch scratch_space_
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
bool IsTraining() const
Definition: network.h:113
virtual StaticShape InputShape() const
Definition: network.h:129

◆ RecognizeLine() [2/2]

void tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
bool  invert,
bool  debug,
double  worst_dict_cert,
const TBOX line_box,
PointerVector< WERD_RES > *  words,
int  lstm_choice_mode = 0,
int  lstm_choice_amount = 5 
)

Definition at line 247 of file lstmrecognizer.cpp.

250  {
251  NetworkIO outputs;
252  float scale_factor;
253  NetworkIO inputs;
254  if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor, &inputs, &outputs)) {
255  return;
256  }
257  if (search_ == nullptr) {
258  search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
259  }
260  search_->excludedUnichars.clear();
261  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, &GetUnicharset(),
262  lstm_choice_mode);
263  search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words,
264  lstm_choice_mode);
265  if (lstm_choice_mode) {
267  for (int i = 0; i < lstm_choice_amount; ++i) {
268  search_->DecodeSecondaryBeams(outputs, kDictRatio, kCertOffset, worst_dict_cert,
269  &GetUnicharset(), lstm_choice_mode);
271  }
273  unsigned char_it = 0;
274  for (size_t i = 0; i < words->size(); ++i) {
275  for (int j = 0; j < words->at(i)->end; ++j) {
276  if (char_it < search_->ctc_choices.size()) {
277  words->at(i)->CTC_symbol_choices.push_back(search_->ctc_choices[char_it]);
278  }
279  if (char_it < search_->segmentedTimesteps.size()) {
280  words->at(i)->segmented_timesteps.push_back(search_->segmentedTimesteps[char_it]);
281  }
282  ++char_it;
283  }
284  words->at(i)->timesteps =
285  search_->combineSegmentedTimesteps(&words->at(i)->segmented_timesteps);
286  }
287  search_->segmentedTimesteps.clear();
288  search_->ctc_choices.clear();
289  search_->excludedUnichars.clear();
290  }
291 }
const double kCertOffset
const double kDictRatio
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
Definition: recodebeam.h:237
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:118
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
Definition: recodebeam.h:235
void extractSymbolChoices(const UNICHARSET *unicharset)
Definition: recodebeam.cpp:415
std::vector< std::unordered_set< int > > excludedUnichars
Definition: recodebeam.h:239
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float >>>> *segmentedTimesteps)
Definition: recodebeam.cpp:181
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:245

◆ sample_iteration()

int tesseract::LSTMRecognizer::sample_iteration ( ) const
inline

Definition at line 67 of file lstmrecognizer.h.

67  {
68  return sample_iteration_;
69  }

◆ ScaleLayerLearningRate()

void tesseract::LSTMRecognizer::ScaleLayerLearningRate ( const std::string &  id,
double  factor 
)
inline

Definition at line 153 of file lstmrecognizer.h.

153  {
154  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
155  ASSERT_HOST(id.length() > 1 && id[0] == ':');
156  auto *series = static_cast<Series *>(network_);
157  series->ScaleLayerLearningRate(&id[1], factor);
158  }

◆ ScaleLearningRate()

void tesseract::LSTMRecognizer::ScaleLearningRate ( double  factor)
inline

Definition at line 142 of file lstmrecognizer.h.

142  {
143  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
144  learning_rate_ *= factor;
146  std::vector<std::string> layers = EnumerateLayers();
147  for (auto &layer : layers) {
148  ScaleLayerLearningRate(layer, factor);
149  }
150  }
151  }
void ScaleLayerLearningRate(const std::string &id, double factor)
std::vector< std::string > EnumerateLayers() const

◆ Serialize()

bool tesseract::LSTMRecognizer::Serialize ( const TessdataManager mgr,
TFile fp 
) const

Definition at line 93 of file lstmrecognizer.cpp.

93  {
94  bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
95  !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
96  if (!network_->Serialize(fp)) {
97  return false;
98  }
99  if (include_charsets && !GetUnicharset().save_to_file(fp)) {
100  return false;
101  }
102  if (!fp->Serialize(network_str_)) {
103  return false;
104  }
105  if (!fp->Serialize(&training_flags_)) {
106  return false;
107  }
108  if (!fp->Serialize(&training_iteration_)) {
109  return false;
110  }
111  if (!fp->Serialize(&sample_iteration_)) {
112  return false;
113  }
114  if (!fp->Serialize(&null_char_)) {
115  return false;
116  }
117  if (!fp->Serialize(&adam_beta_)) {
118  return false;
119  }
120  if (!fp->Serialize(&learning_rate_)) {
121  return false;
122  }
123  if (!fp->Serialize(&momentum_)) {
124  return false;
125  }
126  if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) {
127  return false;
128  }
129  return true;
130 }
bool Serialize(TFile *fp) const
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158

◆ SetIteration()

void tesseract::LSTMRecognizer::SetIteration ( int  iteration)
inline

Definition at line 209 of file lstmrecognizer.h.

209  {
210  sample_iteration_ = iteration;
211  }

◆ SetLayerLearningRate()

void tesseract::LSTMRecognizer::SetLayerLearningRate ( const std::string &  id,
float  learning_rate 
)
inline

Definition at line 172 of file lstmrecognizer.h.

173  {
174  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
175  ASSERT_HOST(id.length() > 1 && id[0] == ':');
176  auto *series = static_cast<Series *>(network_);
177  series->SetLayerLearningRate(&id[1], learning_rate);
178  }

◆ SetLearningRate()

void tesseract::LSTMRecognizer::SetLearningRate ( float  learning_rate)
inline

Definition at line 161 of file lstmrecognizer.h.

162  {
163  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
166  for (auto &id : EnumerateLayers()) {
168  }
169  }
170  }
void SetLayerLearningRate(const std::string &id, float learning_rate)

◆ SetRandomSeed()

void tesseract::LSTMRecognizer::SetRandomSeed ( )
inlineprotected

Definition at line 287 of file lstmrecognizer.h.

287  {
288  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
289  randomizer_.set_seed(seed);
291  }
int32_t IntRand()
Definition: helpers.h:72
void set_seed(uint64_t seed)
Definition: helpers.h:62

◆ SimpleTextOutput()

bool tesseract::LSTMRecognizer::SimpleTextOutput ( ) const
inline

Definition at line 84 of file lstmrecognizer.h.

84  {
85  return OutputLossType() == LT_SOFTMAX;
86  }
LossType OutputLossType() const

◆ training_iteration()

int tesseract::LSTMRecognizer::training_iteration ( ) const
inline

Definition at line 62 of file lstmrecognizer.h.

62  {
63  return training_iteration_;
64  }

Member Data Documentation

◆ adam_beta_

float tesseract::LSTMRecognizer::adam_beta_
protected

Definition at line 353 of file lstmrecognizer.h.

◆ ccutil_

CCUtil tesseract::LSTMRecognizer::ccutil_
protected

Definition at line 331 of file lstmrecognizer.h.

◆ debug_win_

ScrollView* tesseract::LSTMRecognizer::debug_win_
protected

Definition at line 365 of file lstmrecognizer.h.

◆ dict_

Dict* tesseract::LSTMRecognizer::dict_
protected

Definition at line 359 of file lstmrecognizer.h.

◆ learning_rate_

float tesseract::LSTMRecognizer::learning_rate_
protected

Definition at line 350 of file lstmrecognizer.h.

◆ momentum_

float tesseract::LSTMRecognizer::momentum_
protected

Definition at line 351 of file lstmrecognizer.h.

◆ network_

Network* tesseract::LSTMRecognizer::network_
protected

Definition at line 328 of file lstmrecognizer.h.

◆ network_str_

std::string tesseract::LSTMRecognizer::network_str_
protected

Definition at line 338 of file lstmrecognizer.h.

◆ null_char_

int32_t tesseract::LSTMRecognizer::null_char_
protected

Definition at line 348 of file lstmrecognizer.h.

◆ randomizer_

TRand tesseract::LSTMRecognizer::randomizer_
protected

Definition at line 356 of file lstmrecognizer.h.

◆ recoder_

UnicharCompress tesseract::LSTMRecognizer::recoder_
protected

Definition at line 335 of file lstmrecognizer.h.

◆ sample_iteration_

int32_t tesseract::LSTMRecognizer::sample_iteration_
protected

Definition at line 345 of file lstmrecognizer.h.

◆ scratch_space_

NetworkScratch tesseract::LSTMRecognizer::scratch_space_
protected

Definition at line 357 of file lstmrecognizer.h.

◆ search_

RecodeBeamSearch* tesseract::LSTMRecognizer::search_
protected

Definition at line 361 of file lstmrecognizer.h.

◆ training_flags_

int32_t tesseract::LSTMRecognizer::training_flags_
protected

Definition at line 341 of file lstmrecognizer.h.

◆ training_iteration_

int32_t tesseract::LSTMRecognizer::training_iteration_
protected

Definition at line 343 of file lstmrecognizer.h.


The documentation for this class was generated from the following files: