tesseract  5.0.0
tesseract::FullyConnected Class Reference

#include <fullyconnected.h>

Inheritance diagram for tesseract::FullyConnected:
tesseract::Network

Public Member Functions

TESS_API FullyConnected (const std::string &name, int ni, int no, NetworkType type)
 
 ~FullyConnected () override=default
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
std::string spec () const override
 
void ChangeType (NetworkType type)
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
void SetupForward (const NetworkIO &input, const TransposedArray *input_transpose)
 
void ForwardTimeStep (int t, TFloat *output_line)
 
void ForwardTimeStep (const TFloat *d_input, int t, TFloat *output_line)
 
void ForwardTimeStep (const int8_t *i_input, int t, TFloat *output_line)
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void BackwardTimeStep (const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
 
void FinishBackward (const TransposedArray &errors_t)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, TFloat *same, TFloat *changed) const override
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const std::string &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const std::string & name () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual int RemapOutputs ([[maybe_unused]] int old_no, [[maybe_unused]] const std::vector< int > &code_map)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor ([[maybe_unused]] int factor)
 
virtual void Update ([[maybe_unused]] float learning_rate, [[maybe_unused]] float momentum, [[maybe_unused]] float adam_beta, [[maybe_unused]] int num_samples)
 
virtual void CountAlternators ([[maybe_unused]] const Network &other, [[maybe_unused]] TFloat *same, [[maybe_unused]] TFloat *changed) const
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Protected Attributes

WeightMatrix weights_
 
TransposedArray source_t_
 
const TransposedArrayexternal_source_
 
NetworkIO acts_
 
bool int_mode_
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
std::string name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Image pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
TFloat Random (TFloat range)
 

Detailed Description

Definition at line 28 of file fullyconnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

tesseract::FullyConnected::FullyConnected ( const std::string &  name,
int  ni,
int  no,
NetworkType  type 
)

Definition at line 42 of file fullyconnected.cpp.

43  : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {}
const TransposedArray * external_source_
const std::string & name() const
Definition: network.h:140
NetworkType type() const
Definition: network.h:110

◆ ~FullyConnected()

tesseract::FullyConnected::~FullyConnected ( )
overridedefault

Member Function Documentation

◆ Backward()

bool tesseract::FullyConnected::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Implements tesseract::Network.

Definition at line 238 of file fullyconnected.cpp.

