@@ -228,7 +228,7 @@ class kgemm_4bit_inference_cutlass_dequant {
228228 using SrcType = typename EngineIn::value_type;
229229 using DstType = typename EngineOut::value_type;
230230 // using ScaleType = typename EngineScales::value_type;
231- #if 1
231+ #if 0
232232 int numbers = decltype(size(in))::value;
233233 for(int i=0; i<numbers; i++){
234234 //auto in_ptr_8 = (uint8_t*)(raw_pointer_cast(in.data()));
@@ -259,8 +259,8 @@ class kgemm_4bit_inference_cutlass_dequant {
259259 auto s_tensor = make_tensor ((format_type*)(raw_pointer_cast (in.data ())), Shape<Int<loop_cnt / scalar>, Int<N>>{});
260260 auto d_tensor = make_tensor (out.data (), Shape<Int<vec_size>, Int<splits>, Int<N>>{});
261261
262- if(cute::thread0())
263- printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
262+ // if(cute::thread0())
263+ // printf("thread_idx = %d, decltype(size(in))::value = %d, K = %d, N = %d, L = %d, src_bits = %d, sizeof_bits_v<format_type> = %d, scalar = %d, decltype(size(out))::value = %d, loop_cnt = %d, splits = %d\n",int(ThreadIdxX()), decltype(size(in))::value, decltype(size<0>(in))::value, N, decltype(size<2>(in))::value, src_bits, sizeof_bits_v<format_type>, scalar, decltype(size(out))::value, loop_cnt, splits);
264264
265265 for (int n = 0 ; n < N; n++) {
266266 // const auto ts = tCrS_input(n);
@@ -276,8 +276,7 @@ if(cute::thread0())
276276 for (int i = 0 ; i < vec_size; i++) {
277277 uint8_t value = (format_data >> (src_bits * i)) & 0xf ;
278278 dst[i] = (static_cast <DstType>(quant_map[value]));// * ts;
279- //if(cute::thread0())
280- printf("n = %d, s = %d, i = %d, src = %d, dst = %f\n", n, s, i, static_cast<int>(value), static_cast<float>(dst[i]));
279+ // if(cute::thread0()) printf("n = %d, s = %d, i = %d, src = %d, dst = %f\n", n, s, i, static_cast<int>(value), static_cast<float>(dst[i]));
281280 }
282281 }
283282 }
0 commit comments