@@ -275,6 +275,10 @@ class gemm_4bit_cutlass_kernel {
275275 int prefetch_k = k_start_idx;
276276
277277#if 1 // SLM
278+ // alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (64 * 4) * k_tile_count;
279+ // const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * 1) * params.k/2;
280+ // //using total_vec = 4*k_tile_count;
281+ // reinterpret_cast<sycl::vec<uint64_t, 16>*>(slm_B)[0] = reinterpret_cast<const sycl::vec<uint64_t, 16>*>(gB_ptr)[0];
278282 #if 1
279283 auto dequant = [&] (int k_tile) {
280284 constexpr int N = decltype (cute::size<1 >(mma_B))::value;
@@ -291,8 +295,8 @@ class gemm_4bit_cutlass_kernel {
291295 compress_type src[vec_size];
292296 reinterpret_cast <sycl::vec<compress_type, vec_size>*>(src)[0 ] = reinterpret_cast <const sycl::vec<compress_type, vec_size>*>(slm_B)[0 ];
293297
294- // ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(slm_B); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
295- ElementMMA dst[K];
298+ ElementMMA* private_slm = reinterpret_cast <ElementMMA*>(slm_B); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
299+ // ElementMMA dst[K];
296300
297301 float scale_value = fragment_scale (0 );
298302
@@ -301,13 +305,15 @@ class gemm_4bit_cutlass_kernel {
301305 #pragma unroll
302306 for (int j = 0 ; j < compress_size; ++j) {
303307 uint8_t bit_value = (src[i] >> (4 * (((j+1 ) & 1 ) + (j >> 1 ) * 2 ))) & 0xF ;
304- // private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
305- dst[i*compress_size+j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
308+ private_slm[i * compress_size + j] = static_cast <ElementMMA>(quant_map[bit_value] * scale_value);
309+ // dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
306310 }
307311 }
308-
309- // reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm)[0];
310- reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ()))[0 ] = reinterpret_cast <sycl::vec<int64_t , 16 >*>(dst)[0 ];
312+
313+ for (int i=0 ; i<K/4 /16 ; i++){
314+ reinterpret_cast <sycl::vec<int64_t , 16 >*>(cute::raw_pointer_cast (mma_B.data ()))[i] = reinterpret_cast <const sycl::vec<int64_t , 16 >*>(private_slm)[i];
315+ // reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[i];
316+ }
311317 };
312318 #endif
313319#else //register
0 commit comments