Skip to content

Commit 245c354

Browse files
committed
change MMA trait
1 parent 3706871 commit 245c354

13 files changed

Lines changed: 133 additions & 31 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _(
188188
shape = (*A.shape[:-1], shapeB[0])
189189
#import pdb
190190
#pdb.set_trace()
191-
out = torch.zeros(shape, device=A.device, dtype=torch.float32)
191+
out = torch.zeros(shape, device=A.device, dtype=torch.bfloat16)
192192
_gemv_4bit_impl(A, B, shapeB, absmax.bfloat16(), code, blocksize, out=out)
193193
return out
194194

csrc/pythonInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void gemv_4bit_inference_fp16(
381381

382382
#if 1
383383
void gemm_4bit_inference_bf16(
384-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float * out,
384+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype, sycl::ext::oneapi::bfloat16 * out,
385385
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
386386
) {
387387
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
@@ -827,7 +827,7 @@ void cgemv_4bit_inference_fp16(
827827
#if 1
828828
void cgemv_4bit_inference_bf16(
829829
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype,
830-
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
830+
sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
831831
) {
832832
gemm_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
833833
}

csrc/xpu_cutlass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
109109

110110
template <typename T, int BITS>
111111
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
112-
T *absmax, float *datatype, float *out, int lda,
112+
T *absmax, float *datatype, T *out, int lda,
113113
int ldb, int ldc, int blocksize, sycl::queue *stream);
114114

115115
template <typename T, int BITS>

csrc/xpu_cutlass_fusion.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ using namespace cutlass::gemm;
4040

4141
// Define Basic information
4242
//Weight-only-quant (B)
43-
using MmaType = cutlass::bfloat16_t;
43+
using MmaType = sycl::ext::oneapi::bfloat16; //cutlass::bfloat16_t;
4444
using QuantType = cutlass::uint4_t; //NF4,FP4
4545

4646
using ElementA = MmaType; //bfloat16_t;
@@ -50,18 +50,23 @@ using ElementMMA = ElementA;
5050
using ElementQuant = QuantType;
5151
using ElementScale = MmaType; //sycl::ext::oneapi::bfloat16; //MmaType;
5252

53-
using ElementC = float;
54-
using ElementD = float;
55-
using ElementAccumulator = float; // data_type of accumulator
56-
using ElementComputeEpilogue = float; // data_type of epilogue operations
57-
using ElementOutput = float;
53+
using ElementAccumulator = MmaType; // data_type of accumulator
54+
using ElementComputeEpilogue = MmaType; // data_type of epilogue operations
55+
using ElementOutput = MmaType;
5856

5957
using ProblemShape = Shape<int, int, int, int>;
6058

59+
#if 1
6160
using TileShape = Shape<_256, _256, _32>;
6261
using TiledMma =
63-
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
62+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_BF16BF16BF16BF16_TT>, Layout<TileShape>,
6463
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
64+
#else
65+
using TileShape = Shape<_16, _64, _64>;
66+
using TiledMma =
67+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
68+
Layout<Shape<_1, _2, _1>, Stride<_2, _1, _0>>>::TiledMMA;
69+
#endif
6570

6671
using WorkgroupTileShape = TileShape;
6772
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); //256 //16
@@ -94,7 +99,8 @@ static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; // 16
9499

95100
// Design Epilogue
96101
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
97-
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
102+
//constexpr int kAlignment = 128 / sizeof(ElementOutput);
103+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
98104
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
99105
using SharedStorage = FusionCallBacks::SharedStorage;
100106

@@ -115,9 +121,9 @@ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
115121
ElementOutput,
116122
cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>, // Convert CUTLASS 2.x to CUTLASS 3.x representation
117123
FusionCallBacks,
118-
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
124+
XE_2D_U16x8x16_LD_N, // The copy atom used to load matrix C
119125
void, void,
120-
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
126+
XE_2D_U16x8x16_ST_N, // The copy atom used to store matrix D
121127
void, void>;
122128
using EpilogueParams = typename CollectiveEpilogue::Params;
123129

@@ -166,7 +172,7 @@ class kgemm_4bit_inference_cutlass_dequant {
166172
int m, n, k;
167173
T* A;
168174
uint8_t* B;
169-
float* out;
175+
T* out;
170176
float *datatype; //LUT
171177
int group_size;
172178

@@ -279,7 +285,7 @@ class kgemm_4bit_inference_cutlass_dequant {
279285

280286
T* A = params.A;
281287
uint8_t* B = params.B;
282-
float* out = params.out;
288+
T* out = params.out;
283289
float* datatype = params.datatype;
284290

285291
auto tiled_copy_a = params.tiled_copy_a;
@@ -544,7 +550,7 @@ printf("\n");
544550

545551
template <typename T, int BITS>
546552
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
547-
T *absmax_, float *datatype, float *out, int lda,
553+
T *absmax_, float *datatype, T *out, int lda,
548554
int ldb, int ldc, int blocksize, sycl::queue *stream) {
549555
//std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
550556

@@ -593,8 +599,8 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
593599
cutlass::KernelHardwareInfo hw_info;
594600
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
595601
auto problem_shape_MNKL = problem_size; //append<4>(problem_size, 1);
596-
float alpha=1.0f;
597-
float beta=0.f;
602+
T alpha=1.0f;
603+
T beta=0.f;
598604
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
599605
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
600606

@@ -649,6 +655,6 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
649655

650656
template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(
651657
int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B,
652-
sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float *out, int lda,
658+
sycl::ext::oneapi::bfloat16 *absmax, float *datatype, sycl::ext::oneapi::bfloat16 *out, int lda,
653659
int ldb, int ldc, int blocksize, sycl::queue *stream);
654660

include/cute/algorithm/copy.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ copy(Copy_Atom<CopyArgs...> const& copy_atom,
185185
Tensor dst_v = group_modes<1,R>(dst);
186186

187187
if constexpr (is_static<decltype(shape(src_v))>::value && is_static<decltype(shape(dst_v))>::value) {
188+
#if 0
189+
if(cute::thread0()){
190+
print("src_v : "); print(src_v); print("\n");
191+
print("dst_v : "); print(dst_v); print("\n");
192+
}
193+
#endif
188194
CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v));
189195

190196
// AutoFilter on the Rest-mode

include/cute/arch/mma_xe.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16(cut
4545
SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short4 a, cute::intel::int8 b, cute::intel::float4 acc));
4646
SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short2 a, cute::intel::int8 b, cute::intel::float2 acc));
4747
SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::intel::int8 b, float acc));
48+
// mma_bfloat16 with bfloat16 accumulator:
49+
SYCL_EXTERNAL cute::intel::short8 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::short8 acc);
50+
SYCL_EXTERNAL cute::intel::short4 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short4 a, cute::intel::int8 b, cute::intel::short4 acc);
51+
SYCL_EXTERNAL cute::intel::short2 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short2 a, cute::intel::int8 b, cute::intel::short2 acc);
52+
SYCL_EXTERNAL short intel_sub_group_bf16_bf16_matrix_mad_k16( short a, cute::intel::int8 b, short acc);
4853
// mma_half
4954
SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_f16_f16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc));
5055
SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_f16_f16_matrix_mad_k16(cute::intel::short4 a, cute::intel::int8 b, cute::intel::float4 acc));
@@ -155,6 +160,26 @@ struct XE_1x16x16_F32BF16BF16F32_TT
155160
}
156161
};
157162

