-
-
Notifications
You must be signed in to change notification settings - Fork 839
Expand file tree
/
Copy pathcpu_ops.cpp
More file actions
542 lines (491 loc) · 22.8 KB
/
cpu_ops.cpp
File metadata and controls
542 lines (491 loc) · 22.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
#include "cpu_ops.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <thread>
#include <vector>
#ifdef HAS_OPENMP
#include <omp.h>
#define BNB_OMP_PARALLEL_FOR _Pragma("omp parallel for")
#else
#define BNB_OMP_PARALLEL_FOR
#endif
namespace {
constexpr int kCodebookSize = 256;
inline unsigned char lookup_code_index(const float* codebook, float value) {
value = std::clamp(value, -1.0f, 1.0f);
const float* begin = codebook;
const float* end = codebook + kCodebookSize;
const float* right = std::lower_bound(begin, end, value);
if (right == begin) {
return 0;
}
if (right == end) {
return static_cast<unsigned char>(kCodebookSize - 1);
}
const float* left = right - 1;
const float dist_left = std::fabs(value - *left);
const float dist_right = std::fabs(*right - value);
const unsigned char idx = static_cast<unsigned char>(right - begin);
return dist_right < dist_left ? idx : idx - 1;
}
} // namespace
#if defined(__AVX512F__)
#include <immintrin.h>
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
#if defined(__AVX512BF16__)
if (has_avx512bf16()) {
return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src));
}
#endif
__m512i value = _mm512_castps_si512(src);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_value = _mm512_add_epi32(t_value, vec_bias);
// input += rounding_bias;
t_value = _mm512_add_epi32(t_value, value);
// input = input >> 16;
t_value = _mm512_srli_epi32(t_value, 16);
// Check NaN before converting back to bf16
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
return _mm512_cvtusepi32_epi16(t_value);
}
static inline __m512 set_nf4_lut() {
return _mm512_set_ps(
1.0f, 0.7229568362236023, 0.5626170039176941, 0.44070982933044434, 0.33791524171829224, 0.24611230194568634,
0.16093020141124725, 0.07958029955625534, 0.0f, -0.09105003625154495, -0.18477343022823334,
-0.28444138169288635, -0.39491748809814453, -0.5250730514526367, -0.6961928009986877, -1.0f
);
}
static inline __m512 set_fp4_lut() {
return _mm512_set_ps(
-0.2500f, -0.16666667f, -0.5000f, -0.33333333f, -1.0000f, -0.66666667f, -5.208333333e-03f, 0.0000f, 0.2500f,
0.16666667f, 0.5000f, 0.33333333f, 1.0000f, 0.66666667f, 5.208333333e-03f, 0.0000f
);
}
#endif
// 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch.
// DATA_TYPE: 1 = FP4, 2 = NF4
template <typename T, int DATA_TYPE>
void dequantizeBlockwise4bitCpu(
unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n
) {
static_assert(DATA_TYPE == 1 || DATA_TYPE == 2, "dequantizeBlockwise4bitCpu called with non 4-bit DATA_TYPE");
if (blocksize <= 0 || m < 0 || n <= 0)
return;
#if defined(__AVX512F__)
if (has_avx512f()) {
long long dim_0 = m;
long long dim_1 = n;
long long input_dim_1 = dim_1 >> 1;
long long absmax_dim_1 = dim_1 / blocksize;
using Tcomp = float;
constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16
if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) {
__m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut();
constexpr auto k_step = VEC_LEN / 2; // 8
BNB_OMP_PARALLEL_FOR
for (int block_idx = 0; block_idx < dim_0; ++block_idx) {
for (int k = 0; k < input_dim_1; k += k_step) {
// Load 64 bits of nf4 data and a single scale data
uint8_t* p = &A[block_idx * input_dim_1 + k];
uint64_t packed;
std::memcpy(&packed, p, sizeof(uint64_t));
auto scale_idx = k * 2 / blocksize;
auto vscales = _mm512_set1_ps((float)absmax[block_idx * absmax_dim_1 + scale_idx]);
// unpack nf4 data to 32-bit integers
uint64_t high = 0;
uint64_t low = 0;
for (int i = 0; i < 4; ++i) {
low |= ((packed >> (2 * i * 4)) & 0xf) << ((2 * i + 1) * 8);
low |= ((packed >> ((2 * i + 1) * 4)) & 0xf) << (2 * i * 8);
high |= ((packed >> (2 * i * 4 + 32)) & 0xf) << ((2 * i + 1) * 8);
high |= ((packed >> ((2 * i + 1) * 4 + 32)) & 0xf) << (2 * i * 8);
}
__m128i packed_128 = _mm_set_epi64x(high, low);
__m512i vint32 = _mm512_cvtepu8_epi32(packed_128);
// Table look-up
__m512 vout = _mm512_permutexvar_ps(vint32, lut);
// Apply scale
vout = _mm512_mul_ps(vout, vscales);
// Store results
T* pout = &out[block_idx * dim_1 + k * 2];
if constexpr (std::is_same<T, float>()) {
_mm512_storeu_ps(pout, vout);
} else if constexpr (std::is_same<T, bf16_t>()) {
_mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_bf16(vout));
} else if constexpr (std::is_same<T, fp16_t>()) {
_mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_fp16(vout));
}
}
}
return;
}
}
#endif
// Scalar fallback branch
long long total = m * n;
BNB_OMP_PARALLEL_FOR
for (long long block_idx = 0; block_idx < total; block_idx += blocksize) {
long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx);
float scale = absmax[block_idx / blocksize];
for (long long i = 0; i < valid_items; i += 2) {
long long byte_index = (block_idx + i) >> 1;
unsigned char byte = A[byte_index];
// High nibble first (matches previous code logic)
float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) : dDequantizeNF4(byte >> 4)) * scale;
// Low nibble second
float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) : dDequantizeNF4(byte & 0x0F)) * scale;
if constexpr (std::is_same<T, bf16_t>::value) {
out[block_idx + i] = float_to_bf16(v0);
} else if constexpr (std::is_same<T, fp16_t>::value) {
out[block_idx + i] = float_to_fp16(v0);
} else {
out[block_idx + i] = static_cast<T>(v0);
}
if (i + 1 < valid_items) {
if constexpr (std::is_same<T, bf16_t>::value) {
out[block_idx + i + 1] = float_to_bf16(v1);
} else if constexpr (std::is_same<T, fp16_t>::value) {
out[block_idx + i + 1] = float_to_fp16(v1);
} else {
out[block_idx + i + 1] = static_cast<T>(v1);
}
}
}
}
}
template <typename T>
void dequantizeBlockwise8bitCpu(
float* code, unsigned char* A, const float* absmax, T* out, long long blocksize, long long n
) {
if (blocksize <= 0 || n <= 0)
return;
// 8-bit path
BNB_OMP_PARALLEL_FOR
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx);
long long block_end = block_idx + valid_items;
float scale = absmax[block_idx / blocksize];
for (long long i = block_idx; i < block_end; ++i) {
float v = code[A[i]] * scale;
if constexpr (std::is_same<T, bf16_t>::value) {
out[i] = float_to_bf16(v);
} else if constexpr (std::is_same<T, fp16_t>::value) {
out[i] = float_to_fp16(v);
} else {
out[i] = static_cast<T>(v);
}
}
}
}
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {
if (blocksize <= 0 || n <= 0)
return;
// Ensure we cover the full expected dynamic range of the codebook.
code[0] = -1.0f;
const auto process_block = [&](long long block_start, long long block_end) {
float absmax_block = 0.0f;
for (long long i = block_start; i < block_end; ++i) {
absmax_block = std::max(absmax_block, std::fabs(A[i]));
}
long long absmax_idx = block_start / blocksize;
absmax[absmax_idx] = absmax_block;
if (absmax_block == 0.0f) {
std::fill(out + block_start, out + block_end, 0);
return;
}
const float inv_absmax = 1.0f / absmax_block;
for (long long i = block_start; i < block_end; ++i) {
float normed_value = A[i] * inv_absmax;
out[i] = lookup_code_index(code, normed_value);
}
};
const long long num_blocks = (n + blocksize - 1) / blocksize;
const int thread_wave_size = 256;
// We chunk the threads into waves of 256 since the max limit is between 16k and 64k on Linux
// (we reach this when running BLOOM-176B with a large batch size).
for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
const long long wave_blocks = std::min<long long>(thread_wave_size, num_blocks - offset);
std::vector<std::thread> threads;
threads.reserve(wave_blocks);
const long long first_block_start = offset * blocksize;
for (long long b = 0; b < wave_blocks; ++b) {
const long long block_start = first_block_start + b * blocksize;
if (block_start >= n)
break;
const long long block_end = std::min(block_start + blocksize, n);
threads.emplace_back(process_block, block_start, block_end);
}
for (auto& thread : threads) {
if (thread.joinable()) {
thread.join();
}
}
}
}
#if defined(__AVX512F__) && defined(__AVX512BF16__)
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
template <typename scalar_t, int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t*, const unsigned char*, scalar_t*, const scalar_t*, int64_t, int, int64_t, int64_t, int64_t,
int64_t, int64_t
) {
static_assert(sizeof(scalar_t) == 0, "tinygemm_kernel_nn primary template should never be instantiated");
}
};
template <int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tinygemm_kernel_nn<bf16_t, BLOCK_M, BLOCK_N, DATA_TYPE> {
static inline void apply(
const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C,
const bf16_t* __restrict__ Bs, int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc,
int64_t strideBz, int64_t strideBs
) {
static_assert(BLOCK_N % 32 == 0);
constexpr int ROWS = BLOCK_M; // 32
constexpr int COLS = BLOCK_N / 16; // 2
// prefetch distance
constexpr int PREFETCH_SIZE_K = 16 * 4;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vc_master[ROWS * COLS];
__m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit
__m256i fifteen = _mm256_set1_epi8(15);
__m512i lut = DATA_TYPE == 1
? _mm512_set_epi16(
0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80,
0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
)
: _mm512_set_epi16(
0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246,
-0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000,
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
);
__m512 scales[COLS];
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const int64_t gs2 = group_size >> 1; // 64 / 2 = 32
const float* a_ptr = reinterpret_cast<const float*>(A);
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
vc_master[i] = _mm512_set1_ps(0.f);
};
Unroll<ROWS * COLS>{}(loadc);
auto pre_compute = [&](auto i, int64_t kgs) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i] = _mm512_set1_ps(0.f); // reset accumulator
// load scales
if constexpr (row == 0 && col % 2 == 0) {
// Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16
__m512i tmp = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(Bs + kgs * strideBs + col * 16));
scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0));
scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1));
}
};
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0 && col % 2 == 0) {
__m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(B + k * ldb + col * 16));
// deinterleave and lookup to BF16
__m256i vb_i8_lo = vb_u4 & mask;
__m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask;
vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen);
vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen);
vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut);
vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
auto post_compute = [&](auto i, int64_t kgs) {
vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]);
};
for (int64_t k = 0; k < K2; k += gs2) {
Unroll<ROWS * COLS>{}(pre_compute, k / gs2);
for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) {
Unroll<ROWS * COLS>{}(compute, k + k_offset);
}
Unroll<ROWS * COLS>{}(post_compute, k / gs2);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>(C + row * ldc + col * 16),
(__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i]))
);
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \
tinygemm_kernel_nn<scalar_t, MB_SIZE, NB_SIZE, DATA_TYPE>::apply( \
A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, group_size, lda, ldb, ldc, \
strideBz, strideBs \
);
template <typename scalar_t, int DATA_TYPE>
void tinygemm_kernel(
const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C,
const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N,
int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBz, int64_t strideBs
) {
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32, DATA_TYPE);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64, DATA_TYPE);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32, DATA_TYPE);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64, DATA_TYPE);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32, DATA_TYPE);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64, DATA_TYPE);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32, DATA_TYPE);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64, DATA_TYPE);
break;
default: {
std::fprintf(
stderr, "[bitsandbytes] Unexpected block size %lldx%lld\n", (long long)mb_size, (long long)nb_size
);
std::abort(); // or return; if you prefer silent exit
}
}
}
}
}
template <typename T, int DATA_TYPE>
void gemv_4bit_inference(
int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w,
const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
) {
constexpr int64_t BLOCK_M = block_size_m(); // 32
constexpr int64_t BLOCK_N = block_size_n(); // 32
const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32
const int64_t NB = div_up(N, BLOCK_N);
// TODO: enable brgemm in the future.
// const bool use_brgemm = M > 4;
// const bool use_brgemm_dequant_out = M > 512;
// T* Btmp_start = nullptr;
// l2 cache block for n
int64_t cache_blocks_nb = get_cache_blocks<T>(BLOCK_N * K);
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
alignas(64) T Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
for (int64_t mb = begin_mb; mb < end_mb; ++mb) { // 0-1
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
int64_t mb_start = mb * BLOCK_M; // 0
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<T, DATA_TYPE>(
/* A */ x + mb_start * x_stride,
/* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2
/* C */ out + mb_start * out_stride + nb_start,
/* Bs */ absmax + nb_start,
/* Btmp */ Btmp_inner,
/* Ctmp */ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* gs */ blocksize, // group_size
/* lda */ x_stride,
/* ldb */ nb_size,
/* ldc */ out_stride,
/* sBz */ N,
/* sBs */ N
);
}
}
}
// if (use_brgemm) {
// at::native::cpublas::brgemm_release();
// }
});
}
#endif
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template void dequantizeBlockwise8bitCpu<float>(
float* code, unsigned char* A, const float* absmax, float* out, long long blocksize, long long n
);
template void dequantizeBlockwise8bitCpu<fp16_t>(
float* code, unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long n
);
template void dequantizeBlockwise8bitCpu<bf16_t>(
float* code, unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long n
);
template void dequantizeBlockwise4bitCpu<float, FP4>(
unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n
);
template void dequantizeBlockwise4bitCpu<float, NF4>(
unsigned char* A, const float* absmax, float* out, long long blocksize, long long m, long long n
);
template void dequantizeBlockwise4bitCpu<fp16_t, FP4>(
unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n
);
template void dequantizeBlockwise4bitCpu<fp16_t, NF4>(
unsigned char* A, const float* absmax, fp16_t* out, long long blocksize, long long m, long long n
);
template void dequantizeBlockwise4bitCpu<bf16_t, FP4>(
unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n
);
template void dequantizeBlockwise4bitCpu<bf16_t, NF4>(
unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n
);
#if defined(__AVX512F__) && defined(__AVX512BF16__)
template void gemv_4bit_inference<bf16_t, FP4>(
int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
);
template void gemv_4bit_inference<bf16_t, NF4>(
int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
);
#endif