239  {
240 #ifndef GRAPHICS_DISABLED
241  if (debug) {
242  DisplayBackward(fwd_deltas);
243  }
244 #endif
245  back_deltas->Resize(fwd_deltas, ni_);
246  std::vector<NetworkScratch::FloatVec> errors(kNumThreads);
247  for (int i = 0; i < kNumThreads; ++i) {
248  errors[i].Init(no_, scratch);
249  }
250  std::vector<NetworkScratch::FloatVec> temp_backprops;
251  if (needs_to_backprop_) {
252  temp_backprops.resize(kNumThreads);
253  for (int i = 0; i < kNumThreads; ++i) {
254  temp_backprops[i].Init(ni_, scratch);
255  }
256  }
257  int width = fwd_deltas.Width();
258  NetworkScratch::GradientStore errors_t;
259  errors_t.Init(no_, width, scratch);
260 #ifdef _OPENMP
261 # pragma omp parallel for num_threads(kNumThreads)
262  for (int t = 0; t < width; ++t) {
263  int thread_id = omp_get_thread_num();
264 #else
265  for (int t = 0; t < width; ++t) {
266  int thread_id = 0;
267 #endif
268  TFloat *backprop = nullptr;
269  if (needs_to_backprop_) {
270  backprop = temp_backprops[thread_id];
271  }
272  TFloat *curr_errors = errors[thread_id];
273  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
274  if (backprop != nullptr) {
275  back_deltas->WriteTimeStep(t, backprop);
276  }
277  }
278  FinishBackward(*errors_t.get());
279  if (needs_to_backprop_) {
280  back_deltas->ZeroInvalidElements();
281 #if DEBUG_DETAIL > 0
282  tprintf("F Backprop:%s\n", name_.c_str());
283  back_deltas->Print(10);
284 #endif
285  return true;
286  }
287  return false; // No point going further back.
288 }
const int kNumThreads
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
double TFloat
Definition: tesstypes.h:39
void FinishBackward(const TransposedArray &errors_t)
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
bool needs_to_backprop_
Definition: network.h:302
std::string name_
Definition: network.h:307
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:341

◆ BackwardTimeStep()

void tesseract::FullyConnected::BackwardTimeStep ( const NetworkIO fwd_deltas,
int  t,
TFloat curr_errors,
TransposedArray errors_t,
TFloat backprop 
)

Definition at line 290 of file fullyconnected.cpp.

291  {
292  if (type_ == NT_TANH) {
293  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
294  } else if (type_ == NT_LOGISTIC) {
295  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
296  } else if (type_ == NT_POSCLIP) {
297  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
298  } else if (type_ == NT_SYMCLIP) {
299  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
300  } else if (type_ == NT_RELU) {
301  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
302  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC || type_ == NT_LINEAR) {
303  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
304  } else {
305  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
306  }
307  // Generate backprop only if needed by the lower layer.
308  if (backprop != nullptr) {
309  weights_.VectorDotMatrix(curr_errors, backprop);
310  }
311  errors_t->WriteStrided(t, curr_errors);
312 }
#define ASSERT_HOST(x)
Definition: errcode.h:59
@ NT_LINEAR
Definition: network.h:65
@ NT_RELU
Definition: network.h:64
@ NT_SOFTMAX
Definition: network.h:66
@ NT_LOGISTIC
Definition: network.h:60
@ NT_SYMCLIP
Definition: network.h:62
@ NT_POSCLIP
Definition: network.h:61
@ NT_SOFTMAX_NO_CTC
Definition: network.h:67
@ NT_TANH
Definition: network.h:63
NetworkType type_
Definition: network.h:300
void FuncMultiply(const NetworkIO &v_io, int t, TFloat *product)
Definition: networkio.h:258
void VectorDotMatrix(const TFloat *u, TFloat *v) const

◆ ChangeType()

void tesseract::FullyConnected::ChangeType ( NetworkType  type)
inline

Definition at line 62 of file fullyconnected.h.

62  {
63  type_ = type;
64  }

◆ ConvertToInt()

void tesseract::FullyConnected::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 102 of file fullyconnected.cpp.

102  {
104 }

◆ CountAlternators()

void tesseract::FullyConnected::CountAlternators ( const Network other,
TFloat same,
TFloat changed 
) const
override

Definition at line 331 of file fullyconnected.cpp.

331  {
332  ASSERT_HOST(other.type() == type_);
333  const auto *fc = static_cast<const FullyConnected *>(&other);
334  weights_.CountAlternators(fc->weights_, same, changed);
335 }
TESS_API FullyConnected(const std::string &name, int ni, int no, NetworkType type)
void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const

◆ DebugWeights()

void tesseract::FullyConnected::DebugWeights ( )
overridevirtual

Implements tesseract::Network.

Definition at line 107 of file fullyconnected.cpp.

107  {
108  weights_.Debug2D(name_.c_str());
109 }
void Debug2D(const char *msg)

◆ DeSerialize()

bool tesseract::FullyConnected::DeSerialize ( TFile fp)
overridevirtual

Implements tesseract::Network.

Definition at line 123 of file fullyconnected.cpp.

123  {
124  return weights_.DeSerialize(IsTraining(), fp);
125 }
bool IsTraining() const
Definition: network.h:113
bool DeSerialize(bool training, TFile *fp)

◆ FinishBackward()

void tesseract::FullyConnected::FinishBackward ( const TransposedArray errors_t)

Definition at line 314 of file fullyconnected.cpp.

314  {
315  if (external_source_ == nullptr) {
316  weights_.SumOuterTransposed(errors_t, source_t_, true);
317  } else {
318  weights_.SumOuterTransposed(errors_t, *external_source_, true);
319  }
320 }
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)

◆ Forward()

void tesseract::FullyConnected::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Implements tesseract::Network.

Definition at line 129 of file fullyconnected.cpp.

