20 #if !defined(__AVX2__)
21 # if defined(__i686__) || defined(__x86_64__)
22 # error Implementation only for AVX2 capable architectures
25 # include <immintrin.h>
33 constexpr
int kNumOutputsPerRegister = 8;
35 constexpr
int kMaxOutputRegisters = 8;
37 constexpr
int kNumInputsPerRegister = 32;
39 constexpr
int kNumInputsPerGroup = 4;
41 constexpr
int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
60 static inline void MultiplyGroup(
const __m256i &rep_input,
const __m256i &ones,
const int8_t *&wi,
61 __m256i &weights, __m256i &reps, __m256i &result) {
63 weights = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(wi));
64 wi += kNumInputsPerRegister;
66 reps = _mm256_sign_epi8(rep_input, weights);
67 weights = _mm256_sign_epi8(weights, weights);
70 weights = _mm256_maddubs_epi16(weights, reps);
76 weights = _mm256_madd_epi16(weights, ones);
77 result = _mm256_add_epi32(result, weights);
83 static inline __m128i load64_to_128(
const int8_t *wi_) {
84 const auto *wi =
reinterpret_cast<const int64_t *
>(wi_);
85 return _mm_set_epi64x(0, wi[0]);
88 #if defined(FAST_FLOAT)
90 static inline void ExtractResults8(__m256i result,
const int8_t *wi,
91 const float *scales,
float *v) {
92 __m128i w128 = load64_to_128(wi);
93 __m256i w256 = _mm256_cvtepi8_epi32(w128);
94 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
95 __m256 scale01234567 = _mm256_loadu_ps(scales);
96 w256 = _mm256_mullo_epi32(w256, bias_scale);
97 result = _mm256_add_epi32(result, w256);
98 __m256 res01234567 = _mm256_cvtepi32_ps(result);
99 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
100 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
101 _mm256_storeu_ps(v, res01234567);
104 static inline void ExtractResults16(__m256i result0, __m256i result1,
105 const int8_t *&wi,
const float *&scales,
107 __m128i w8 = _mm_loadu_si128(
reinterpret_cast<const __m128i *
>(wi));
109 const __m256i bias_scale =
110 _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
111 __m256i w256 = _mm256_cvtepi8_epi32(w8);
112 __m256 scale01234567 = _mm256_loadu_ps(scales);
113 w256 = _mm256_mullo_epi32(w256, bias_scale);
114 result0 = _mm256_add_epi32(result0, w256);
115 __m256 res01234567 = _mm256_cvtepi32_ps(result0);
116 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
117 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
118 _mm256_storeu_ps(v, res01234567);
119 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
120 w256 = _mm256_cvtepi8_epi32(w8);
121 scale01234567 = _mm256_loadu_ps(scales + 8);
122 w256 = _mm256_mullo_epi32(w256, bias_scale);
123 result1 = _mm256_add_epi32(result1, w256);
124 res01234567 = _mm256_cvtepi32_ps(result1);
125 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
126 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
127 _mm256_storeu_ps(v + 8, res01234567);
140 static void PartialMatrixDotVector64(
const int8_t *wi,
const float *scales,
const int8_t *u,
141 int num_in,
float *v) {
144 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
145 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
147 __m256i result0 = _mm256_setzero_si256();
148 __m256i result1 = _mm256_setzero_si256();
149 __m256i result2 = _mm256_setzero_si256();
150 __m256i result3 = _mm256_setzero_si256();
151 __m256i result4 = _mm256_setzero_si256();
152 __m256i result5 = _mm256_setzero_si256();
153 __m256i result6 = _mm256_setzero_si256();
154 __m256i result7 = _mm256_setzero_si256();
156 for (
int j = 0; j < num_in;) {
157 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
160 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
162 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
164 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
165 __m256i weights, reps;
167 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
168 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
169 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
170 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
171 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
172 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
173 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
174 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
177 ExtractResults16(result0, result1, wi, scales, v);
178 ExtractResults16(result2, result3, wi, scales, v);
179 ExtractResults16(result4, result5, wi, scales, v);
180 ExtractResults16(result6, result7, wi, scales, v);
185 static void PartialMatrixDotVector32(
const int8_t *wi,
const float *scales,
const int8_t *u,
186 int num_in,
float *v) {
189 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
190 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
192 __m256i result0 = _mm256_setzero_si256();
193 __m256i result1 = _mm256_setzero_si256();
194 __m256i result2 = _mm256_setzero_si256();
195 __m256i result3 = _mm256_setzero_si256();
197 for (
int j = 0; j < num_in;) {
198 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
201 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
203 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
205 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
206 __m256i weights, reps;
208 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
209 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
210 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
211 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
214 ExtractResults16(result0, result1, wi, scales, v);
215 ExtractResults16(result2, result3, wi, scales, v);
220 static void PartialMatrixDotVector16(
const int8_t *wi,
const float *scales,
const int8_t *u,
221 int num_in,
float *v) {
224 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
225 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
227 __m256i result0 = _mm256_setzero_si256();
228 __m256i result1 = _mm256_setzero_si256();
230 for (
int j = 0; j < num_in;) {
231 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
234 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
236 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
238 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
239 __m256i weights, reps;
241 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
242 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
245 ExtractResults16(result0, result1, wi, scales, v);
250 static inline void PartialMatrixDotVector8(
const int8_t *wi,
const float *scales,
const int8_t *u,
251 int num_in,
float *v) {
254 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
255 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
257 __m256i result0 = _mm256_setzero_si256();
259 for (
int j = 0; j < num_in;) {
260 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
263 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
265 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
267 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
268 __m256i weights, reps;
270 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
273 ExtractResults8(result0, wi, scales, v);
276 static void matrixDotVector(
int dim1,
int dim2,
const int8_t *wi,
const float *scales,
277 const int8_t *u,
float *v) {
278 const int num_out = dim1;
279 const int num_in = dim2 - 1;
284 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
287 int w_step = (rounded_num_in + 1) * group_size;
291 for (; output + group_size <= rounded_num_out; output += group_size) {
292 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
294 scales += group_size;
300 if (output + group_size <= rounded_num_out) {
301 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
303 scales += group_size;
305 output += group_size;
310 if (output + group_size <= rounded_num_out) {
311 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
313 scales += group_size;
315 output += group_size;
320 if (output + group_size <= rounded_num_out) {
321 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
325 static inline void ExtractResults8(__m256i result,
const int8_t *wi,
const double *scales,
327 __m128i w128 = load64_to_128(wi);
328 __m256i w256 = _mm256_cvtepi8_epi32(w128);
329 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
330 __m256d scale0123 = _mm256_loadu_pd(scales);
331 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
332 w256 = _mm256_mullo_epi32(w256, bias_scale);
333 result = _mm256_add_epi32(result, w256);
334 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
335 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
336 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
337 res0123 = _mm256_mul_pd(res0123, scale0123);
338 res4567 = _mm256_mul_pd(res4567, scale4567);
339 _mm256_storeu_pd(v, res0123);
340 _mm256_storeu_pd(v + 4, res4567);
343 static inline void ExtractResults16(__m256i result0, __m256i result1,
const int8_t *&wi,
344 const double *&scales,
double *&v) {
345 __m128i w8 = _mm_loadu_si128(
reinterpret_cast<const __m128i *
>(wi));
347 const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
348 __m256i w256 = _mm256_cvtepi8_epi32(w8);
349 __m256d scale0123 = _mm256_loadu_pd(scales);
350 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
351 w256 = _mm256_mullo_epi32(w256, bias_scale);
352 result0 = _mm256_add_epi32(result0, w256);
353 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
354 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
355 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
356 res0123 = _mm256_mul_pd(res0123, scale0123);
357 res4567 = _mm256_mul_pd(res4567, scale4567);
358 _mm256_storeu_pd(v, res0123);
359 _mm256_storeu_pd(v + 4, res4567);
360 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
361 w256 = _mm256_cvtepi8_epi32(w8);
362 scale0123 = _mm256_loadu_pd(scales + 8);
363 scale4567 = _mm256_loadu_pd(scales + 12);
364 w256 = _mm256_mullo_epi32(w256, bias_scale);
365 result1 = _mm256_add_epi32(result1, w256);
366 res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
367 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
368 res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
369 res0123 = _mm256_mul_pd(res0123, scale0123);
370 res4567 = _mm256_mul_pd(res4567, scale4567);
371 _mm256_storeu_pd(v + 8, res0123);
372 _mm256_storeu_pd(v + 12, res4567);
385 static void PartialMatrixDotVector64(
const int8_t *wi,
const double *scales,
const int8_t *u,
386 int num_in,
double *v) {
389 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
390 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
392 __m256i result0 = _mm256_setzero_si256();
393 __m256i result1 = _mm256_setzero_si256();
394 __m256i result2 = _mm256_setzero_si256();
395 __m256i result3 = _mm256_setzero_si256();
396 __m256i result4 = _mm256_setzero_si256();
397 __m256i result5 = _mm256_setzero_si256();
398 __m256i result6 = _mm256_setzero_si256();
399 __m256i result7 = _mm256_setzero_si256();
401 for (
int j = 0; j < num_in;) {
402 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
405 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
407 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
409 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
410 __m256i weights, reps;
412 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
413 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
414 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
415 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
416 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
417 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
418 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
419 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
422 ExtractResults16(result0, result1, wi, scales, v);
423 ExtractResults16(result2, result3, wi, scales, v);
424 ExtractResults16(result4, result5, wi, scales, v);
425 ExtractResults16(result6, result7, wi, scales, v);
430 static void PartialMatrixDotVector32(
const int8_t *wi,
const double *scales,
const int8_t *u,
431 int num_in,
double *v) {
434 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
435 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
437 __m256i result0 = _mm256_setzero_si256();
438 __m256i result1 = _mm256_setzero_si256();
439 __m256i result2 = _mm256_setzero_si256();
440 __m256i result3 = _mm256_setzero_si256();
442 for (
int j = 0; j < num_in;) {
443 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
446 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
448 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
450 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
451 __m256i weights, reps;
453 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
454 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
455 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
456 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
459 ExtractResults16(result0, result1, wi, scales, v);
460 ExtractResults16(result2, result3, wi, scales, v);
465 static void PartialMatrixDotVector16(
const int8_t *wi,
const double *scales,
const int8_t *u,
466 int num_in,
double *v) {
469 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
470 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
472 __m256i result0 = _mm256_setzero_si256();
473 __m256i result1 = _mm256_setzero_si256();
475 for (
int j = 0; j < num_in;) {
476 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
479 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
481 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
483 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
484 __m256i weights, reps;
486 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
487 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
490 ExtractResults16(result0, result1, wi, scales, v);
495 static inline void PartialMatrixDotVector8(
const int8_t *wi,
const double *scales,
const int8_t *u,
496 int num_in,
double *v) {
499 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
500 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
502 __m256i result0 = _mm256_setzero_si256();
504 for (
int j = 0; j < num_in;) {
505 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
508 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
510 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
512 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
513 __m256i weights, reps;
515 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
518 ExtractResults8(result0, wi, scales, v);
521 static void matrixDotVector(
int dim1,
int dim2,
const int8_t *wi,
const double *scales,
522 const int8_t *u,
double *v) {
523 const int num_out = dim1;
524 const int num_in = dim2 - 1;
529 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
532 int w_step = (rounded_num_in + 1) * group_size;
536 for (; output + group_size <= rounded_num_out; output += group_size) {
537 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
539 scales += group_size;
545 if (output + group_size <= rounded_num_out) {
546 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
548 scales += group_size;
550 output += group_size;
555 if (output + group_size <= rounded_num_out) {
556 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
558 scales += group_size;
560 output += group_size;
565 if (output + group_size <= rounded_num_out) {
566 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
575 kNumOutputsPerRegister,
579 kNumInputsPerRegister,
static const IntSimdMatrix intSimdMatrixAVX2
static int Roundup(int input, int factor)