Skip to content

Commit fb9106d

Browse files
committed
refine code
1 parent 62190ab commit fb9106d

1 file changed

Lines changed: 4 additions & 5 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class kgemm_4bit_inference_cutlass_dequant {
228228
using SrcType = typename EngineIn::value_type;
229229
using DstType = typename EngineOut::value_type;
230230
//using ScaleType = typename EngineScales::value_type;
231-
#if 1
231+
#if 0
232232
int numbers = decltype(size(in))::value;
233233
for(int i=0; i<numbers; i++){
234234
//auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
@@ -259,8 +259,8 @@ class kgemm_4bit_inference_cutlass_dequant {
259259
auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
260260
auto d_tensor = make_tensor(out.data(), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
261261

262-
if(cute::thread0())
263-
printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
262+
//if(cute::thread0())
263+
// printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
264264

265265
for (int n = 0; n < N; n++) {
266266
//const auto ts = tCrS_input(n);
@@ -276,8 +276,7 @@ if(cute::thread0())
276276
for (int i = 0; i < vec_size; i++) {
277277
uint8_t value = (format_data >> (src_bits * i)) & 0xf;
278278
dst[i] = (static_cast<DstType>(quant_map[value]));// * ts;
279-
//if(cute::thread0())
280-
printf("n = %d, s = %d, i = %d, src = %d, dst = %f\n", n, s, i, static_cast<int>(value), static_cast<float>(dst[i]));
279+
//if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, dst = %f\n", n, s, i, static_cast<int>(value), static_cast<float>(dst[i]));
281280
}
282281
}
283282
}

0 commit comments

Comments
 (0)