tesseract  5.0.0
intsimdmatrixavx2.cpp
Go to the documentation of this file.
1 // File: intsimdmatrixavx2.cpp
3 // Description: matrix-vector product for 8-bit data on avx2.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2017, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #include "intsimdmatrix.h"
19 
20 #if !defined(__AVX2__)
21 # if defined(__i686__) || defined(__x86_64__)
22 # error Implementation only for AVX2 capable architectures
23 # endif
24 #else
25 # include <immintrin.h>
26 # include <algorithm>
27 # include <cstdint>
28 # include <vector>
29 
30 namespace tesseract {
31 
32 // Number of outputs held in each register. 8 x 32 bit ints.
33 constexpr int kNumOutputsPerRegister = 8;
34 // Maximum number of registers that we will use.
35 constexpr int kMaxOutputRegisters = 8;
36 // Number of inputs in the inputs register.
37 constexpr int kNumInputsPerRegister = 32;
38 // Number of inputs in each weight group.
39 constexpr int kNumInputsPerGroup = 4;
40 // Number of groups of inputs to be broadcast.
41 constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
42 
43 // Functions to compute part of a matrix.vector multiplication. The weights
44 // are in a very specific order (see above) in w, which is multiplied by
45 // u of length num_in, to produce output v after scaling the integer results
46 // by the corresponding member of scales.
47 // The amount of w and scales consumed is fixed and not available to the
48 // caller. The number of outputs written to v will be at most num_out.
49 
50 // Computes one set of 4x8 products of inputs and weights, adding to result.
51 // Horizontally adds 4 adjacent results, making 8x32-bit results.
52 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
53 // Note that wi must previously have been re-organized with blocks of 4x8
54 // weights in contiguous memory.
55 // ones is a register of 16x16-bit values all equal to 1.
56 // Note: wi is incremented by the amount of data read.
57 // weights and reps are scratch registers.
58 // This function must be inlined with references in order for the compiler to
59 // correctly use the registers declared in the caller.
60 static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
61  __m256i &weights, __m256i &reps, __m256i &result) {
62  // Load a 4x8 block of weights.
63  weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
64  wi += kNumInputsPerRegister;
65  // Normalize the signs on rep_input, weights, so weights is always +ve.
66  reps = _mm256_sign_epi8(rep_input, weights);
67  weights = _mm256_sign_epi8(weights, weights);
68  // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
69  // with adjacent pairs added.
70  weights = _mm256_maddubs_epi16(weights, reps);
71  // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
72  // with adjacent pairs added. What we really want is a horizontal add of
73  // 16+16=32 bit result, but there is no such instruction, so multiply by
74  // 16-bit ones instead. It is probably faster than all the sign-extending,
75  // permuting and adding that would otherwise be required.
76  weights = _mm256_madd_epi16(weights, ones);
77  result = _mm256_add_epi32(result, weights);
78 }
79 
80 // Load 64 bits into the bottom of a 128bit register.
81 // We don't actually care what the top 64bits are, but this ends
82 // up with them being zero.
83 static inline __m128i load64_to_128(const int8_t *wi_) {
84  const auto *wi = reinterpret_cast<const int64_t *>(wi_);
85  return _mm_set_epi64x(0, wi[0]);
86 }
87 
88 #if defined(FAST_FLOAT)
89 
90 static inline void ExtractResults8(__m256i result, const int8_t *wi,
91  const float *scales, float *v) {
92  __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
93  __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
94  __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
95  __m256 scale01234567 = _mm256_loadu_ps(scales);
96  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
97  result = _mm256_add_epi32(result, w256); // result += bias * 127
98  __m256 res01234567 = _mm256_cvtepi32_ps(result);
99  result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
100  res01234567 = _mm256_mul_ps(res01234567, scale01234567);
101  _mm256_storeu_ps(v, res01234567);
102 }
103 
104 static inline void ExtractResults16(__m256i result0, __m256i result1,
105  const int8_t *&wi, const float *&scales,
106  float *&v) {
107  __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
108  // 8x8bit vals in bottom of 128bit reg
109  const __m256i bias_scale =
110  _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
111  __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
112  __m256 scale01234567 = _mm256_loadu_ps(scales);
113  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
114  result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
115  __m256 res01234567 = _mm256_cvtepi32_ps(result0);
116  result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
117  res01234567 = _mm256_mul_ps(res01234567, scale01234567);
118  _mm256_storeu_ps(v, res01234567);
119  w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
120  w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
121  scale01234567 = _mm256_loadu_ps(scales + 8);
122  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
123  result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
124  res01234567 = _mm256_cvtepi32_ps(result1);
125  result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
126  res01234567 = _mm256_mul_ps(res01234567, scale01234567);
127  _mm256_storeu_ps(v + 8, res01234567);
128  wi += 16;
129  scales += 16;
130  v += 16;
131 }
132 
133 // Computes part of matrix.vector v = Wu. Computes N=64 results.
134 // The weights *must* be arranged so that consecutive reads from wi
135 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
136 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
137 // bias weights, before continuing with any more weights.
138 // u must be padded out with zeros to
139 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
140 static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
141  int num_in, float *v) {
142  // Register containing 16-bit ones for horizontal add with 16->32 bit
143  // conversion.
144  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
145  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
146  // Initialize all the results to 0.
147  __m256i result0 = _mm256_setzero_si256();
148  __m256i result1 = _mm256_setzero_si256();
149  __m256i result2 = _mm256_setzero_si256();
150  __m256i result3 = _mm256_setzero_si256();
151  __m256i result4 = _mm256_setzero_si256();
152  __m256i result5 = _mm256_setzero_si256();
153  __m256i result6 = _mm256_setzero_si256();
154  __m256i result7 = _mm256_setzero_si256();
155  // Iterate over the input (u), one registerful at a time.
156  for (int j = 0; j < num_in;) {
157  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
158  // Inputs are processed in groups of kNumInputsPerGroup, replicated
159  // kNumInputGroups times.
160  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
161  // Replicate the low 32 bits (4 inputs) 8 times.
162  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
163  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
164  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
165  __m256i weights, reps;
166  // Mul-add, with horizontal add of the 4 inputs to each of the results.
167  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
168  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
169  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
170  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
171  MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
172  MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
173  MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
174  MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
175  }
176  }
177  ExtractResults16(result0, result1, wi, scales, v);
178  ExtractResults16(result2, result3, wi, scales, v);
179  ExtractResults16(result4, result5, wi, scales, v);
180  ExtractResults16(result6, result7, wi, scales, v);
181 }
182 
183 // Computes part of matrix.vector v = Wu. Computes N=32 results.
184 // For details see PartialMatrixDotVector64 with N=32.
185 static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
186  int num_in, float *v) {
187  // Register containing 16-bit ones for horizontal add with 16->32 bit
188  // conversion.
189  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
190  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
191  // Initialize all the results to 0.
192  __m256i result0 = _mm256_setzero_si256();
193  __m256i result1 = _mm256_setzero_si256();
194  __m256i result2 = _mm256_setzero_si256();
195  __m256i result3 = _mm256_setzero_si256();
196  // Iterate over the input (u), one registerful at a time.
197  for (int j = 0; j < num_in;) {
198  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
199  // Inputs are processed in groups of kNumInputsPerGroup, replicated
200  // kNumInputGroups times.
201  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
202  // Replicate the low 32 bits (4 inputs) 8 times.
203  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
204  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
205  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
206  __m256i weights, reps;
207  // Mul-add, with horizontal add of the 4 inputs to each of the results.
208  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
209  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
210  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
211  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
212  }
213  }
214  ExtractResults16(result0, result1, wi, scales, v);
215  ExtractResults16(result2, result3, wi, scales, v);
216 }
217 
218 // Computes part of matrix.vector v = Wu. Computes N=16 results.
219 // For details see PartialMatrixDotVector64 with N=16.
220 static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
221  int num_in, float *v) {
222  // Register containing 16-bit ones for horizontal add with 16->32 bit
223  // conversion.
224  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
225  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
226  // Initialize all the results to 0.
227  __m256i result0 = _mm256_setzero_si256();
228  __m256i result1 = _mm256_setzero_si256();
229  // Iterate over the input (u), one registerful at a time.
230  for (int j = 0; j < num_in;) {
231  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
232  // Inputs are processed in groups of kNumInputsPerGroup, replicated
233  // kNumInputGroups times.
234  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
235  // Replicate the low 32 bits (4 inputs) 8 times.
236  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
237  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
238  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
239  __m256i weights, reps;
240  // Mul-add, with horizontal add of the 4 inputs to each of the results.
241  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
242  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
243  }
244  }
245  ExtractResults16(result0, result1, wi, scales, v);
246 }
247 
248 // Computes part of matrix.vector v = Wu. Computes N=8 results.
249 // For details see PartialMatrixDotVector64 with N=8.
250 static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
251  int num_in, float *v) {
252  // Register containing 16-bit ones for horizontal add with 16->32 bit
253  // conversion.
254  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
255  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
256  // Initialize all the results to 0.
257  __m256i result0 = _mm256_setzero_si256();
258  // Iterate over the input (u), one registerful at a time.
259  for (int j = 0; j < num_in;) {
260  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
261  // Inputs are processed in groups of kNumInputsPerGroup, replicated
262  // kNumInputGroups times.
263  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
264  // Replicate the low 32 bits (4 inputs) 8 times.
265  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
266  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
267  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
268  __m256i weights, reps;
269  // Mul-add, with horizontal add of the 4 inputs to each of the results.
270  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
271  }
272  }
273  ExtractResults8(result0, wi, scales, v);
274 }
275 
276 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
277  const int8_t *u, float *v) {
278  const int num_out = dim1;
279  const int num_in = dim2 - 1;
280  // Each call to a partial_func_ produces group_size outputs, except the
281  // last one, which can produce less.
282  const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
283  const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
284  int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
285  int output = 0;
286 
287  int w_step = (rounded_num_in + 1) * group_size;
288 
289  // Run with this group size, until it would produce too much output, then
290  // switch to a smaller size.
291  for (; output + group_size <= rounded_num_out; output += group_size) {
292  PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
293  wi += w_step;
294  scales += group_size;
295  v += group_size;
296  }
297  group_size /= 2;
298  w_step /= 2;
299 
300  if (output + group_size <= rounded_num_out) {
301  PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
302  wi += w_step;
303  scales += group_size;
304  v += group_size;
305  output += group_size;
306  }
307  group_size /= 2;
308  w_step /= 2;
309 
310  if (output + group_size <= rounded_num_out) {
311  PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
312  wi += w_step;
313  scales += group_size;
314  v += group_size;
315  output += group_size;
316  }
317  group_size /= 2;
318  w_step /= 2;
319 
320  if (output + group_size <= rounded_num_out) {
321  PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
322  }
323 }
324 #else
325 static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
326  double *v) {
327  __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
328  __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
329  __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
330  __m256d scale0123 = _mm256_loadu_pd(scales);
331  __m256d scale4567 = _mm256_loadu_pd(scales + 4);
332  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
333  result = _mm256_add_epi32(result, w256); // result += bias * 127
334  __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
335  result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
336  __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
337  res0123 = _mm256_mul_pd(res0123, scale0123);
338  res4567 = _mm256_mul_pd(res4567, scale4567);
339  _mm256_storeu_pd(v, res0123);
340  _mm256_storeu_pd(v + 4, res4567);
341 }
342 
343 static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
344  const double *&scales, double *&v) {
345  __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
346  // 8x8bit vals in bottom of 128bit reg
347  const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
348  __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
349  __m256d scale0123 = _mm256_loadu_pd(scales);
350  __m256d scale4567 = _mm256_loadu_pd(scales + 4);
351  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
352  result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
353  __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
354  result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
355  __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
356  res0123 = _mm256_mul_pd(res0123, scale0123);
357  res4567 = _mm256_mul_pd(res4567, scale4567);
358  _mm256_storeu_pd(v, res0123);
359  _mm256_storeu_pd(v + 4, res4567);
360  w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
361  w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
362  scale0123 = _mm256_loadu_pd(scales + 8);
363  scale4567 = _mm256_loadu_pd(scales + 12);
364  w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
365  result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
366  res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
367  result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
368  res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
369  res0123 = _mm256_mul_pd(res0123, scale0123);
370  res4567 = _mm256_mul_pd(res4567, scale4567);
371  _mm256_storeu_pd(v + 8, res0123);
372  _mm256_storeu_pd(v + 12, res4567);
373  wi += 16;
374  scales += 16;
375  v += 16;
376 }
377 
378 // Computes part of matrix.vector v = Wu. Computes N=64 results.
379 // The weights *must* be arranged so that consecutive reads from wi
380 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
381 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
382 // bias weights, before continuing with any more weights.
383 // u must be padded out with zeros to
384 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
385 static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
386  int num_in, double *v) {
387  // Register containing 16-bit ones for horizontal add with 16->32 bit
388  // conversion.
389  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
390  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
391  // Initialize all the results to 0.
392  __m256i result0 = _mm256_setzero_si256();
393  __m256i result1 = _mm256_setzero_si256();
394  __m256i result2 = _mm256_setzero_si256();
395  __m256i result3 = _mm256_setzero_si256();
396  __m256i result4 = _mm256_setzero_si256();
397  __m256i result5 = _mm256_setzero_si256();
398  __m256i result6 = _mm256_setzero_si256();
399  __m256i result7 = _mm256_setzero_si256();
400  // Iterate over the input (u), one registerful at a time.
401  for (int j = 0; j < num_in;) {
402  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
403  // Inputs are processed in groups of kNumInputsPerGroup, replicated
404  // kNumInputGroups times.
405  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
406  // Replicate the low 32 bits (4 inputs) 8 times.
407  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
408  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
409  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
410  __m256i weights, reps;
411  // Mul-add, with horizontal add of the 4 inputs to each of the results.
412  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
413  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
414  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
415  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
416  MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
417  MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
418  MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
419  MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
420  }
421  }
422  ExtractResults16(result0, result1, wi, scales, v);
423  ExtractResults16(result2, result3, wi, scales, v);
424  ExtractResults16(result4, result5, wi, scales, v);
425  ExtractResults16(result6, result7, wi, scales, v);
426 }
427 
428 // Computes part of matrix.vector v = Wu. Computes N=32 results.
429 // For details see PartialMatrixDotVector64 with N=32.
430 static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
431  int num_in, double *v) {
432  // Register containing 16-bit ones for horizontal add with 16->32 bit
433  // conversion.
434  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
435  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
436  // Initialize all the results to 0.
437  __m256i result0 = _mm256_setzero_si256();
438  __m256i result1 = _mm256_setzero_si256();
439  __m256i result2 = _mm256_setzero_si256();
440  __m256i result3 = _mm256_setzero_si256();
441  // Iterate over the input (u), one registerful at a time.
442  for (int j = 0; j < num_in;) {
443  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
444  // Inputs are processed in groups of kNumInputsPerGroup, replicated
445  // kNumInputGroups times.
446  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
447  // Replicate the low 32 bits (4 inputs) 8 times.
448  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
449  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
450  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
451  __m256i weights, reps;
452  // Mul-add, with horizontal add of the 4 inputs to each of the results.
453  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
454  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
455  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
456  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
457  }
458  }
459  ExtractResults16(result0, result1, wi, scales, v);
460  ExtractResults16(result2, result3, wi, scales, v);
461 }
462 
463 // Computes part of matrix.vector v = Wu. Computes N=16 results.
464 // For details see PartialMatrixDotVector64 with N=16.
465 static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
466  int num_in, double *v) {
467  // Register containing 16-bit ones for horizontal add with 16->32 bit
468  // conversion.
469  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
470  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
471  // Initialize all the results to 0.
472  __m256i result0 = _mm256_setzero_si256();
473  __m256i result1 = _mm256_setzero_si256();
474  // Iterate over the input (u), one registerful at a time.
475  for (int j = 0; j < num_in;) {
476  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
477  // Inputs are processed in groups of kNumInputsPerGroup, replicated
478  // kNumInputGroups times.
479  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
480  // Replicate the low 32 bits (4 inputs) 8 times.
481  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
482  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
483  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
484  __m256i weights, reps;
485  // Mul-add, with horizontal add of the 4 inputs to each of the results.
486  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
487  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
488  }
489  }
490  ExtractResults16(result0, result1, wi, scales, v);
491 }
492 
493 // Computes part of matrix.vector v = Wu. Computes N=8 results.
494 // For details see PartialMatrixDotVector64 with N=8.
495 static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
496  int num_in, double *v) {
497  // Register containing 16-bit ones for horizontal add with 16->32 bit
498  // conversion.
499  __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
500  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
501  // Initialize all the results to 0.
502  __m256i result0 = _mm256_setzero_si256();
503  // Iterate over the input (u), one registerful at a time.
504  for (int j = 0; j < num_in;) {
505  __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
506  // Inputs are processed in groups of kNumInputsPerGroup, replicated
507  // kNumInputGroups times.
508  for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
509  // Replicate the low 32 bits (4 inputs) 8 times.
510  __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
511  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
512  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
513  __m256i weights, reps;
514  // Mul-add, with horizontal add of the 4 inputs to each of the results.
515  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
516  }
517  }
518  ExtractResults8(result0, wi, scales, v);
519 }
520 
521 static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
522  const int8_t *u, double *v) {
523  const int num_out = dim1;
524  const int num_in = dim2 - 1;
525  // Each call to a partial_func_ produces group_size outputs, except the
526  // last one, which can produce less.
527  const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
528  const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
529  int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
530  int output = 0;
531 
532  int w_step = (rounded_num_in + 1) * group_size;
533 
534  // Run with this group size, until it would produce too much output, then
535  // switch to a smaller size.
536  for (; output + group_size <= rounded_num_out; output += group_size) {
537  PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
538  wi += w_step;
539  scales += group_size;
540  v += group_size;
541  }
542  group_size /= 2;
543  w_step /= 2;
544 
545  if (output + group_size <= rounded_num_out) {
546  PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
547  wi += w_step;
548  scales += group_size;
549  v += group_size;
550  output += group_size;
551  }
552  group_size /= 2;
553  w_step /= 2;
554 
555  if (output + group_size <= rounded_num_out) {
556  PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
557  wi += w_step;
558  scales += group_size;
559  v += group_size;
560  output += group_size;
561  }
562  group_size /= 2;
563  w_step /= 2;
564 
565  if (output + group_size <= rounded_num_out) {
566  PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
567  }
568 }
569 #endif
570 
571 const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
572  // Function.
573  matrixDotVector,
574  // Number of 32 bit outputs held in each register.
575  kNumOutputsPerRegister,
576  // Maximum number of registers that we will use to hold outputs.
577  kMaxOutputRegisters,
578  // Number of 8 bit inputs in the inputs register.
579  kNumInputsPerRegister,
580  // Number of inputs in each weight group.
581  kNumInputsPerGroup
582 };
583 
584 } // namespace tesseract.
585 
586 #endif
static const IntSimdMatrix intSimdMatrixAVX2
static int Roundup(int input, int factor)
Definition: intsimdmatrix.h:87