163+
struct XE_8x16x16_BF16BF16BF16BF16_TT
164+
{
165+
using DRegisters = intel::short8[1];
166+
using ARegisters = intel::short8[1];
167+
using BRegisters = intel::int8[1];
168+
using CRegisters = intel::short8[1];
169+
170+
CUTE_HOST_DEVICE static void
171+
fma(intel::short8 & d,
172+
intel::short8 const& a,
173+
intel::int8 const& b,
174+
intel::short8 const& c)
175+
{
176+
#if defined(SYCL_INTEL_TARGET)
177+
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
178+
#else
179+
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_BF16BF16BF16BF16_TT on non-PVC hardware");
180+
#endif
181+
}
182+
};
158183
//MxNxK_D,A,B,C
159184
//# of vector component of a x subgroup-size x function name
160185
//float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, int8 acc);

include/cute/arch/mma_xe_builtin.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ SYCL_EXTERNAL cute::intel::short4 intel_sub_group_bf16_bf16_matrix_mad_k16(cute:
6262
SYCL_EXTERNAL cute::intel::short2 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short2 a, cute::intel::int8 b, cute::intel::short2 acc);
6363
SYCL_EXTERNAL short intel_sub_group_bf16_bf16_matrix_mad_k16( short a, cute::intel::int8 b, short acc);
6464

65+
// Use the spirv functions as the builtins do not work
66+
SYCL_EXTERNAL cute::intel::half8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short8, cute::intel::int8, cute::intel::half8, int32_t);
67+
SYCL_EXTERNAL cute::intel::half4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short4, cute::intel::int8, cute::intel::half4, int32_t);
68+
SYCL_EXTERNAL cute::intel::half2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short2, cute::intel::int8, cute::intel::half2, int32_t);
69+
SYCL_EXTERNAL cute::intel::half __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, short, cute::intel::int8, cute::intel::half, int32_t);
70+
71+
struct SPIRV_MMAOperands {
72+
static constexpr int SPIRV_MatrixAFp16 = 0x400;
73+
static constexpr int SPIRV_MatrixBFp16 = 0x800;
74+
};
75+
6576
namespace cute::detail
6677
{
6778

@@ -97,6 +108,16 @@ struct XeSubgroupMatrixMultiplyAccumulate<bfloat16_t, bfloat16_t, bfloat16_t, bf
97108
}
98109
};
99110

