tesseract  5.0.0
lstm.cpp
Go to the documentation of this file.
1 // File: lstm.cpp
3 // Description: Long-term-short-term-memory Recurrent neural network.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifdef HAVE_CONFIG_H
19 # include "config_auto.h"
20 #endif
21 
22 #include "lstm.h"
23 
24 #ifdef _OPENMP
25 # include <omp.h>
26 #endif
27 #include <cstdio>
28 #include <cstdlib>
29 #include <sstream> // for std::ostringstream
30 
31 #if !defined(__GNUC__) && defined(_MSC_VER)
32 # include <intrin.h> // _BitScanReverse
33 #endif
34 
35 #include "fullyconnected.h"
36 #include "functions.h"
37 #include "networkscratch.h"
38 #include "tprintf.h"
39 
40 // Macros for openmp code if it is available, otherwise empty macros.
41 #ifdef _OPENMP
42 # define PARALLEL_IF_OPENMP(__num_threads) \
43  PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
44  PRAGMA(omp sections nowait) { \
45  PRAGMA(omp section) {
46 # define SECTION_IF_OPENMP \
47  } \
48  PRAGMA(omp section) {
49 # define END_PARALLEL_IF_OPENMP \
50  } \
51  } /* end of sections */ \
52  } /* end of parallel section */
53 
54 // Define the portable PRAGMA macro.
55 # ifdef _MSC_VER // Different _Pragma
56 # define PRAGMA(x) __pragma(x)
57 # else
58 # define PRAGMA(x) _Pragma(# x)
59 # endif // _MSC_VER
60 
61 #else // _OPENMP
62 # define PARALLEL_IF_OPENMP(__num_threads)
63 # define SECTION_IF_OPENMP
64 # define END_PARALLEL_IF_OPENMP
65 #endif // _OPENMP
66 
67 namespace tesseract {
68 
69 // Max absolute value of state_. It is reasonably high to enable the state
70 // to count things.
71 const TFloat kStateClip = 100.0;
72 // Max absolute value of gate_errors (the gradients).
73 const TFloat kErrClip = 1.0f;
74 
75 // Calculate ceil(log2(n)).
76 static inline uint32_t ceil_log2(uint32_t n) {
77  // l2 = (unsigned)log2(n).
78 #if defined(__GNUC__)
79  // Use fast inline assembler code for gcc or clang.
80  uint32_t l2 = 31 - __builtin_clz(n);
81 #elif defined(_MSC_VER)
82  // Use fast intrinsic function for MS compiler.
83  unsigned long l2 = 0;
84  _BitScanReverse(&l2, n);
85 #else
86  if (n == 0)
87  return UINT_MAX;
88  if (n == 1)
89  return 0;
90  uint32_t val = n;
91  uint32_t l2 = 0;
92  while (val > 1) {
93  val >>= 1;
94  l2++;
95  }
96 #endif
97  // Round up if n is not a power of 2.
98  return (n == (1u << l2)) ? l2 : l2 + 1;
99 }
100 
101 LSTM::LSTM(const std::string &name, int ni, int ns, int no, bool two_dimensional, NetworkType type)
102  : Network(type, name, ni, no)
103  , na_(ni + ns)
104  , ns_(ns)
105  , nf_(0)
106  , is_2d_(two_dimensional)
107  , softmax_(nullptr)
108  , input_width_(0) {
109  if (two_dimensional) {
110  na_ += ns_;
111  }
112  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
113  nf_ = 0;
114  // networkbuilder ensures this is always true.
115  ASSERT_HOST(no == ns);
116  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
117  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
118  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
119  } else {
120  tprintf("%d is invalid type of LSTM!\n", type);
121  ASSERT_HOST(false);
122  }
123  na_ += nf_;
124 }
125 
127  delete softmax_;
128 }
129 
130 // Returns the shape output from the network given an input shape (which may
131 // be partially unknown ie zero).
132 StaticShape LSTM::OutputShape(const StaticShape &input_shape) const {
133  StaticShape result = input_shape;
134  result.set_depth(no_);
135  if (type_ == NT_LSTM_SUMMARY) {
136  result.set_width(1);
137  }
138  if (softmax_ != nullptr) {
139  return softmax_->OutputShape(result);
140  }
141  return result;
142 }
143 
144 // Suspends/Enables training by setting the training_ flag. Serialize and
145 // DeSerialize only operate on the run-time data if state is false.
147  if (state == TS_RE_ENABLE) {
148  // Enable only from temp disabled.
149  if (training_ == TS_TEMP_DISABLE) {
151  }
152  } else if (state == TS_TEMP_DISABLE) {
153  // Temp disable only from enabled.
154  if (training_ == TS_ENABLED) {
155  training_ = state;
156  }
157  } else {
158  if (state == TS_ENABLED && training_ != TS_ENABLED) {
159  for (int w = 0; w < WT_COUNT; ++w) {
160  if (w == GFS && !Is2D()) {
161  continue;
162  }
163  gate_weights_[w].InitBackward();
164  }
165  }
166  training_ = state;
167  }
168  if (softmax_ != nullptr) {
169  softmax_->SetEnableTraining(state);
170  }
171 }
172 
173 // Sets up the network for training. Initializes weights using weights of
174 // scale `range` picked according to the random number generator `randomizer`.
175 int LSTM::InitWeights(float range, TRand *randomizer) {
176  Network::SetRandomizer(randomizer);
177  num_weights_ = 0;
178  for (int w = 0; w < WT_COUNT; ++w) {
179  if (w == GFS && !Is2D()) {
180  continue;
181  }
182  num_weights_ +=
183  gate_weights_[w].InitWeightsFloat(ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
184  }
185  if (softmax_ != nullptr) {
186  num_weights_ += softmax_->InitWeights(range, randomizer);
187  }
188  return num_weights_;
189 }
190 
191 // Recursively searches the network for softmaxes with old_no outputs,
192 // and remaps their outputs according to code_map. See network.h for details.
193 int LSTM::RemapOutputs(int old_no, const std::vector<int> &code_map) {
194  if (softmax_ != nullptr) {
195  num_weights_ -= softmax_->num_weights();
196  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
197  }
198  return num_weights_;
199 }
200 
201 // Converts a float network to an int network.
203  for (int w = 0; w < WT_COUNT; ++w) {
204  if (w == GFS && !Is2D()) {
205  continue;
206  }
207  gate_weights_[w].ConvertToInt();
208  }
209  if (softmax_ != nullptr) {
210  softmax_->ConvertToInt();
211  }
212 }
213 
214 // Sets up the network for training using the given weight_range.
216  for (int w = 0; w < WT_COUNT; ++w) {
217  if (w == GFS && !Is2D()) {
218  continue;
219  }
220  std::ostringstream msg;
221  msg << name_ << " Gate weights " << w;
222  gate_weights_[w].Debug2D(msg.str().c_str());
223  }
224  if (softmax_ != nullptr) {
225  softmax_->DebugWeights();
226  }
227 }
228 
229 // Writes to the given file. Returns false in case of error.
230 bool LSTM::Serialize(TFile *fp) const {
231  if (!Network::Serialize(fp)) {
232  return false;
233  }
234  if (!fp->Serialize(&na_)) {
235  return false;
236  }
237  for (int w = 0; w < WT_COUNT; ++w) {
238  if (w == GFS && !Is2D()) {
239  continue;
240  }
241  if (!gate_weights_[w].Serialize(IsTraining(), fp)) {
242  return false;
243  }
244  }
245  if (softmax_ != nullptr && !softmax_->Serialize(fp)) {
246  return false;
247  }
248  return true;
249 }
250 
251 // Reads from the given file. Returns false in case of error.
252 
254  if (!fp->DeSerialize(&na_)) {
255  return false;
256  }
257  if (type_ == NT_LSTM_SOFTMAX) {
258  nf_ = no_;
259  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
260  nf_ = ceil_log2(no_);
261  } else {
262  nf_ = 0;
263  }
264  is_2d_ = false;
265  for (int w = 0; w < WT_COUNT; ++w) {
266  if (w == GFS && !Is2D()) {
267  continue;
268  }
269  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) {
270  return false;
271  }
272  if (w == CI) {
273  ns_ = gate_weights_[CI].NumOutputs();
274  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
275  }
276  }
277  delete softmax_;
279  softmax_ = static_cast<FullyConnected *>(Network::CreateFromFile(fp));
280  if (softmax_ == nullptr) {
281  return false;
282  }
283  } else {
284  softmax_ = nullptr;
285  }
286  return true;
287 }
288 
289 // Runs forward propagation of activations on the input line.
290 // See NetworkCpp for a detailed discussion of the arguments.
291 void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
292  NetworkScratch *scratch, NetworkIO *output) {
293  input_map_ = input.stride_map();
294  input_width_ = input.Width();
295  if (softmax_ != nullptr) {
296  output->ResizeFloat(input, no_);
297  } else if (type_ == NT_LSTM_SUMMARY) {
298  output->ResizeXTo1(input, no_);
299  } else {
300  output->Resize(input, no_);
301  }
302  ResizeForward(input);
303  // Temporary storage of forward computation for each gate.
305  int ro = ns_;
306  if (source_.int_mode() && IntSimdMatrix::intSimdMatrix) {
308  }
309  for (auto &temp_line : temp_lines) {
310  temp_line.Init(ns_, ro, scratch);
311  }
312  // Single timestep buffers for the current/recurrent output and state.
313  NetworkScratch::FloatVec curr_state, curr_output;
314  curr_state.Init(ns_, scratch);
315  ZeroVector<TFloat>(ns_, curr_state);
316  curr_output.Init(ns_, scratch);
317  ZeroVector<TFloat>(ns_, curr_output);
318  // Rotating buffers of width buf_width allow storage of the state and output
319  // for the other dimension, used only when working in true 2D mode. The width
320  // is enough to hold an entire strip of the major direction.
321  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
322  std::vector<NetworkScratch::FloatVec> states, outputs;
323  if (Is2D()) {
324  states.resize(buf_width);
325  outputs.resize(buf_width);
326  for (int i = 0; i < buf_width; ++i) {
327  states[i].Init(ns_, scratch);
328  ZeroVector<TFloat>(ns_, states[i]);
329  outputs[i].Init(ns_, scratch);
330  ZeroVector<TFloat>(ns_, outputs[i]);
331  }
332  }
333  // Used only if a softmax LSTM.
334  NetworkScratch::FloatVec softmax_output;
335  NetworkScratch::IO int_output;
336  if (softmax_ != nullptr) {
337  softmax_output.Init(no_, scratch);
338  ZeroVector<TFloat>(no_, softmax_output);
339  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
340  if (input.int_mode()) {
341  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
342  }
343  softmax_->SetupForward(input, nullptr);
344  }
345  NetworkScratch::FloatVec curr_input;
346  curr_input.Init(na_, scratch);
347  StrideMap::Index src_index(input_map_);
348  // Used only by NT_LSTM_SUMMARY.
349  StrideMap::Index dest_index(output->stride_map());
350  do {
351  int t = src_index.t();
352  // True if there is a valid old state for the 2nd dimension.
353  bool valid_2d = Is2D();
354  if (valid_2d) {
355  StrideMap::Index dim_index(src_index);
356  if (!dim_index.AddOffset(-1, FD_HEIGHT)) {
357  valid_2d = false;
358  }
359  }
360  // Index of the 2-D revolving buffers (outputs, states).
361  int mod_t = Modulo(t, buf_width); // Current timestep.
362  // Setup the padded input in source.
363  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
364  if (softmax_ != nullptr) {
365  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
366  }
367  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
368  if (Is2D()) {
369  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
370  }
371  if (!source_.int_mode()) {
372  source_.ReadTimeStep(t, curr_input);
373  }
374  // Matrix multiply the inputs with the source.
376  // It looks inefficient to create the threads on each t iteration, but the
377  // alternative of putting the parallel outside the t loop, a single around
378  // the t-loop and then tasks in place of the sections is a *lot* slower.
379  // Cell inputs.
380  if (source_.int_mode()) {
381  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
382  } else {
383  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
384  }
385  FuncInplace<GFunc>(ns_, temp_lines[CI]);
386 
388  // Input Gates.
389  if (source_.int_mode()) {
390  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
391  } else {
392  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
393  }
394  FuncInplace<FFunc>(ns_, temp_lines[GI]);
395 
397  // 1-D forget gates.
398  if (source_.int_mode()) {
399  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
400  } else {
401  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
402  }
403  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
404 
405  // 2-D forget gates.
406  if (Is2D()) {
407  if (source_.int_mode()) {
408  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
409  } else {
410  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
411  }
412  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
413  }
414 
416  // Output gates.
417  if (source_.int_mode()) {
418  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
419  } else {
420  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
421  }
422  FuncInplace<FFunc>(ns_, temp_lines[GO]);
424 
425  // Apply forget gate to state.
426  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
427  if (Is2D()) {
428  // Max-pool the forget gates (in 2-d) instead of blindly adding.
429  int8_t *which_fg_col = which_fg_[t];
430  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
431  if (valid_2d) {
432  const TFloat *stepped_state = states[mod_t];
433  for (int i = 0; i < ns_; ++i) {
434  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
435  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
436  which_fg_col[i] = 2;
437  }
438  }
439  }
440  }
441  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
442  // Clip curr_state to a sane range.
443  ClipVector<TFloat>(ns_, -kStateClip, kStateClip, curr_state);
444  if (IsTraining()) {
445  // Save the gate node values.
446  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
447  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
448  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
449  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
450  if (Is2D()) {
451  node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
452  }
453  }
454  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
455  if (IsTraining()) {
456  state_.WriteTimeStep(t, curr_state);
457  }
458  if (softmax_ != nullptr) {
459  if (input.int_mode()) {
460  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
461  softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
462  } else {
463  softmax_->ForwardTimeStep(curr_output, t, softmax_output);
464  }
465  output->WriteTimeStep(t, softmax_output);
467  CodeInBinary(no_, nf_, softmax_output);
468  }
469  } else if (type_ == NT_LSTM_SUMMARY) {
470  // Output only at the end of a row.
471  if (src_index.IsLast(FD_WIDTH)) {
472  output->WriteTimeStep(dest_index.t(), curr_output);
473  dest_index.Increment();
474  }
475  } else {
476  output->WriteTimeStep(t, curr_output);
477  }
478  // Save states for use by the 2nd dimension only if needed.
479  if (Is2D()) {
480  CopyVector(ns_, curr_state, states[mod_t]);
481  CopyVector(ns_, curr_output, outputs[mod_t]);
482  }
483  // Always zero the states at the end of every row, but only for the major
484  // direction. The 2-D state remains intact.
485  if (src_index.IsLast(FD_WIDTH)) {
486  ZeroVector<TFloat>(ns_, curr_state);
487  ZeroVector<TFloat>(ns_, curr_output);
488  }
489  } while (src_index.Increment());
490 #if DEBUG_DETAIL > 0
491  tprintf("Source:%s\n", name_.c_str());
492  source_.Print(10);
493  tprintf("State:%s\n", name_.c_str());
494  state_.Print(10);
495  tprintf("Output:%s\n", name_.c_str());
496  output->Print(10);
497 #endif
498 #ifndef GRAPHICS_DISABLED
499  if (debug) {
500  DisplayForward(*output);
501  }
502 #endif
503 }
504 
505 // Runs backward propagation of errors on the deltas line.
506 // See NetworkCpp for a detailed discussion of the arguments.
507 bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
508  NetworkIO *back_deltas) {
509 #ifndef GRAPHICS_DISABLED
510  if (debug) {
511  DisplayBackward(fwd_deltas);
512  }
513 #endif
514  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
515  // ======Scratch space.======
516  // Output errors from deltas with recurrence from sourceerr.
517  NetworkScratch::FloatVec outputerr;
518  outputerr.Init(ns_, scratch);
519  // Recurrent error in the state/source.
520  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
521  curr_stateerr.Init(ns_, scratch);
522  curr_sourceerr.Init(na_, scratch);
523  ZeroVector<TFloat>(ns_, curr_stateerr);
524  ZeroVector<TFloat>(na_, curr_sourceerr);
525  // Errors in the gates.
526  NetworkScratch::FloatVec gate_errors[WT_COUNT];
527  for (auto &gate_error : gate_errors) {
528  gate_error.Init(ns_, scratch);
529  }
530  // Rotating buffers of width buf_width allow storage of the recurrent time-
531  // steps used only for true 2-D. Stores one full strip of the major direction.
532  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
533  std::vector<NetworkScratch::FloatVec> stateerr, sourceerr;
534  if (Is2D()) {
535  stateerr.resize(buf_width);
536  sourceerr.resize(buf_width);
537  for (int t = 0; t < buf_width; ++t) {
538  stateerr[t].Init(ns_, scratch);
539  sourceerr[t].Init(na_, scratch);
540  ZeroVector<TFloat>(ns_, stateerr[t]);
541  ZeroVector<TFloat>(na_, sourceerr[t]);
542  }
543  }
544  // Parallel-generated sourceerr from each of the gates.
545  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
546  for (auto &sourceerr_temp : sourceerr_temps) {
547  sourceerr_temp.Init(na_, scratch);
548  }
549  int width = input_width_;
550  // Transposed gate errors stored over all timesteps for sum outer.
552  for (auto &w : gate_errors_t) {
553  w.Init(ns_, width, scratch);
554  }
555  // Used only if softmax_ != nullptr.
556  NetworkScratch::FloatVec softmax_errors;
557  NetworkScratch::GradientStore softmax_errors_t;
558  if (softmax_ != nullptr) {
559  softmax_errors.Init(no_, scratch);
560  softmax_errors_t.Init(no_, width, scratch);
561  }
562  TFloat state_clip = Is2D() ? 9.0 : 4.0;
563 #if DEBUG_DETAIL > 1
564  tprintf("fwd_deltas:%s\n", name_.c_str());
565  fwd_deltas.Print(10);
566 #endif
567  StrideMap::Index dest_index(input_map_);
568  dest_index.InitToLast();
569  // Used only by NT_LSTM_SUMMARY.
570  StrideMap::Index src_index(fwd_deltas.stride_map());
571  src_index.InitToLast();
572  do {
573  int t = dest_index.t();
574  bool at_last_x = dest_index.IsLast(FD_WIDTH);
575  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
576  // valid if >= 0, which is true if 2d and not on the top/bottom.
577  int up_pos = -1;
578  int down_pos = -1;
579  if (Is2D()) {
580  if (dest_index.index(FD_HEIGHT) > 0) {
581  StrideMap::Index up_index(dest_index);
582  if (up_index.AddOffset(-1, FD_HEIGHT)) {
583  up_pos = up_index.t();
584  }
585  }
586  if (!dest_index.IsLast(FD_HEIGHT)) {
587  StrideMap::Index down_index(dest_index);
588  if (down_index.AddOffset(1, FD_HEIGHT)) {
589  down_pos = down_index.t();
590  }
591  }
592  }
593  // Index of the 2-D revolving buffers (sourceerr, stateerr).
594  int mod_t = Modulo(t, buf_width); // Current timestep.
595  // Zero the state in the major direction only at the end of every row.
596  if (at_last_x) {
597  ZeroVector<TFloat>(na_, curr_sourceerr);
598  ZeroVector<TFloat>(ns_, curr_stateerr);
599  }
600  // Setup the outputerr.
601  if (type_ == NT_LSTM_SUMMARY) {
602  if (dest_index.IsLast(FD_WIDTH)) {
603  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
604  src_index.Decrement();
605  } else {
606  ZeroVector<TFloat>(ns_, outputerr);
607  }
608  } else if (softmax_ == nullptr) {
609  fwd_deltas.ReadTimeStep(t, outputerr);
610  } else {
611  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr);
612  }
613  if (!at_last_x) {
614  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
615  }
616  if (down_pos >= 0) {
617  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
618  }
619  // Apply the 1-d forget gates.
620  if (!at_last_x) {
621  const float *next_node_gf1 = node_values_[GF1].f(t + 1);
622  for (int i = 0; i < ns_; ++i) {
623  curr_stateerr[i] *= next_node_gf1[i];
624  }
625  }
626  if (Is2D() && t + 1 < width) {
627  for (int i = 0; i < ns_; ++i) {
628  if (which_fg_[t + 1][i] != 1) {
629  curr_stateerr[i] = 0.0;
630  }
631  }
632  if (down_pos >= 0) {
633  const float *right_node_gfs = node_values_[GFS].f(down_pos);
634  const TFloat *right_stateerr = stateerr[mod_t];
635  for (int i = 0; i < ns_; ++i) {
636  if (which_fg_[down_pos][i] == 2) {
637  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
638  }
639  }
640  }
641  }
642  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr, curr_stateerr);
643  // Clip stateerr_ to a sane range.
644  ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr);
645 #if DEBUG_DETAIL > 1
646  if (t + 10 > width) {
647  tprintf("t=%d, stateerr=", t);
648  for (int i = 0; i < ns_; ++i)
649  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], curr_sourceerr[ni_ + nf_ + i]);
650  tprintf("\n");
651  }
652 #endif
653  // Matrix multiply to get the source errors.
655 
656  // Cell inputs.
657  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t, curr_stateerr, gate_errors[CI]);
658  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
659  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
660  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
661 
663  // Input Gates.
664  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t, curr_stateerr, gate_errors[GI]);
665  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
666  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
667  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
668 
670  // 1-D forget Gates.
671  if (t > 0) {
672  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[GF1]);
673  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
674  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], sourceerr_temps[GF1]);
675  } else {
676  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
677  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
678  }
679  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
680 
681  // 2-D forget Gates.
682  if (up_pos >= 0) {
683  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[GFS]);
684  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
685  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], sourceerr_temps[GFS]);
686  } else {
687  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
688  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
689  }
690  if (Is2D()) {
691  gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
692  }
693 
695  // Output gates.
696  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr, gate_errors[GO]);
697  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
698  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
699  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
701 
702  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], sourceerr_temps[GF1],
703  sourceerr_temps[GO], sourceerr_temps[GFS], curr_sourceerr);
704  back_deltas->WriteTimeStep(t, curr_sourceerr);
705  // Save states for use by the 2nd dimension only if needed.
706  if (Is2D()) {
707  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
708  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
709  }
710  } while (dest_index.Decrement());
711 #if DEBUG_DETAIL > 2
712  for (int w = 0; w < WT_COUNT; ++w) {
713  tprintf("%s gate errors[%d]\n", name_.c_str(), w);
714  gate_errors_t[w].get()->PrintUnTransposed(10);
715  }
716 #endif
717  // Transposed source_ used to speed-up SumOuter.
718  NetworkScratch::GradientStore source_t, state_t;
719  source_t.Init(na_, width, scratch);
720  source_.Transpose(source_t.get());
721  state_t.Init(ns_, width, scratch);
722  state_.Transpose(state_t.get());
723 #ifdef _OPENMP
724 # pragma omp parallel for num_threads(GFS) if (!Is2D())
725 #endif
726  for (int w = 0; w < WT_COUNT; ++w) {
727  if (w == GFS && !Is2D()) {
728  continue;
729  }
730  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
731  }
732  if (softmax_ != nullptr) {
733  softmax_->FinishBackward(*softmax_errors_t);
734  }
735  return needs_to_backprop_;
736 }
737 
738 // Updates the weights using the given learning rate, momentum and adam_beta.
739 // num_samples is used in the adam computation iff use_adam_ is true.
740 void LSTM::Update(float learning_rate, float momentum, float adam_beta, int num_samples) {
741 #if DEBUG_DETAIL > 3
742  PrintW();
743 #endif
744  for (int w = 0; w < WT_COUNT; ++w) {
745  if (w == GFS && !Is2D()) {
746  continue;
747  }
748  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
749  }
750  if (softmax_ != nullptr) {
751  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
752  }
753 #if DEBUG_DETAIL > 3
754  PrintDW();
755 #endif
756 }
757 
758 // Sums the products of weight updates in *this and other, splitting into
759 // positive (same direction) in *same and negative (different direction) in
760 // *changed.
761 void LSTM::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
762  ASSERT_HOST(other.type() == type_);
763  const LSTM *lstm = static_cast<const LSTM *>(&other);
764  for (int w = 0; w < WT_COUNT; ++w) {
765  if (w == GFS && !Is2D()) {
766  continue;
767  }
768  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
769  }
770  if (softmax_ != nullptr) {
771  softmax_->CountAlternators(*lstm->softmax_, same, changed);
772  }
773 }
774 
775 #if DEBUG_DETAIL > 3
776 
777 // Prints the weights for debug purposes.
778 void LSTM::PrintW() {
779  tprintf("Weight state:%s\n", name_.c_str());
780  for (int w = 0; w < WT_COUNT; ++w) {
781  if (w == GFS && !Is2D()) {
782  continue;
783  }
784  tprintf("Gate %d, inputs\n", w);
785  for (int i = 0; i < ni_; ++i) {
786  tprintf("Row %d:", i);
787  for (int s = 0; s < ns_; ++s) {
788  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
789  }
790  tprintf("\n");
791  }
792  tprintf("Gate %d, outputs\n", w);
793  for (int i = ni_; i < ni_ + ns_; ++i) {
794  tprintf("Row %d:", i - ni_);
795  for (int s = 0; s < ns_; ++s) {
796  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
797  }
798  tprintf("\n");
799  }
800  tprintf("Gate %d, bias\n", w);
801  for (int s = 0; s < ns_; ++s) {
802  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
803  }
804  tprintf("\n");
805  }
806 }
807 
808 // Prints the weight deltas for debug purposes.
809 void LSTM::PrintDW() {
810  tprintf("Delta state:%s\n", name_.c_str());
811  for (int w = 0; w < WT_COUNT; ++w) {
812  if (w == GFS && !Is2D()) {
813  continue;
814  }
815  tprintf("Gate %d, inputs\n", w);
816  for (int i = 0; i < ni_; ++i) {
817  tprintf("Row %d:", i);
818  for (int s = 0; s < ns_; ++s) {
819  tprintf(" %g", gate_weights_[w].GetDW(s, i));
820  }
821  tprintf("\n");
822  }
823  tprintf("Gate %d, outputs\n", w);
824  for (int i = ni_; i < ni_ + ns_; ++i) {
825  tprintf("Row %d:", i - ni_);
826  for (int s = 0; s < ns_; ++s) {
827  tprintf(" %g", gate_weights_[w].GetDW(s, i));
828  }
829  tprintf("\n");
830  }
831  tprintf("Gate %d, bias\n", w);
832  for (int s = 0; s < ns_; ++s) {
833  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
834  }
835  tprintf("\n");
836  }
837 }
838 
839 #endif
840 
841 // Resizes forward data to cope with an input image of the given width.
842 void LSTM::ResizeForward(const NetworkIO &input) {
843  int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
844  source_.Resize(input, rounded_inputs);
845  which_fg_.ResizeNoInit(input.Width(), ns_);
846  if (IsTraining()) {
847  state_.ResizeFloat(input, ns_);
848  for (int w = 0; w < WT_COUNT; ++w) {
849  if (w == GFS && !Is2D()) {
850  continue;
851  }
852  node_values_[w].ResizeFloat(input, ns_);
853  }
854  }
855 }
856 
857 } // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:59
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:64
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:62
#define SECTION_IF_OPENMP
Definition: lstm.cpp:63
const TFloat kErrClip
Definition: lstm.cpp:73
const TFloat kStateClip
Definition: lstm.cpp:71
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
void SumVectors(int n, const TFloat *v1, const TFloat *v2, const TFloat *v3, const TFloat *v4, const TFloat *v5, TFloat *sum)
Definition: functions.h:236
TrainingState
Definition: network.h:90
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_ENABLED
Definition: network.h:93
@ TS_RE_ENABLE
Definition: network.h:97
void MultiplyAccumulate(int n, const TFloat *u, const TFloat *v, TFloat *out)
Definition: functions.h:229
NetworkType
Definition: network.h:41
@ NT_LSTM
Definition: network.h:58
@ NT_SOFTMAX
Definition: network.h:66
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:74
@ NT_LSTM_SUMMARY
Definition: network.h:59
@ NT_LSTM_SOFTMAX
Definition: network.h:73
double TFloat
Definition: tesstypes.h:39
@ FD_WIDTH
Definition: stridemap.h:35
@ FD_HEIGHT
Definition: stridemap.h:34
@ NF_ADAM
Definition: network.h:86
void CopyVector(int n, const TFloat *src, TFloat *dest)
Definition: functions.h:210
void CodeInBinary(int n, int nf, TFloat *vec)
Definition: functions.h:259
void AccumulateVector(int n, const TFloat *src, TFloat *dest)
Definition: functions.h:215
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:251
void MultiplyVectorsInPlace(int n, const TFloat *src, TFloat *inout)
Definition: functions.h:222
int Modulo(int a, int b)
Definition: helpers.h:158
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:94
int RoundOutputs(int size) const
Definition: intsimdmatrix.h:74
static const IntSimdMatrix * intSimdMatrix
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
void ForwardTimeStep(int t, TFloat *output_line)
void FinishBackward(const TransposedArray &errors_t)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void SetEnableTraining(TrainingState state) override
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
int InitWeights(float range, TRand *randomizer) override
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
StaticShape OutputShape(const StaticShape &input_shape) const override
bool Serialize(TFile *fp) const override
bool Is2D() const
Definition: lstm.h:119
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: lstm.cpp:507
TESS_API LSTM(const std::string &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:101
~LSTM() override
Definition: lstm.cpp:126
int InitWeights(float range, TRand *randomizer) override
Definition: lstm.cpp:175
void DebugWeights() override
Definition: lstm.cpp:215
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:193
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
Definition: lstm.cpp:761
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:253
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:230
void ConvertToInt() override
Definition: lstm.cpp:202
void SetEnableTraining(TrainingState state) override
Definition: lstm.cpp:146
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: lstm.cpp:291
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:740
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: lstm.cpp:132
NetworkType type_
Definition: network.h:300
bool needs_to_backprop_
Definition: network.h:302
int num_weights() const
Definition: network.h:119
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:333
std::string name_
Definition: network.h:307
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:341
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:217
bool IsTraining() const
Definition: network.h:113
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
int32_t num_weights_
Definition: network.h:306
TrainingState training_
Definition: network.h:301
NetworkType type() const
Definition: network.h:110
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
void WriteTimeStepPart(int t, int offset, int num_features, const TFloat *input)
Definition: networkio.cpp:671
void ResizeXTo1(const NetworkIO &src, int num_features)
Definition: networkio.cpp:68
const StrideMap & stride_map() const
Definition: networkio.h:129
bool int_mode() const
Definition: networkio.h:123
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
const int8_t * i(int t) const
Definition: networkio.h:119
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:405
void WriteTimeStep(int t, const TFloat *input)
Definition: networkio.cpp:665
void FuncMultiply3Add(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
Definition: networkio.h:297
void Print(int num) const
Definition: networkio.cpp:378
float * f(int t)
Definition: networkio.h:111
void ReadTimeStep(int t, TFloat *output) const
Definition: networkio.cpp:619
int Width() const
Definition: networkio.h:103
void Func2Multiply3(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
Definition: networkio.h:312
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:980
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:46
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void Init(int, int reserve, NetworkScratch *scratch)
void Init(int size1, int size2, NetworkScratch *scratch)
void set_depth(int value)
Definition: static_shape.h:62
void set_width(int value)
Definition: static_shape.h:56
int Size(FlexDimensions dimension) const
Definition: stridemap.h:119
int index(FlexDimensions dimension) const
Definition: stridemap.h:59
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:67
bool IsLast(FlexDimensions dimension) const
Definition: stridemap.cpp:40
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40
void PrintUnTransposed(int num)
Definition: weightmatrix.h:53
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
void Debug2D(const char *msg)
void VectorDotMatrix(const TFloat *u, TFloat *v) const
void MatrixDotVector(const TFloat *u, TFloat *v) const
int RoundInputs(int size) const
Definition: weightmatrix.h:96
void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const