18 #ifndef TESSERACT_LSTM_NETWORK_H_
19 #define TESSERACT_LSTM_NETWORK_H_
117 return needs_to_backprop_;
140 const std::string &
name()
const {
143 virtual std::string
spec()
const {
147 return (network_flags_ & flag) != 0;
171 virtual void SetNetworkFlags(uint32_t flags);
178 virtual int InitWeights(
float range,
TRand *randomizer);
191 [[maybe_unused]]
const std::vector<int> &code_map) {
201 virtual void SetRandomizer(
TRand *randomizer);
206 virtual bool SetupNeedsBackprop(
bool needs_backprop);
235 virtual void Update([[maybe_unused]]
float learning_rate,
236 [[maybe_unused]]
float momentum,
237 [[maybe_unused]]
float adam_beta,
238 [[maybe_unused]]
int num_samples) {}
243 [[maybe_unused]]
TFloat *same,
244 [[maybe_unused]]
TFloat *changed)
const {}
283 void DisplayForward(
const NetworkIO &matrix);
285 void DisplayBackward(
const NetworkIO &matrix);
288 static void ClearWindow(
bool tess_coords,
const char *window_name,
int width,
bool Serialize(FILE *fp, const std::vector< T > &data)
@ NT_LSTM_SOFTMAX_ENCODED
virtual void Update([[maybe_unused]] float learning_rate, [[maybe_unused]] float momentum, [[maybe_unused]] float adam_beta, [[maybe_unused]] int num_samples)
virtual void CountAlternators([[maybe_unused]] const Network &other, [[maybe_unused]] TFloat *same, [[maybe_unused]] TFloat *changed) const
virtual int XScaleFactor() const
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
const std::string & name() const
virtual bool DeSerialize(TFile *fp)=0
virtual bool IsPlumbingType() const
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool needs_to_backprop() const
ScrollView * forward_win_
ScrollView * backward_win_
virtual void DebugWeights()=0
virtual int RemapOutputs([[maybe_unused]] int old_no, [[maybe_unused]] const std::vector< int > &code_map)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
bool TestFlag(NetworkFlags flag) const
virtual std::string spec() const
virtual ~Network()=default
virtual void CacheXScaleFactor([[maybe_unused]] int factor)
virtual void ConvertToInt()
virtual StaticShape InputShape() const
void set_depth(int value)