Skip to content

Commit 92b1d24

Browse files
committed
add debug code
1 parent 82d692f commit 92b1d24

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

csrc/xpu_cutlass_fusion.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ using namespace cutlass::gemm;
3636

3737
// Define Basic information
3838
//Weight-only-quant (B)
39-
using MmaType = sycl::ext::oneapi::bfloat16; //cutlass::bfloat16_t;
39+
//using MmaType = sycl::ext::oneapi::bfloat16;
40+
using MmaType = cutlass::bfloat16_t;
4041
using QuantType = cutlass::uint4_t; //NF4,FP4
4142

4243
using ElementA = MmaType;
@@ -130,8 +131,11 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
130131
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
131132
using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{}));
132133

133-
//using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
134+
#if 0
135+
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
136+
#else
134137
using GmemTiledCopyA = XE_2D_U16x16x32_LD_N;
138+
#endif
135139
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
136140
using traits_load_A = Copy_Traits<GmemTiledCopyA, StrideA>;
137141
using atom_load_A = Copy_Atom<traits_load_A, ElementA>;
@@ -279,7 +283,7 @@ class gemm_4bit_cutlass_kernel {
279283
}
280284
#endif
281285
#else
282-
using format_type = int; //32
286+
using format_type = uint32_t; //32
283287
static constexpr auto src_bits = sizeof_bits_v<SrcType>; //4
284288
static constexpr auto scalar = sizeof_bits_v<format_type> / src_bits; // 8
285289
static constexpr auto loop_cnt = decltype(size(out))::value / N; // 128 / 2 = 64
@@ -308,12 +312,14 @@ class gemm_4bit_cutlass_kernel {
308312

309313
CUTLASS_PRAGMA_UNROLL
310314
for (int i = 0; i < vec_size/2; i++) {
311-
#if 1
315+
#if 0
312316
dst[i * 2] = static_cast<DstType>(1.0f * ts);
313317
dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
314318
#else
315319
dst[i * 2] = static_cast<DstType>(quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts);
316320
dst[i * 2 + 1] = static_cast<DstType>(quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts);
321+
//dst[i * 2] = quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts;
322+
//dst[i * 2 + 1] = quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts;
317323
#endif
318324
}
319325
}

0 commit comments

Comments
 (0)