tesseract  5.0.0
series.h
Go to the documentation of this file.
1 // File: series.h
3 // Description: Runs networks in series on the same input.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_SERIES_H_
19 #define TESSERACT_LSTM_SERIES_H_
20 
21 #include "plumbing.h"
22 
23 namespace tesseract {
24 
25 // Runs two or more networks in series (layers) on the same input.
26 class Series : public Plumbing {
27 public:
28  // ni_ and no_ will be set by AddToStack.
29  TESS_API
30  explicit Series(const char *name);
31  ~Series() override = default;
32 
33  // Returns the shape output from the network given an input shape (which may
34  // be partially unknown ie zero).
35  StaticShape OutputShape(const StaticShape &input_shape) const override;
36 
37  std::string spec() const override {
38  std::string spec("[");
39  for (auto &it : stack_) {
40  spec += it->spec();
41  }
42  spec += "]";
43  return spec;
44  }
45 
46  // Sets up the network for training. Initializes weights using weights of
47  // scale `range` picked according to the random number generator `randomizer`.
48  // Returns the number of weights initialized.
49  int InitWeights(float range, TRand *randomizer) override;
50  // Recursively searches the network for softmaxes with old_no outputs,
51  // and remaps their outputs according to code_map. See network.h for details.
52  int RemapOutputs(int old_no, const std::vector<int> &code_map) override;
53 
54  // Sets needs_to_backprop_ to needs_backprop and returns true if
55  // needs_backprop || any weights in this network so the next layer forward
56  // can be told to produce backprop for this layer if needed.
57  bool SetupNeedsBackprop(bool needs_backprop) override;
58 
59  // Returns an integer reduction factor that the network applies to the
60  // time sequence. Assumes that any 2-d is already eliminated. Used for
61  // scaling bounding boxes of truth data.
62  // WARNING: if GlobalMinimax is used to vary the scale, this will return
63  // the last used scale factor. Call it before any forward, and it will return
64  // the minimum scale factor of the paths through the GlobalMinimax.
65  int XScaleFactor() const override;
66 
67  // Provides the (minimum) x scale factor to the network (of interest only to
68  // input units) so they can determine how to scale bounding boxes.
69  void CacheXScaleFactor(int factor) override;
70 
71  // Runs forward propagation of activations on the input line.
72  // See Network for a detailed discussion of the arguments.
73  void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
74  NetworkScratch *scratch, NetworkIO *output) override;
75 
76  // Runs backward propagation of errors on the deltas line.
77  // See Network for a detailed discussion of the arguments.
78  bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
79  NetworkIO *back_deltas) override;
80 
81  // Splits the series after the given index, returning the two parts and
82  // deletes itself. The first part, up to network with index last_start, goes
83  // into start, and the rest goes into end.
84  TESS_API
85  void SplitAt(unsigned last_start, Series **start, Series **end);
86 
87  // Appends the elements of the src series to this, removing from src and
88  // deleting it.
89  TESS_API
90  void AppendSeries(Network *src);
91 };
92 
93 } // namespace tesseract.
94 
95 #endif // TESSERACT_LSTM_SERIES_H_
const std::string & name() const
Definition: network.h:140
std::vector< Network * > stack_
Definition: plumbing.h:150
bool SetupNeedsBackprop(bool needs_backprop) override
Definition: series.cpp:76
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: series.cpp:128
TESS_API void AppendSeries(Network *src)
Definition: series.cpp:192
TESS_API Series(const char *name)
Definition: series.cpp:28
int XScaleFactor() const override
Definition: series.cpp:90
TESS_API void SplitAt(unsigned last_start, Series **start, Series **end)
Definition: series.cpp:163
std::string spec() const override
Definition: series.h:37
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:34
void CacheXScaleFactor(int factor) override
Definition: series.cpp:100
int InitWeights(float range, TRand *randomizer) override
Definition: series.cpp:46
~Series() override=default
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: series.cpp:60
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: series.cpp:106
#define TESS_API
Definition: export.h:34