131  {
132  int width = input.Width();
133  if (type_ == NT_SOFTMAX) {
134  output->ResizeFloat(input, no_);
135  } else {
136  output->Resize(input, no_);
137  }
138  SetupForward(input, input_transpose);
139  std::vector<NetworkScratch::FloatVec> temp_lines(kNumThreads);
140  std::vector<NetworkScratch::FloatVec> curr_input(kNumThreads);
141  int ro = no_;
144  }
145  for (int i = 0; i < kNumThreads; ++i) {
146  temp_lines[i].Init(ro, scratch);
147  curr_input[i].Init(ni_, scratch);
148  }
149 #ifdef _OPENMP
150 # pragma omp parallel for num_threads(kNumThreads)
151  for (int t = 0; t < width; ++t) {
152  // Thread-local pointer to temporary storage.
153  int thread_id = omp_get_thread_num();
154 #else
155  for (int t = 0; t < width; ++t) {
156  // Thread-local pointer to temporary storage.
157  int thread_id = 0;
158 #endif
159  TFloat *temp_line = temp_lines[thread_id];
160  if (input.int_mode()) {
161  ForwardTimeStep(input.i(t), t, temp_line);
162  } else {
163  input.ReadTimeStep(t, curr_input[thread_id]);
164  ForwardTimeStep(curr_input[thread_id], t, temp_line);
165  }
166  output->WriteTimeStep(t, temp_line);
167  if (IsTraining() && type_ != NT_SOFTMAX) {
168  acts_.CopyTimeStepFrom(t, *output, t);
169  }
170  }
171  // Zero all the elements that are in the padding around images that allows
172  // multiple different-sized images to exist in a single array.
173  // acts_ is only used if this is not a softmax op.
174  if (IsTraining() && type_ != NT_SOFTMAX) {
176  }
177  output->ZeroInvalidElements();
178 #if DEBUG_DETAIL > 0
179  tprintf("F Output:%s\n", name_.c_str());
180  output->Print(10);
181 #endif
182 #ifndef GRAPHICS_DISABLED
183  if (debug) {
184  DisplayForward(*output);
185  }
186 #endif
187 }
int RoundOutputs(int size) const
Definition: intsimdmatrix.h:74
static const IntSimdMatrix * intSimdMatrix
void ForwardTimeStep(int t, TFloat *output_line)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:333
void ZeroInvalidElements()
Definition: networkio.cpp:86
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:395

◆ ForwardTimeStep() [1/3]

void tesseract::FullyConnected::ForwardTimeStep ( const int8_t *  i_input,
int  t,
TFloat output_line 
)

Definition at line 230 of file fullyconnected.cpp.

230  {
231  // input is copied to source_ line-by-line for cache coherency.
232  weights_.MatrixDotVector(i_input, output_line);
233  ForwardTimeStep(t, output_line);
234 }
void MatrixDotVector(const TFloat *u, TFloat *v) const

◆ ForwardTimeStep() [2/3]

void tesseract::FullyConnected::ForwardTimeStep ( const TFloat d_input,
int  t,
TFloat output_line 
)

Definition at line 221 of file fullyconnected.cpp.

221  {
222  // input is copied to source_ line-by-line for cache coherency.
223  if (IsTraining() && external_source_ == nullptr) {
224  source_t_.WriteStrided(t, d_input);
225  }
226  weights_.MatrixDotVector(d_input, output_line);
227  ForwardTimeStep(t, output_line);
228 }
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40

◆ ForwardTimeStep() [3/3]

void tesseract::FullyConnected::ForwardTimeStep ( int  t,
TFloat output_line 
)

Definition at line 203 of file fullyconnected.cpp.

203  {
204  if (type_ == NT_TANH) {
205  FuncInplace<GFunc>(no_, output_line);
206  } else if (type_ == NT_LOGISTIC) {
207  FuncInplace<FFunc>(no_, output_line);
208  } else if (type_ == NT_POSCLIP) {
209  FuncInplace<ClipFFunc>(no_, output_line);
210  } else if (type_ == NT_SYMCLIP) {
211  FuncInplace<ClipGFunc>(no_, output_line);
212  } else if (type_ == NT_RELU) {
213  FuncInplace<Relu>(no_, output_line);
214  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
215  SoftmaxInPlace(no_, output_line);
216  } else if (type_ != NT_LINEAR) {
217  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
218  }
219 }
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:181

◆ InitWeights()

int tesseract::FullyConnected::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 84 of file fullyconnected.cpp.

84  {
85  Network::SetRandomizer(randomizer);
86  num_weights_ = weights_.InitWeightsFloat(no_, ni_ + 1, TestFlag(NF_ADAM), range, randomizer);
87  return num_weights_;
88 }
@ NF_ADAM
Definition: network.h:86
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
int32_t num_weights_
Definition: network.h:306
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)

◆ OutputShape()

StaticShape tesseract::FullyConnected::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 47 of file fullyconnected.cpp.

