@@ -36,7 +36,8 @@ using namespace cutlass::gemm;
3636
3737// Define Basic information
3838// Weight-only-quant (B)
39- using MmaType = sycl::ext::oneapi::bfloat16; // cutlass::bfloat16_t;
39+ // using MmaType = sycl::ext::oneapi::bfloat16;
40+ using MmaType = cutlass::bfloat16_t ;
4041using QuantType = cutlass::uint4_t ; // NF4,FP4
4142
4243using ElementA = MmaType;
@@ -130,8 +131,11 @@ using ClusterShape = typename DispatchPolicy::ClusterShape;
130131using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
131132using CopyThreadShapeRev = decltype (cute::reverse(CopyThreadShape{}));
132133
133- // using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
134+ #if 0
135+ using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
136+ #else
134137using GmemTiledCopyA = XE_2D_U16x16x32_LD_N;
138+ #endif
135139using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
136140using traits_load_A = Copy_Traits<GmemTiledCopyA, StrideA>;
137141using atom_load_A = Copy_Atom<traits_load_A, ElementA>;
@@ -279,7 +283,7 @@ class gemm_4bit_cutlass_kernel {
279283 }
280284#endif
281285#else
282- using format_type = int ; // 32
286+ using format_type = uint32_t ; // 32
283287 static constexpr auto src_bits = sizeof_bits_v<SrcType>; // 4
284288 static constexpr auto scalar = sizeof_bits_v<format_type> / src_bits; // 8
285289 static constexpr auto loop_cnt = decltype (size (out))::value / N; // 128 / 2 = 64
@@ -308,12 +312,14 @@ class gemm_4bit_cutlass_kernel {
308312
309313 CUTLASS_PRAGMA_UNROLL
310314 for (int i = 0 ; i < vec_size/2 ; i++) {
311- #if 1
315+ #if 0
312316 dst[i * 2] = static_cast<DstType>(1.0f * ts);
313317 dst[i * 2 + 1] = static_cast<DstType>(1.0f * ts);
314318#else
315319 dst[i * 2 ] = static_cast <DstType>(quant_map[(format_data >> (src_bits * (i * 2 + 1 ))) & 0xf ] * ts);
316320 dst[i * 2 + 1 ] = static_cast <DstType>(quant_map[(format_data >> (src_bits * (i * 2 ))) & 0xf ] * ts);
321+ // dst[i * 2] = quant_map[(format_data >> (src_bits * (i * 2 + 1))) & 0xf] * ts;
322+ // dst[i * 2 + 1] = quant_map[(format_data >> (src_bits * (i * 2))) & 0xf] * ts;
317323#endif
318324 }
319325 }
0 commit comments