@@ -177,6 +177,7 @@ class gemm_4bit_cutlass_kernel {
177177 return {grid.x , grid.y , grid.z };
178178 }
179179 }
180+
180181inline float dDequantizeNF4 (unsigned char val) {
181182
182183 // the values for this tree was generated by test_normal_map_tree
@@ -237,7 +238,7 @@ inline float dDequantizeNF4(unsigned char val) {
237238 // static constexpr std::array<float, 16> quant_map{};
238239 // {
239240 // Load Dequatize LUT and save to SLM, 16 for 4bits
240- float * quant_map = reinterpret_cast <float *>(smem_buf);
241+ alignas ( 16 ) float * quant_map = reinterpret_cast <float *>(smem_buf);
241242 if (thread_idx < 16 ) {
242243 quant_map[thread_idx] = params.datatype [thread_idx];
243244 }
@@ -328,6 +329,46 @@ inline float dDequantizeNF4(unsigned char val) {
328329 const int k_start_idx = crd2idx ((*k_tile_iter), make_shape (params.k ));
329330 int prefetch_k = k_start_idx;
330331
332+ #if 0
333+ constexpr float VALUE0 = -1.0f;
334+ constexpr float VALUE1 = -0.6961928f;
335+ constexpr float VALUE2 = -0.52507305f;
336+ constexpr float VALUE3 = -0.39491749f;
337+ constexpr float VALUE4 = -0.28444138f;
338+ constexpr float VALUE5 = -0.18477343f;
339+ constexpr float VALUE6 = -0.09105004f;
340+ constexpr float VALUE7 = 0.0f;
341+ constexpr float VALUE8 = 0.0795803f;
342+ constexpr float VALUE9 = 0.1609302f;
343+ constexpr float VALUE10 = 0.2461123f;
344+ constexpr float VALUE11 = 0.33791524f;
345+ constexpr float VALUE12 = 0.44070983f;
346+ constexpr float VALUE13 = 0.562617f;
347+ constexpr float VALUE14 = 0.72295684f;
348+ constexpr float VALUE15 = 1.0f;
349+
350+ auto quant_map_alias = [&](uint8_t index) {
351+ switch(index) {
352+ case 0: return VALUE0;
353+ case 1: return VALUE1;
354+ case 2: return VALUE2;
355+ case 3: return VALUE3;
356+ case 4: return VALUE4;
357+ case 5: return VALUE5;
358+ case 6: return VALUE6;
359+ case 7: return VALUE7;
360+ case 8: return VALUE8;
361+ case 9: return VALUE9;
362+ case 10: return VALUE10;
363+ case 11: return VALUE11;
364+ case 12: return VALUE12;
365+ case 13: return VALUE13;
366+ case 14: return VALUE14;
367+ case 15: return VALUE15;
368+ }
369+ };
370+ #endif
371+
331372#if 0 //SLM
332373 #if 1
333374 auto dequant = [&] (int k_tile) {
@@ -429,6 +470,7 @@ if(thread_idx==0 && m_coord==0 && n_coord==0 && l_coord==0) {
429470 uint8_t bit_value = (src_value >> (4 * (((c + 1 ) & 1 ) + (c >> 1 ) * 2 ))) & 0xF ;
430471 float scale_value = fragment_scale (n * (BLK_K / GROUP_SIZE ) + (dst_base_idx + c) / GROUP_SIZE );
431472 dst[dst_base_idx + c] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
473+ // dst[dst_base_idx + c] = static_cast<ElementMMA>(quant_map_alias(bit_value) * scale_value);
432474#if 0
433475if(thread_idx==60 && m_coord==0 && n_coord==0 && l_coord==0){
434476 printf("tid = %d, m_coord = %d, n_coord = %d, l_coord = %d, n = %d, src_l = %d, dst_dx = %d, scale_idx = %d, scale_value = %f\n", thread_idx, m_coord, n_coord, l_coord, n, l, dst_base_idx+c, n * (BLK_K / GROUP_SIZE) + (dst_base_idx+c)/GROUP_SIZE, scale_value);
0 commit comments