tesseract  5.0.0
tfnetwork.h
Go to the documentation of this file.
1 // File: tfnetwork.h
3 // Description: Encapsulation of an entire tensorflow graph as a
4 // Tesseract Network.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_TFNETWORK_H_
20 #define TESSERACT_LSTM_TFNETWORK_H_
21 
22 #ifdef INCLUDE_TENSORFLOW
23 
24 # include <memory>
25 # include <string>
26 
27 # include "network.h"
28 # include "static_shape.h"
29 # include "tensorflow/core/framework/graph.pb.h"
30 # include "tensorflow/core/public/session.h"
31 # include "tfnetwork.pb.h"
32 
33 namespace tesseract {
34 
35 class TFNetwork : public Network {
36 public:
37  explicit TFNetwork(const char *name);
38  virtual ~TFNetwork() = default;
39 
40  // Returns the required shape input to the network.
41  StaticShape InputShape() const override {
42  return input_shape_;
43  }
44  // Returns the shape output from the network given an input shape (which may
45  // be partially unknown ie zero).
46  StaticShape OutputShape(const StaticShape &input_shape) const override {
47  return output_shape_;
48  }
49 
50  std::string spec() const override {
51  return spec_;
52  }
53 
54  // Deserializes *this from a serialized TFNetwork proto. Returns 0 if failed,
55  // otherwise the global step of the serialized graph.
56  int InitFromProtoStr(const std::string &proto_str);
57  // The number of classes in this network should be equal to those in the
58  // recoder_ in LSTMRecognizer.
59  int num_classes() const {
60  return output_shape_.depth();
61  }
62 
63  // Writes to the given file. Returns false in case of error.
64  // Should be overridden by subclasses, but called by their Serialize.
65  bool Serialize(TFile *fp) const override;
66  // Reads from the given file. Returns false in case of error.
67  // Should be overridden by subclasses, but NOT called by their DeSerialize.
68  bool DeSerialize(TFile *fp) override;
69 
70  // Runs forward propagation of activations on the input line.
71  // See Network for a detailed discussion of the arguments.
72  void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
73  NetworkScratch *scratch, NetworkIO *output) override;
74 
75 private:
76  // Runs backward propagation of errors on the deltas line.
77  // See Network for a detailed discussion of the arguments.
78  bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
79  NetworkIO *back_deltas) override {
80  tprintf("Must override Network::Backward for type %d\n", type_);
81  return false;
82  }
83 
84  void DebugWeights() override {
85  tprintf("Must override Network::DebugWeights for type %d\n", type_);
86  }
87 
88  int InitFromProto();
89 
90  // The original network definition for reference.
91  std::string spec_;
92  // Input tensor parameters.
93  StaticShape input_shape_;
94  // Output tensor parameters.
95  StaticShape output_shape_;
96  // The tensor flow graph is contained in here.
97  std::unique_ptr<tensorflow::Session> session_;
98  // The serialized graph is also contained in here.
99  TFNetworkModel model_proto_;
100 };
101 
102 } // namespace tesseract.
103 
104 #endif // ifdef INCLUDE_TENSORFLOW
105 
106 #endif // TESSERACT_TENSORFLOW_TFNETWORK_H_
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:220
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:251