tesseract  5.0.0
mastertrainer_test.cc
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 // Although this is a trivial-looking test, it exercises a lot of code:
13 // SampleIterator has to correctly iterate over the correct characters, or
14 // it will fail.
15 // The canonical and cloud features computed by TrainingSampleSet need to
16 // be correct, along with the distance caches, organizing samples by font
17 // and class, indexing of features, distance calculations.
18 // IntFeatureDist has to work, or the canonical samples won't work.
19 // Mastertrainer has ability to read tr files and set itself up tested.
20 // Finally the serialize/deserialize test ensures that MasterTrainer,
21 // TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
22 // enough to reproduce the same results.
23 
24 #include "include_gunit.h"
25 
26 #include "commontraining.h"
27 #include "errorcounter.h"
28 #include "log.h" // for LOG
29 #include "mastertrainer.h"
30 #include "shapeclassifier.h"
31 #include "shapetable.h"
32 #include "trainingsample.h"
33 #include "unicharset.h"
34 
35 #include <string>
36 #include <utility>
37 #include <vector>
38 
39 using namespace tesseract;
40 
41 // Specs of the MockClassifier.
42 static const int kNumTopNErrs = 10;
43 static const int kNumTop2Errs = kNumTopNErrs + 20;
44 static const int kNumTop1Errs = kNumTop2Errs + 30;
45 static const int kNumTopTopErrs = kNumTop1Errs + 25;
46 static const int kNumNonReject = 1000;
47 static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
48 // The total number of answers is given by the number of non-rejects plus
49 // all the multiple answers.
50 static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
51  (kNumTop1Errs - kNumTop2Errs) + (kNumTopTopErrs - kNumTop1Errs);
52 
53 #ifndef DISABLED_LEGACY_ENGINE
54 static bool safe_strto32(const std::string &str, int *pResult) {
55  long n = strtol(str.c_str(), nullptr, 0);
56  *pResult = n;
57  return true;
58 }
59 #endif
60 
61 // Mock ShapeClassifier that cheats by looking at the correct answer, and
62 // creates a specific pattern of errors that can be tested.
64 public:
65  explicit MockClassifier(ShapeTable *shape_table)
66  : shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
67  // Add a false font answer to the shape table. We pick a random unichar_id,
68  // add a new shape for it with a false font. Font must actually exist in
69  // the font table, but not match anything in the first 1000 samples.
70  false_unichar_id_ = 67;
71  false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
72  }
73  ~MockClassifier() override = default;
74 
75  // Classifies the given [training] sample, writing to results.
76  // If debug is non-zero, then various degrees of classifier dependent debug
77  // information is provided.
78  // If keep_this (a shape index) is >= 0, then the results should always
79  // contain keep_this, and (if possible) anything of intermediate confidence.
80  // The return value is the number of classes saved in results.
81  int ClassifySample(const TrainingSample &sample, Image page_pix, int debug, UNICHAR_ID keep_this,
82  std::vector<ShapeRating> *results) override {
83  results->clear();
84  // Everything except the first kNumNonReject is a reject.
85  if (++num_done_ > kNumNonReject) {
86  return 0;
87  }
88 
89  int class_id = sample.class_id();
90  int font_id = sample.font_id();
91  int shape_id = shape_table_->FindShape(class_id, font_id);
92  // Get ids of some wrong answers.
93  int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
94  int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
95  if (num_done_ <= kNumTopNErrs) {
96  // The first kNumTopNErrs are top-n errors.
97  results->push_back(ShapeRating(wrong_id1, 1.0f));
98  } else if (num_done_ <= kNumTop2Errs) {
99  // The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
100  results->push_back(ShapeRating(wrong_id1, 1.0f));
101  results->push_back(ShapeRating(wrong_id2, 0.875f));
102  results->push_back(ShapeRating(shape_id, 0.75f));
103  } else if (num_done_ <= kNumTop1Errs) {
104  // The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
105  results->push_back(ShapeRating(wrong_id1, 1.0f));
106  results->push_back(ShapeRating(shape_id, 0.8f));
107  } else if (num_done_ <= kNumTopTopErrs) {
108  // The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
109  // is not correct, but do not count as a top-1 error because the rating
110  // is close enough to the top answer.
111  results->push_back(ShapeRating(wrong_id1, 1.0f));
112  results->push_back(ShapeRating(shape_id, 0.99f));
113  } else if (!done_bad_font_ && class_id == false_unichar_id_) {
114  // There is a single character with a bad font.
115  results->push_back(ShapeRating(false_shape_, 1.0f));
116  done_bad_font_ = true;
117  } else {
118  // Everything else is correct.
119  results->push_back(ShapeRating(shape_id, 1.0f));
120  }
121  return results->size();
122  }
123  // Provides access to the ShapeTable that this classifier works with.
124  const ShapeTable *GetShapeTable() const override {
125  return shape_table_;
126  }
127 
128 private:
129  // Borrowed pointer to the ShapeTable.
130  ShapeTable *shape_table_;
131  // Unichar_id of a random character that occurs after the first 60 samples.
132  int false_unichar_id_;
133  // Shape index of prepared false answer for false_unichar_id.
134  int false_shape_;
135  // The number of classifications we have processed.
136  int num_done_;
137  // True after the false font has been emitted.
138  bool done_bad_font_;
139 };
140 
141 const double kMin1lDistance = 0.25;
142 
143 // The fixture for testing Tesseract.
144 class MasterTrainerTest : public testing::Test {
145 #ifndef DISABLED_LEGACY_ENGINE
146 protected:
147  void SetUp() override {
148  std::locale::global(std::locale(""));
150  }
151 
152  std::string TestDataNameToPath(const std::string &name) {
153  return file::JoinPath(TESTING_DIR, name);
154  }
155  std::string TmpNameToPath(const std::string &name) {
156  return file::JoinPath(FLAGS_test_tmpdir, name);
157  }
158 
160  shape_table_ = nullptr;
161  master_trainer_ = nullptr;
162  }
163  ~MasterTrainerTest() override {
164  delete shape_table_;
165  }
166 
167  // Initializes the master_trainer_ and shape_table_.
168  // if load_from_tmp, then reloads a master trainer that was saved by a
169  // previous call in which it was false.
171  FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
172  FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
173  FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
174  FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
175  std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
176  const char *filelist[] = {tr_file_name.c_str(), nullptr};
177  std::string file_prefix;
178  delete shape_table_;
179  shape_table_ = nullptr;
180  master_trainer_ = LoadTrainingData(filelist, false, &shape_table_, file_prefix);
181  EXPECT_TRUE(master_trainer_ != nullptr);
182  EXPECT_TRUE(shape_table_ != nullptr);
183  }
184 
185  // EXPECTs that the distance between I and l in Arial is 0 and that the
186  // distance to 1 is significantly not 0.
187  void VerifyIl1() {
188  // Find the font id for Arial.
189  int font_id = master_trainer_->GetFontInfoId("Arial");
190  EXPECT_GE(font_id, 0);
191  // Track down the characters we are interested in.
192  int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
193  EXPECT_GT(unichar_I, 0);
194  int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
195  EXPECT_GT(unichar_l, 0);
196  int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
197  EXPECT_GT(unichar_1, 0);
198  // Now get the shape ids.
199  int shape_I = shape_table_->FindShape(unichar_I, font_id);
200  EXPECT_GE(shape_I, 0);
201  int shape_l = shape_table_->FindShape(unichar_l, font_id);
202  EXPECT_GE(shape_l, 0);
203  int shape_1 = shape_table_->FindShape(unichar_1, font_id);
204  EXPECT_GE(shape_1, 0);
205 
206  float dist_I_l = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
207  // No tolerance here. We expect that I and l should match exactly.
208  EXPECT_EQ(0.0f, dist_I_l);
209  float dist_l_I = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
210  // BOTH ways.
211  EXPECT_EQ(0.0f, dist_l_I);
212 
213  // l/1 on the other hand should be distinct.
214  float dist_l_1 = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
215  EXPECT_GT(dist_l_1, kMin1lDistance);
216  float dist_1_l = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
217  EXPECT_GT(dist_1_l, kMin1lDistance);
218 
219  // So should I/1.
220  float dist_I_1 = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
221  EXPECT_GT(dist_I_1, kMin1lDistance);
222  float dist_1_I = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
223  EXPECT_GT(dist_1_I, kMin1lDistance);
224  }
225 
226  // Objects declared here can be used by all tests in the test case for Foo.
228  std::unique_ptr<MasterTrainer> master_trainer_;
229 #endif
230 };
231 
232 // Tests that the MasterTrainer correctly loads its data and reaches the correct
233 // conclusion over the distance between Arial I l and 1.
235 #ifdef DISABLED_LEGACY_ENGINE
236  // Skip test because LoadTrainingData is missing.
237  GTEST_SKIP();
238 #else
239  // Initialize the master_trainer_ and load the Arial tr file.
240  LoadMasterTrainer();
241  VerifyIl1();
242 #endif
243 }
244 
245 // Tests the ErrorCounter using a MockClassifier to check that it counts
246 // error categories correctly.
247 TEST_F(MasterTrainerTest, ErrorCounterTest) {
248 #ifdef DISABLED_LEGACY_ENGINE
249  // Skip test because LoadTrainingData is missing.
250  GTEST_SKIP();
251 #else
252  // Initialize the master_trainer_ from the saved tmp file.
253  LoadMasterTrainer();
254  // Add the space character to the shape_table_ if not already present to
255  // count junk.
256  if (shape_table_->FindShape(0, -1) < 0) {
257  shape_table_->AddShape(0, 0);
258  }
259  // Make a mock classifier.
260  auto shape_classifier = std::make_unique<MockClassifier>(shape_table_);
261  // Get the accuracy report.
262  std::string accuracy_report;
263  master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0, false,
264  shape_classifier.get(), &accuracy_report);
265  LOG(INFO) << accuracy_report.c_str();
266  std::string result_string = accuracy_report.c_str();
267  std::vector<std::string> results = split(result_string, '\t');
268  EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
269  int result_values[tesseract::CT_SIZE];
270  for (int i = 0; i < tesseract::CT_SIZE; ++i) {
271  EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
272  }
273  // These tests are more-or-less immune to additions to the number of
274  // categories or changes in the training data.
275  int num_samples = master_trainer_->GetSamples()->num_raw_samples();
276  EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
277  EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
278  EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
279  EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
280  EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
281  EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
282  // Each of the TOPTOP errs also counts as a multi-unichar.
283  EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs, result_values[tesseract::CT_OK_MULTI_UNICHAR]);
284  EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
285  EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
286 #endif
287 }
@ LOG
@ INFO
Definition: log.h:28
const double kMin1lDistance
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:41
std::string TestDataNameToPath(const std::string &name)
int UNICHAR_ID
Definition: unichar.h:36
std::unique_ptr< MasterTrainer > LoadTrainingData(const char *const *filelist, bool replication, ShapeTable **shape_table, std::string &file_prefix)
@ CT_UNICHAR_TOPN_ERR
Definition: errorcounter.h:76
@ CT_UNICHAR_TOP_OK
Definition: errorcounter.h:70
@ CT_UNICHAR_TOP1_ERR
Definition: errorcounter.h:74
@ CT_UNICHAR_TOP2_ERR
Definition: errorcounter.h:75
@ CT_UNICHAR_TOPTOP_ERR
Definition: errorcounter.h:77
@ CT_FONT_ATTR_ERR
Definition: errorcounter.h:82
@ CT_OK_MULTI_UNICHAR
Definition: errorcounter.h:78
@ CT_NUM_RESULTS
Definition: errorcounter.h:84
TEST_F(EuroText, FastLatinOCR)
UNICHAR_ID class_id() const
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
const ShapeTable * GetShapeTable() const override
~MockClassifier() override=default
MockClassifier(ShapeTable *shape_table)
int ClassifySample(const TrainingSample &sample, Image page_pix, int debug, UNICHAR_ID keep_this, std::vector< ShapeRating > *results) override
std::string TestDataNameToPath(const std::string &name)
std::string TmpNameToPath(const std::string &name)
void SetUp() override
std::unique_ptr< MasterTrainer > master_trainer_