@@ -186,7 +186,7 @@ class gemm_4bit_cutlass_kernel {
186186 ? BlockIdxX () : BlockIdxY ();
187187 const int l_coord = BlockIdxZ ();
188188
189- #if 1
189+ #if 0
190190 float* quant_map;
191191 {
192192 // Load Dequatize LUT and save to SLM, 16 for 4bits
@@ -274,6 +274,7 @@ class gemm_4bit_cutlass_kernel {
274274 return ((quant_idx / 7.5f) - 1.0f) * scale; // 7.5=15/2 (4-bit)
275275 };
276276#endif
277+ #if 0
277278 auto dequant = [&] {
278279 constexpr int N = decltype(cute::size<1>(mma_B))::value;
279280 constexpr int K = decltype(cute::size(mma_B))::value / N;
@@ -284,23 +285,100 @@ class gemm_4bit_cutlass_kernel {
284285
285286 //if(cute::thread0()) printf("N = %d, K = %d, compress_size = %d, vec_size = %d\n", N, K, compress_size, vec_size);
286287 compress_type src[vec_size];
287- ElementMMA dst[K ];
288+ reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0 ];
288289
289290 float scale_value = fragment_scale(0);
290291
291- reinterpret_cast <sycl::vec<compress_type, vec_size>*>(src)[ 0 ] = reinterpret_cast <sycl::vec<compress_type, vec_size >*>(cute::raw_pointer_cast (dequant_frag. data ()))[ 0 ] ;
292+ auto* dst = reinterpret_cast<sycl::vec<int64_t, 16 >*>(&smem_buf[thread_idx * decltype( cute::size(mma_B))::value * 2]) ;
292293
293294 #pragma unroll
294295 for (int i = 0; i < vec_size; i++) {
296+ //compress_type src = src_[i];//(*src_).get(i);
297+
295298 #pragma unroll
296- for (int j = 0 ; j < compress_size; j++) {
297- uint8_t bit_value = (src[i] >> (4 * ((j+1 )%2 + (j/2 )*2 ))) & 0xf ;
298- dst[i*compress_size+j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
299- // dst[i*compress_size+j] = static_cast<ElementMMA>(convert(bit_value, scale_value));
299+ for (int j = 0; j < compress_size/2; j++) {
300+ uint8_t high = (src[i]>> (4 * (j * 2 + 1))) & 0xf;
301+ uint8_t low = (src[i] >> (4 * (j * 2))) & 0xf;
302+ dst[0][i*compress_size+j*2] = static_cast<ElementMMA>(quant_map[high] * scale_value);
303+ dst[0][i*compress_size+j*2+1] = static_cast<ElementMMA>(quant_map[low] * scale_value);
300304 }
301305 }
302306 reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
303- };
307+ #else
308+ #if 0
309+ auto dequant = [&] {
310+ constexpr int N = decltype(cute::size<1>(mma_B))::value;
311+ constexpr int K = decltype(cute::size(mma_B))::value / N;
312+ using compress_type = uint32_t;
313+ constexpr int compress_size = cute::sizeof_bits_v<compress_type> / cute::sizeof_bits_v<ElementB>;
314+ constexpr int vec_size = K / compress_size;
315+
316+ compress_type src[vec_size];
317+ reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<sycl::vec<compress_type, vec_size>*>(cute::raw_pointer_cast(dequant_frag.data()))[0];
318+
319+ const int tid = thread_idx;
320+ constexpr int BANK_NUM = 32;
321+ constexpr int ELEMS_PER_THREAD = vec_size * compress_size;
322+ constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1) / BANK_NUM;
323+
324+ ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(smem_buf) + tid * BANK_NUM * ELEMS_PER_BANK;
325+ //auto* private_slm = reinterpret_cast<sycl::vec<int64_t, 16>*>(&smem_buf[thread_idx * BANK_NUM * ELEMS_PER_BANK * 2]);
326+ //if(cute::thread0()) printf("ELEMS_PER_THREAD = %d, ELEMS_PER_BANK = %d\n", ELEMS_PER_THREAD, ELEMS_PER_BANK);
327+ float scale_value = fragment_scale(0);
328+ #pragma unroll
329+ for (int i = 0; i < vec_size; i++) {
330+ #pragma unroll
331+ for (int j = 0; j < compress_size; j++) {
332+ uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
333+
334+ const int linear_idx = i * compress_size + j;
335+ const int bank = linear_idx % BANK_NUM;
336+ const int offset = linear_idx / BANK_NUM;
337+ //if(cute::thread0()) printf("i = %d, j = %d, linear_idx = %d, bank = %d, offset = %d, bank * ELEMS_PER_BANK + offset = %d\n",i,j,linear_idx,bank,offset, bank * ELEMS_PER_BANK + offset);
338+
339+ private_slm[bank * ELEMS_PER_BANK + offset] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
340+ }
341+ }
342+
343+ reinterpret_cast<sycl::vec<uint64_t, 16>*>(&mma_B)[0] = *reinterpret_cast<sycl::vec<uint64_t, 16>*>(private_slm);
344+ };
345+ #endif
346+ auto dequant = [&] {
347+ constexpr int N = decltype (cute::size<1 >(mma_B))::value;
348+ constexpr int K = decltype (cute::size (mma_B))::value / N;
349+
350+ using compress_type = uint32_t ;
351+ constexpr int compress_size = 32 / cute::sizeof_bits_v<ElementB>;
352+ constexpr int vec_size = K / compress_size;
353+
354+ constexpr int BANK_NUM = 32 ; // Intel SLM bank 数
355+ constexpr int ELEMS_PER_THREAD = vec_size * compress_size; // 64
356+ constexpr int ELEMS_PER_BANK = (ELEMS_PER_THREAD + BANK_NUM - 1 ) / BANK_NUM ; // 2
357+
358+ compress_type src[vec_size];
359+ *reinterpret_cast <sycl::vec<compress_type, vec_size>*>(src) =
360+ *reinterpret_cast <const sycl::vec<compress_type, vec_size>*>(
361+ cute::raw_pointer_cast (dequant_frag.data ()));
362+
363+ const int tid = thread_idx;
364+ ElementMMA* private_slm = reinterpret_cast <ElementMMA*>(smem_buf) + tid * ELEMS_PER_THREAD ; // 每个线程一段 **连续** 128 B,天然 128 B 对齐
365+
366+ float scale_value = fragment_scale (0 );
367+
368+ #pragma unroll
369+ for (int i = 0 ; i < vec_size; ++i) {
370+ #pragma unroll
371+ for (int j = 0 ; j < compress_size; ++j) {
372+ uint8_t bit_value = (src[i] >> (4 * (((j+1 ) & 1 ) + (j >> 1 ) * 2 ))) & 0xF ;
373+ // uint8_t bit_value = (src[i] >> (4 * ((j+1)%2 + (j/2)*2))) & 0xf;
374+ private_slm[i * compress_size + j] =
375+ static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
376+ }
377+ }
378+
379+ *reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ())) = *reinterpret_cast <const sycl::vec<int64_t , 16 >*>(private_slm);
380+ };
381+ #endif
304382
305383 CUTLASS_PRAGMA_UNROLL
306384 for (int i = 0 ; i < DispatchPolicy::Stages; i++, prefetch_k++) {
@@ -351,7 +429,9 @@ void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B,
351429
352430 using GemmKernel = gemm_4bit_cutlass_kernel<T, BITS >;
353431
354- static constexpr int smem_size= (16 +1 )*32 /8 ;
432+ static constexpr int smem_size= BLK_N * BLK_K * 16 /8 ; // (16+1)*32/8;
433+ size_t max_slm_size = q.get_device ().get_info <sycl::info::device::local_mem_size>();
434+ assert (smem_size <= max_slm_size);
355435
356436 auto problem_size = ProblemShape{m, n, k, l};
357437
0 commit comments