Skip to content

Commit c634424

Browse files
committed
save code
1 parent 10090c5 commit c634424

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ class gemm_4bit_cutlass_kernel {
275275
int prefetch_k = k_start_idx;
276276

277277
#if 1 //SLM
278+
//alignas(16) ElementB* slm_B = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (64 * 4) * k_tile_count;
279+
//const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * 1) * params.k/2;
280+
////using total_vec = 4*k_tile_count;
281+
//reinterpret_cast<sycl::vec<uint64_t, 16>*>(slm_B)[0] = reinterpret_cast<const sycl::vec<uint64_t, 16>*>(gB_ptr)[0];
278282
#if 1
279283
auto dequant = [&] (int k_tile) {
280284
constexpr int N = decltype(cute::size<1>(mma_B))::value;
@@ -291,8 +295,8 @@ class gemm_4bit_cutlass_kernel {
291295
compress_type src[vec_size];
292296
reinterpret_cast<sycl::vec<compress_type, vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<compress_type, vec_size>*>(slm_B)[0];
293297

294-
//ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(slm_B); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
295-
ElementMMA dst[K];
298+
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(slm_B); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
299+
//ElementMMA dst[K];
296300

297301
float scale_value = fragment_scale(0);
298302

@@ -301,13 +305,15 @@ class gemm_4bit_cutlass_kernel {
301305
#pragma unroll
302306
for (int j = 0; j < compress_size; ++j) {
303307
uint8_t bit_value = (src[i] >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
304-
//private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
305-
dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
308+
private_slm[i * compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
309+
//dst[i*compress_size+j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
306310
}
307311
}
308-
309-
//reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm)[0];
310-
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[0] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[0];
312+
313+
for(int i=0; i<K/4/16; i++){
314+
reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<const sycl::vec<int64_t, 16>*>(private_slm)[i];
315+
//reinterpret_cast<sycl::vec<int64_t, 16>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<sycl::vec<int64_t, 16>*>(dst)[i];
316+
}
311317
};
312318
#endif
313319
#else //register

0 commit comments

Comments
 (0)