tesseract  5.0.0
mastertrainer.h
Go to the documentation of this file.
1 // Copyright 2010 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
4 // File: mastertrainer.h
5 // Description: Trainer to build the MasterClassifier.
6 // Author: Ray Smith
7 //
8 // (C) Copyright 2010, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
18 //
20 
21 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H_
22 #define TESSERACT_TRAINING_MASTERTRAINER_H_
23 
24 #include "export.h"
25 
26 #include "classify.h"
27 #include "cluster.h"
28 #include "elst.h"
29 #include "errorcounter.h"
30 #include "featdefs.h"
31 #include "fontinfo.h"
32 #include "indexmapbidi.h"
33 #include "intfeaturemap.h"
34 #include "intfeaturespace.h"
35 #include "intfx.h"
36 #include "intmatcher.h"
37 #include "params.h"
38 #include "shapetable.h"
39 #include "trainingsample.h"
40 #include "trainingsampleset.h"
41 #include "unicharset.h"
42 
43 namespace tesseract {
44 
45 class ShapeClassifier;
46 
47 // Simple struct to hold the distance between two shapes during clustering.
48 struct ShapeDist {
49  ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
50  ShapeDist(int s1, int s2, float dist) : shape1(s1), shape2(s2), distance(dist) {}
51 
52  // Sort operator to sort in ascending order of distance.
53  bool operator<(const ShapeDist &other) const {
54  return distance < other.distance;
55  }
56 
57  int shape1;
58  int shape2;
59  float distance;
60 };
61 
62 // Class to encapsulate training processes that use the TrainingSampleSet.
63 // Initially supports shape clustering and mftrainining.
64 // Other important features of the MasterTrainer are conditioning the data
65 // by outlier elimination, replication with perturbation, and serialization.
66 class TESS_COMMON_TRAINING_API MasterTrainer {
67 public:
68  MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, bool replicate_samples,
69  int debug_level);
70  ~MasterTrainer();
71 
72  // Writes to the given file. Returns false in case of error.
73  bool Serialize(FILE *fp) const;
74 
75  // Loads an initial unicharset, or sets one up if the file cannot be read.
76  void LoadUnicharset(const char *filename);
77 
78  // Sets the feature space definition.
79  void SetFeatureSpace(const IntFeatureSpace &fs) {
80  feature_space_ = fs;
81  feature_map_.Init(fs);
82  }
83 
84  // Reads the samples and their features from the given file,
85  // adding them to the trainer with the font_id from the content of the file.
86  // If verification, then these are verification samples, not training.
87  void ReadTrainingSamples(const char *page_name, const FEATURE_DEFS_STRUCT &feature_defs,
88  bool verification);
89 
90  // Adds the given single sample to the trainer, setting the classid
91  // appropriately from the given unichar_str.
92  void AddSample(bool verification, const char *unichar_str, TrainingSample *sample);
93 
94  // Loads all pages from the given tif filename and append to page_images_.
95  // Must be called after ReadTrainingSamples, as the current number of images
96  // is used as an offset for page numbers in the samples.
97  void LoadPageImages(const char *filename);
98 
99  // Cleans up the samples after initial load from the tr files, and prior to
100  // saving the MasterTrainer:
101  // Remaps fragmented chars if running shape analysis.
102  // Sets up the samples appropriately for class/fontwise access.
103  // Deletes outlier samples.
104  void PostLoadCleanup();
105 
106  // Gets the samples ready for training. Use after both
107  // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
108  // Re-indexes the features and computes canonical and cloud features.
109  void PreTrainingSetup();
110 
111  // Sets up the master_shapes_ table, which tells which fonts should stay
112  // together until they get to a leaf node classifier.
113  void SetupMasterShapes();
114 
115  // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
116  // fragments and n-grams (all incorrectly segmented characters).
117  // Various training functions may result in incorrectly segmented characters
118  // being added to the unicharset of the main samples, perhaps because they
119  // form a "radical" decomposition of some (Indic) grapheme, or because they
120  // just look the same as a real character (like rn/m)
121  // This function moves all the junk samples, to the main samples_ set, but
122  // desirable junk, being any sample for which the unichar already exists in
123  // the samples_ unicharset gets the unichar-ids re-indexed to match, but
124  // anything else gets re-marked as unichar_id 0 (space character) to identify
125  // it as junk to the error counter.
126  void IncludeJunk();
127 
128  // Replicates the samples and perturbs them if the enable_replication_ flag
129  // is set. MUST be used after the last call to OrganizeByFontAndClass on
130  // the training samples, ie after IncludeJunk if it is going to be used, as
131  // OrganizeByFontAndClass will eat the replicated samples into the regular
132  // samples.
133  void ReplicateAndRandomizeSamplesIfRequired();
134 
135  // Loads the basic font properties file into fontinfo_table_.
136  // Returns false on failure.
137  bool LoadFontInfo(const char *filename);
138 
139  // Loads the xheight font properties file into xheights_.
140  // Returns false on failure.
141  bool LoadXHeights(const char *filename);
142 
143  // Reads spacing stats from filename and adds them to fontinfo_table.
144  // Returns false on failure.
145  bool AddSpacingInfo(const char *filename);
146 
147  // Returns the font id corresponding to the given font name.
148  // Returns -1 if the font cannot be found.
149  int GetFontInfoId(const char *font_name);
150  // Returns the font_id of the closest matching font name to the given
151  // filename. It is assumed that a substring of the filename will match
152  // one of the fonts. If more than one is matched, the longest is returned.
153  int GetBestMatchingFontInfoId(const char *filename);
154 
155  // Returns the filename of the tr file corresponding to the command-line
156  // argument with the given index.
157  const std::string &GetTRFileName(int index) const {
158  return tr_filenames_[index];
159  }
160 
161  // Sets up a flat shapetable with one shape per class/font combination.
162  void SetupFlatShapeTable(ShapeTable *shape_table);
163 
164  // Sets up a Clusterer for mftraining on a single shape_id.
165  // Call FreeClusterer on the return value after use.
166  CLUSTERER *SetupForClustering(const ShapeTable &shape_table,
167  const FEATURE_DEFS_STRUCT &feature_defs, int shape_id,
168  int *num_samples);
169 
170  // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
171  // to the given inttemp_file, and the corresponding pffmtable.
172  // The unicharset is the original encoding of graphemes, and shape_set should
173  // match the size of the shape_table, and may possibly be totally fake.
174  void WriteInttempAndPFFMTable(const UNICHARSET &unicharset, const UNICHARSET &shape_set,
175  const ShapeTable &shape_table, CLASS_STRUCT *float_classes,
176  const char *inttemp_file, const char *pffmtable_file);
177 
178  const UNICHARSET &unicharset() const {
179  return samples_.unicharset();
180  }
182  return &samples_;
183  }
184  const ShapeTable &master_shapes() const {
185  return master_shapes_;
186  }
187 
188  // Generates debug output relating to the canonical distance between the
189  // two given UTF8 grapheme strings.
190  void DebugCanonical(const char *unichar_str1, const char *unichar_str2);
191 #ifndef GRAPHICS_DISABLED
192  // Debugging for cloud/canonical features.
193  // Displays a Features window containing:
194  // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
195  // displays the canonical features of the char/font combination in red.
196  // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
197  // displays the cloud feature of the char/font combination in green.
198  // The canonical features are drawn first to show which ones have no
199  // matches in the cloud features.
200  // Until the features window is destroyed, each click in the features window
201  // will display the samples that have that feature in a separate window.
202  void DisplaySamples(const char *unichar_str1, int cloud_font, const char *unichar_str2,
203  int canonical_font);
204 #endif // !GRAPHICS_DISABLED
205 
206  void TestClassifierVOld(bool replicate_samples, ShapeClassifier *test_classifier,
207  ShapeClassifier *old_classifier);
208 
209  // Tests the given test_classifier on the internal samples.
210  // See TestClassifier for details.
211  void TestClassifierOnSamples(CountTypes error_mode, int report_level, bool replicate_samples,
212  ShapeClassifier *test_classifier, std::string *report_string);
213  // Tests the given test_classifier on the given samples
214  // error_mode indicates what counts as an error.
215  // report_levels:
216  // 0 = no output.
217  // 1 = bottom-line error rate.
218  // 2 = bottom-line error rate + time.
219  // 3 = font-level error rate + time.
220  // 4 = list of all errors + short classifier debug output on 16 errors.
221  // 5 = list of all errors + short classifier debug output on 25 errors.
222  // If replicate_samples is true, then the test is run on an extended test
223  // sample including replicated and systematically perturbed samples.
224  // If report_string is non-nullptr, a summary of the results for each font
225  // is appended to the report_string.
226  double TestClassifier(CountTypes error_mode, int report_level, bool replicate_samples,
227  TrainingSampleSet *samples, ShapeClassifier *test_classifier,
228  std::string *report_string);
229 
230  // Returns the average (in some sense) distance between the two given
231  // shapes, which may contain multiple fonts and/or unichars.
232  // This function is public to facilitate testing.
233  float ShapeDistance(const ShapeTable &shapes, int s1, int s2);
234 
235 private:
236  // Replaces samples that are always fragmented with the corresponding
237  // fragment samples.
238  void ReplaceFragmentedSamples();
239 
240  // Runs a hierarchical agglomerative clustering to merge shapes in the given
241  // shape_table, while satisfying the given constraints:
242  // * End with at least min_shapes left in shape_table,
243  // * No shape shall have more than max_shape_unichars in it,
244  // * Don't merge shapes where the distance between them exceeds max_dist.
245  void ClusterShapes(int min_shapes, int max_shape_unichars, float max_dist,
246  ShapeTable *shape_table);
247 
248 private:
249  NormalizationMode norm_mode_;
250  // Character set we are training for.
251  UNICHARSET unicharset_;
252  // Original feature space. Subspace mapping is contained in feature_map_.
253  IntFeatureSpace feature_space_;
254  TrainingSampleSet samples_;
255  TrainingSampleSet junk_samples_;
256  TrainingSampleSet verify_samples_;
257  // Master shape table defines what fonts stay together until the leaves.
258  ShapeTable master_shapes_;
259  // Flat shape table has each unichar/font id pair in a separate shape.
260  ShapeTable flat_shapes_;
261  // Font metrics gathered from multiple files.
262  FontInfoTable fontinfo_table_;
263  // Array of xheights indexed by font ids in fontinfo_table_;
264  std::vector<int32_t> xheights_;
265 
266  // Non-serialized data initialized by other means or used temporarily
267  // during loading of training samples.
268  // Number of different class labels in unicharset_.
269  int charsetsize_;
270  // Flag to indicate that we are running shape analysis and need fragments
271  // fixing.
272  bool enable_shape_analysis_;
273  // Flag to indicate that sample replication is required.
274  bool enable_replication_;
275  // Array of classids of fragments that replace the correctly segmented chars.
276  int *fragments_;
277  // Classid of previous correctly segmented sample that was added.
278  int prev_unichar_id_;
279  // Debug output control.
280  int debug_level_;
281  // Feature map used to construct reduced feature spaces for compact
282  // classifiers.
283  IntFeatureMap feature_map_;
284  // Vector of Pix pointers used for classifiers that need the image.
285  // Indexed by page_num_ in the samples.
286  // These images are owned by the trainer and need to be pixDestroyed.
287  std::vector<Image > page_images_;
288  // Vector of filenames of loaded tr files.
289  std::vector<std::string> tr_filenames_;
290 };
291 
292 } // namespace tesseract.
293 
294 #endif // TESSERACT_TRAINING_MASTERTRAINER_H_
void ReadTrainingSamples(const FEATURE_DEFS_STRUCT &feature_definitions, const char *feature_name, int max_samples, UNICHARSET *unicharset, FILE *file, LIST *training_samples)
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:251
FEATURE_DEFS_STRUCT feature_defs
NormalizationMode
Definition: normalis.h:46
void Init(uint8_t xbuckets, uint8_t ybuckets, uint8_t thetabuckets)
bool operator<(const ShapeDist &other) const
Definition: mastertrainer.h:53
ShapeDist(int s1, int s2, float dist)
Definition: mastertrainer.h:50
const UNICHARSET & unicharset() const
const std::string & GetTRFileName(int index) const
void SetFeatureSpace(const IntFeatureSpace &fs)
Definition: mastertrainer.h:79
const ShapeTable & master_shapes() const
TrainingSampleSet * GetSamples()