47  {
48  LossType loss_type = LT_NONE;
49  if (type_ == NT_SOFTMAX) {
50  loss_type = LT_CTC;
51  } else if (type_ == NT_SOFTMAX_NO_CTC) {
52  loss_type = LT_SOFTMAX;
53  } else if (type_ == NT_LOGISTIC) {
54  loss_type = LT_LOGISTIC;
55  }
56  StaticShape result(input_shape);
57  result.set_depth(no_);
58  result.set_loss_type(loss_type);
59  return result;
60 }

◆ RemapOutputs()

int tesseract::FullyConnected::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
override

Definition at line 93 of file fullyconnected.cpp.

93  {
94  if (type_ == NT_SOFTMAX && no_ == old_no) {
96  no_ = code_map.size();
97  }
98  return num_weights_;
99 }
int RemapOutputs(const std::vector< int > &code_map)

◆ Serialize()

bool tesseract::FullyConnected::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 112 of file fullyconnected.cpp.

112  {
113  if (!Network::Serialize(fp)) {
114  return false;
115  }
116  if (!weights_.Serialize(IsTraining(), fp)) {
117  return false;
118  }
119  return true;
120 }
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158
bool Serialize(bool training, TFile *fp) const

◆ SetEnableTraining()

void tesseract::FullyConnected::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 63 of file fullyconnected.cpp.

63  {
64  if (state == TS_RE_ENABLE) {
65  // Enable only from temp disabled.
66  if (training_ == TS_TEMP_DISABLE) {
68  }
69  } else if (state == TS_TEMP_DISABLE) {
70  // Temp disable only from enabled.
71  if (training_ == TS_ENABLED) {
72  training_ = state;
73  }
74  } else {
75  if (state == TS_ENABLED && training_ != TS_ENABLED) {
77  }
78  training_ = state;
79  }
80 }
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_ENABLED
Definition: network.h:93
@ TS_RE_ENABLE
Definition: network.h:97
TrainingState training_
Definition: network.h:301

◆ SetupForward()

void tesseract::FullyConnected::SetupForward ( const NetworkIO input,
const TransposedArray input_transpose 
)

Definition at line 190 of file fullyconnected.cpp.

190  {
191  // Softmax output is always float, so save the input type.
192  int_mode_ = input.int_mode();
193  if (IsTraining()) {
194  acts_.Resize(input, no_);
195  // Source_ is a transposed copy of input. It isn't needed if provided.
196  external_source_ = input_transpose;
197  if (external_source_ == nullptr) {
198  source_t_.ResizeNoInit(ni_, input.Width());
199  }
200  }
201 }
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:94
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45

◆ spec()

std::string tesseract::FullyConnected::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 38 of file fullyconnected.h.

38  {
39  std::string spec;
40  if (type_ == NT_TANH) {
41  spec += "Ft" + std::to_string(no_);
42  } else if (type_ == NT_LOGISTIC) {
43  spec += "Fs" + std::to_string(no_);
44  } else if (type_ == NT_RELU) {
45  spec += "Fr" + std::to_string(no_);
46  } else if (type_ == NT_LINEAR) {
47  spec += "Fl" + std::to_string(no_);
48  } else if (type_ == NT_POSCLIP) {
49  spec += "Fp" + std::to_string(no_);
50  } else if (type_ == NT_SYMCLIP) {
51  spec += "Fn" + std::to_string(no_);
52  } else if (type_ == NT_SOFTMAX) {
53  spec += "Fc" + std::to_string(no_);
54  } else {
55  spec += "Fm" + std::to_string(no_);
56  }
57  return spec;
58  }
std::string spec() const override

◆ Update()

void tesseract::FullyConnected::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
override

Definition at line 324 of file fullyconnected.cpp.

324  {
325  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
326 }
void Update(float learning_rate, float momentum, float adam_beta, int num_samples)

Member Data Documentation

◆ acts_

NetworkIO tesseract::FullyConnected::acts_
protected

Definition at line 124 of file fullyconnected.h.

◆ external_source_

const TransposedArray* tesseract::FullyConnected::external_source_
protected

Definition at line 122 of file fullyconnected.h.

◆ int_mode_

bool tesseract::FullyConnected::int_mode_
protected

Definition at line 127 of file fullyconnected.h.

◆ source_t_

TransposedArray tesseract::FullyConnected::source_t_
protected

Definition at line 119 of file fullyconnected.h.

◆ weights_

WeightMatrix tesseract::FullyConnected::weights_
protected

Definition at line 117 of file fullyconnected.h.


The documentation for this class was generated from the following files: