Skip to content

Commit 678eeaa

Browse files
committed
save code
1 parent 5d06871 commit 678eeaa

1 file changed

Lines changed: 43 additions & 1 deletion

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class gemm_4bit_cutlass_kernel {
177177
return {grid.x, grid.y, grid.z};
178178
}
179179
}
180+
180181
inline float dDequantizeNF4(unsigned char val) {
181182

182183
// the values for this tree was generated by test_normal_map_tree
@@ -237,7 +238,7 @@ inline float dDequantizeNF4(unsigned char val) {
237238
//static constexpr std::array<float, 16> quant_map{};
238239
// {
239240
// Load Dequatize LUT and save to SLM, 16 for 4bits
240-
float* quant_map = reinterpret_cast<float*>(smem_buf);
241+
alignas(16) float* quant_map = reinterpret_cast<float*>(smem_buf);
241242
if (thread_idx < 16) {
242243
quant_map[thread_idx] = params.datatype[thread_idx];
243244
}
@@ -328,6 +329,46 @@ inline float dDequantizeNF4(unsigned char val) {
328329
const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k));
329330
int prefetch_k = k_start_idx;
330331

332+
#if 0
333+
constexpr float VALUE0 = -1.0f;
334+
constexpr float VALUE1 = -0.6961928f;
335+
constexpr float VALUE2 = -0.52507305f;
336+
constexpr float VALUE3 = -0.39491749f;
337+
constexpr float VALUE4 = -0.28444138f;
338+
constexpr float VALUE5 = -0.18477343f;
339+
constexpr float VALUE6 = -0.09105004f;
340+
constexpr float VALUE7 = 0.0f;
341+
constexpr float VALUE8 = 0.0795803f;
342+
constexpr float VALUE9 = 0.1609302f;
343+
constexpr float VALUE10 = 0.2461123f;
344+
constexpr float VALUE11 = 0.33791524f;
345+
constexpr float VALUE12 = 0.44070983f;
346+
constexpr float VALUE13 = 0.562617f;
347+
constexpr float VALUE14 = 0.72295684f;
348+
constexpr float VALUE15 = 1.0f;
349+
350+
auto quant_map_alias = [&](uint8_t index) {
351+
switch(index) {
352+
case 0: return VALUE0;
353+
case 1: return VALUE1;
354+
case 2: return VALUE2;
355+
case 3: return VALUE3;
356+
case 4: return VALUE4;
357+
case 5: return VALUE5;
358+
case 6: return VALUE6;
359+
case 7: return VALUE7;
360+
case 8: return VALUE8;
361+
case 9: return VALUE9;
362+
case 10: return VALUE10;
363+
case 11: return VALUE11;
364+
case 12: return VALUE12;
365+
case 13: return VALUE13;
366+
case 14: return VALUE14;
367+
case 15: return VALUE15;
368+
}
369+
};
370+
#endif
371+
331372
#if 0 //SLM
332373
#if 1
333374
auto dequant = [&] (int k_tile) {
@@ -429,6 +470,7 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
429470
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
430471
float scale_value = fragment_scale(n * (BLK_K / GROUP_SIZE) + (dst_base_idx + c) / GROUP_SIZE);
431472
dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
473+
//dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
432474
#if 0
433475
if(thread_idx==60 && m_coord==0 && n_coord==0 && l_coord==0){
434476
printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);

0 commit comments

Comments
 (0)