111+
template<>
112+
struct XeSubgroupMatrixMultiplyAccumulate<half_t, half_t, half_t, half_t> {
113+
template<typename ARegisters, typename BRegisters, typename CRegisters>
114+
CUTE_HOST_DEVICE
115+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
116+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(16, a, b, c,
117+
SPIRV_MMAOperands::SPIRV_MatrixAFp16 | SPIRV_MMAOperands::SPIRV_MatrixBFp16);
118+
}
119+
};
120+
100121
template<>
101122
struct XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t> {
102123
template<typename ARegisters, typename BRegisters, typename CRegisters>

include/cute/arch/mma_xe_spirv.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ SYCL_EXTERNAL cute::intel::short4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(
5656
SYCL_EXTERNAL cute::intel::short2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short2, cute::intel::int8, cute::intel::short2, int32_t);
5757
SYCL_EXTERNAL short __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, short, cute::intel::int8, short, int32_t);
5858

59+
SYCL_EXTERNAL cute::intel::half8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short8, cute::intel::int8, cute::intel::half8, int32_t);
60+
SYCL_EXTERNAL cute::intel::half4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short4, cute::intel::int8, cute::intel::half4, int32_t);
61+
SYCL_EXTERNAL cute::intel::half2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, cute::intel::short2, cute::intel::int8, cute::intel::half2, int32_t);
62+
SYCL_EXTERNAL cute::intel::half __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int32_t, short, cute::intel::int8, cute::intel::half, int32_t);
63+
5964
struct SPIRV_MMAOperands {
6065
static constexpr int SPIRV_MatrixASigned = 0x1;
6166
static constexpr int SPIRV_MatrixBSigned = 0x2;
@@ -109,6 +114,16 @@ struct XeSubgroupMatrixMultiplyAccumulate<bfloat16_t, bfloat16_t, bfloat16_t, bf
109114
}
110115
};
111116

117+
template<>
118+
struct XeSubgroupMatrixMultiplyAccumulate<half_t, half_t, half_t, half_t> {
119+
template<typename ARegisters, typename BRegisters, typename CRegisters>
120+
CUTE_HOST_DEVICE
121+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
122+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(16, a, b, c,
123+
SPIRV_MMAOperands::SPIRV_MatrixAFp16 | SPIRV_MMAOperands::SPIRV_MatrixBFp16);
124+
}
125+
};
126+
112127
template<>
113128
struct XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t> {
114129
template<typename ARegisters, typename BRegisters, typename CRegisters>

include/cute/atom/mma_traits_xe.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@
3838
namespace cute
3939
{
4040
template <>
41+
struct MMA_Traits<XE_8x16x16_BF16BF16BF16BF16_TT>
42+
{
43+
using ValTypeD = bfloat16_t;
44+
using ValTypeA = bfloat16_t;
45+
using ValTypeB = bfloat16_t;
46+
using ValTypeC = bfloat16_t;
47+
48+
using Shape_MNK = Shape<_8,_16,_16>;
49+
using ThrID = Layout<_16>;
50+
51+
using ALayout = Layout<Shape<_16, _8>, Stride<_8, _1>>;
52+
using BLayout = Layout<Shape<_16, _16>, Stride<_1, _16>>;
53+
using CLayout = Layout<Shape<_16, _8>, Stride<_8, _1>>;
54+
};
55+
template <>
4156
struct MMA_Traits<XE_8x16x16_F32BF16BF16F32_TT>
4257
{
4358
using ValTypeD = float;

include/cute/util/sycl_vec.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,18 @@ using uchar2 = vector_t<uchar, 2>;
5252
using uchar4 = vector_t<uchar, 4>;
5353
using uchar8 = vector_t<uchar, 8>;
5454
using uchar16 = vector_t<uchar, 16>;
55+
using uchar32 = vector_t<uchar, 32>;
56+
using uchar64 = vector_t<uchar, 64>;
5557

5658
using float2 = vector_t<float, 2>;
5759
using float4 = vector_t<float, 4>;
5860
using float8 = vector_t<float, 8>;
5961

62+
using half = _Float16;
63+
using half2 = vector_t<_Float16, 2>;
64+
using half4 = vector_t<_Float16, 4>;
65+
using half8 = vector_t<_Float16, 8>;
66+
6067
using short2 = vector_t<short, 2>;
6168
using short4 = vector_t<short, 4>;
6269
using short8 = vector_t<short, 8>;

0 commit comments

Comments
 (0)