Skip to content

Commit 143b91e

Browse files
committed
save code
1 parent 9dc75fc commit 143b91e

1 file changed

Lines changed: 25 additions & 13 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -339,27 +339,39 @@ inline float dDequantizeNF4(unsigned char val) {
339339
constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size;
340340

341341
alignas(16) ElementB* src = reinterpret_cast<ElementB*>(smem_buf) + thread_idx * (K * 4); //for K=64, 4 is hardcode for 128B alignment.
342-
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k/2 + k_tile * BLK_K/2;
343-
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[0];
342+
const uint8_t* gB_ptr = params.B + (n_coord * BLK_N + thread_idx * N) * params.k / 2 + k_tile * BLK_K / 2;
343+
//reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[0];
344344

345345

346-
ElementMMA* private_slm = reinterpret_cast<ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
347-
348-
float scale_value = fragment_scale(0);
346+
ElementMMA* dst_slm = reinterpret_cast<ElementMMA*>(src + K); // reuse src SLM buffer, for K=64, 每个线程一段 连续 128 B,天然 128 B 对齐
349347

348+
#pragma unroll
349+
for (int n = 0; n < N; n++) {
350+
float scale_value = fragment_scale(n);
350351
#pragma unroll
351-
for (int i = 0; i < src_vec_size; ++i) {
352-
src_compress_type src_value = reinterpret_cast<src_compress_type*>(src)[i];
352+
for (int l = 0; l < src_loop_num; l++) {
353+
reinterpret_cast<sycl::vec<src_compress_type, src_vec_size>*>(src)[0] = reinterpret_cast<const sycl::vec<src_compress_type, src_vec_size>*>(gB_ptr)[n*src_loop_num + l];
354+
#pragma unroll
355+
for (int v = 0; v < src_vec_size; ++v) {
356+
src_compress_type src_value = reinterpret_cast<src_compress_type*>(src)[v];
357+
int dst_idx = v * src_compress_size;
353358
#pragma unroll
354-
for (int j = 0; j < src_compress_size; ++j) {
355-
uint8_t bit_value = (src_value >> (4 * (((j+1) & 1) + (j >> 1) * 2))) & 0xF;
356-
private_slm[i * src_compress_size + j] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
359+
for (int c = 0; c < src_compress_size; ++c) {
360+
uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF;
361+
dst_slm[dst_idx + c] = static_cast<ElementMMA>(quant_map[bit_value] * scale_value);
357362
}
363+
}
358364
}
359-
360-
for(int i=0; i<K/4/16; i++){
361-
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(private_slm)[i];
365+
366+
#pragma unroll
367+
for (int l = 0; l < dst_loop_num; l++) {
368+
reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(dst_slm)[0];
362369
}
370+
}
371+
372+
//for(int i=0; i<K/4/16; i++){
373+
// reinterpret_cast<sycl::vec<dst_compress_type, dst_vec_size>*>(cute::raw_pointer_cast(mma_B.data()))[i] = reinterpret_cast<const sycl::vec<dst_compress_type, dst_vec_size>*>(private_slm)[i];
374+
//}
363375
};
364376
#endif
365377
#else //register

0 commit comments

Comments
 (0)