Skip to content

Commit cb5bf3d

Browse files
Style-align SM120 FP4 MMA support
1 parent 33266ae commit cb5bf3d

5 files changed

Lines changed: 95 additions & 79 deletions

File tree

src/tl_templates/cuda/cuda_fp4.h

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) || \
66
(defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200))
77
#include <cuda_fp4.h>
8+
#include <cute/numeric/numeric_types.hpp>
89

910
// Wrapper for __nv_fp4_e2m1 with implicit conversions
1011
struct fp4_e2_t {
@@ -45,6 +46,12 @@ struct fp4_e2_t {
4546
TL_DEVICE operator __half() const { return __half(float(*this)); }
4647
};
4748

49+
namespace tl {
50+
template <> struct to_cute_type<::fp4_e2_t> {
51+
using type = cute::float_e2m1_t;
52+
};
53+
} // namespace tl
54+
4855
class fp4_e2_2_t {
4956
public:
5057
__nv_fp4x2_storage_t __x;
@@ -163,26 +170,27 @@ TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t(
163170
}
164171

165172
// Pack sixty-four fp4_e2_t values.
166-
template <typename... Args>
167-
TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(Args... args) {
168-
static_assert(sizeof...(Args) == 64,
169-
"make_fp4_e2_64_t expects 64 fp4 values");
170-
fp4_e2_t values[64] = {fp4_e2_t(args)...};
173+
TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(
174+
fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, fp4_e2_t x3, fp4_e2_t x4,
175+
fp4_e2_t x5, fp4_e2_t x6, fp4_e2_t x7, fp4_e2_t x8, fp4_e2_t x9,
176+
fp4_e2_t x10, fp4_e2_t x11, fp4_e2_t x12, fp4_e2_t x13, fp4_e2_t x14,
177+
fp4_e2_t x15, fp4_e2_t x16, fp4_e2_t x17, fp4_e2_t x18, fp4_e2_t x19,
178+
fp4_e2_t x20, fp4_e2_t x21, fp4_e2_t x22, fp4_e2_t x23, fp4_e2_t x24,
179+
fp4_e2_t x25, fp4_e2_t x26, fp4_e2_t x27, fp4_e2_t x28, fp4_e2_t x29,
180+
fp4_e2_t x30, fp4_e2_t x31, fp4_e2_t y0, fp4_e2_t y1, fp4_e2_t y2,
181+
fp4_e2_t y3, fp4_e2_t y4, fp4_e2_t y5, fp4_e2_t y6, fp4_e2_t y7,
182+
fp4_e2_t y8, fp4_e2_t y9, fp4_e2_t y10, fp4_e2_t y11, fp4_e2_t y12,
183+
fp4_e2_t y13, fp4_e2_t y14, fp4_e2_t y15, fp4_e2_t y16, fp4_e2_t y17,
184+
fp4_e2_t y18, fp4_e2_t y19, fp4_e2_t y20, fp4_e2_t y21, fp4_e2_t y22,
185+
fp4_e2_t y23, fp4_e2_t y24, fp4_e2_t y25, fp4_e2_t y26, fp4_e2_t y27,
186+
fp4_e2_t y28, fp4_e2_t y29, fp4_e2_t y30, fp4_e2_t y31) {
171187
fp4_e2_64_t result;
172-
result.x = make_fp4_e2_32_t(
173-
values[0], values[1], values[2], values[3], values[4], values[5],
174-
values[6], values[7], values[8], values[9], values[10], values[11],
175-
values[12], values[13], values[14], values[15], values[16], values[17],
176-
values[18], values[19], values[20], values[21], values[22], values[23],
177-
values[24], values[25], values[26], values[27], values[28], values[29],
178-
values[30], values[31]);
179-
result.y = make_fp4_e2_32_t(
180-
values[32], values[33], values[34], values[35], values[36], values[37],
181-
values[38], values[39], values[40], values[41], values[42], values[43],
182-
values[44], values[45], values[46], values[47], values[48], values[49],
183-
values[50], values[51], values[52], values[53], values[54], values[55],
184-
values[56], values[57], values[58], values[59], values[60], values[61],
185-
values[62], values[63]);
188+
result.x = make_fp4_e2_32_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11,
189+
x12, x13, x14, x15, x16, x17, x18, x19, x20, x21,
190+
x22, x23, x24, x25, x26, x27, x28, x29, x30, x31);
191+
result.y = make_fp4_e2_32_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11,
192+
y12, y13, y14, y15, y16, y17, y18, y19, y20, y21,
193+
y22, y23, y24, y25, y26, y27, y28, y29, y30, y31);
186194
return result;
187195
}
188196

src/tl_templates/cuda/gemm_mma.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,6 @@ using _X = Underscore;
4646
#include "cuda_fp8.h"
4747
#include <cute/arch/mma_sm120.hpp>
4848
#include <cute/arch/mma_sm80.hpp>
49-
namespace tl {
50-
template <> struct to_cute_type<fp4_e2_t> {
51-
using type = cute::float_e2m1_t;
52-
};
53-
} // namespace tl
5449
TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN)
5550
TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp4_e2_t, float, SM120_16x8x32_TN)
5651
TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp8_e4_t, float, SM120_16x8x32_TN)

src/tl_templates/cuda/instruction/mma.h

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ struct MmaDispatcher {
6969
}
7070
};
7171

72-
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
73-
NValue, KValue, TransAValue, TransBValue, \
74-
SaturateValue, ImplType) \
72+
#define TL_DEFINE_MMA_DISPATCHER_IMPL( \
73+
ATypeEnum, BTypeEnum, CTypeEnum, MValue, NValue, KValue, TransAValue, \
74+
TransBValue, SaturateValue, ShiftAValue, ShiftBValue, ImplType) \
7575
template <> \
7676
struct MmaDispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
7777
DataType::CTypeEnum, MValue, NValue, KValue, \
@@ -84,12 +84,46 @@ struct MmaDispatcher {
8484
static_assert( \
8585
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
8686
"tl::mma_sync requires matching accumulator/output regs"); \
87+
template <bool Shift, class Reg> \
88+
static TL_DEVICE Reg maybe_shift_fp4_reg(Reg reg) { \
89+
if constexpr (Shift) { \
90+
return reg << 2; \
91+
} else { \
92+
return reg; \
93+
} \
94+
} \
8795
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
8896
const BRegType *b, const CRegType *c) { \
89-
call_fma<Impl>(d, a, b, c); \
97+
if constexpr (ShiftAValue || ShiftBValue) { \
98+
ARegType as[Traits::kARegs]; \
99+
BRegType bs[Traits::kBRegs]; \
100+
_Pragma("unroll") for (int i = 0; i < Traits::kARegs; ++i) { \
101+
as[i] = maybe_shift_fp4_reg<ShiftAValue>(a[i]); \
102+
} \
103+
_Pragma("unroll") for (int i = 0; i < Traits::kBRegs; ++i) { \
104+
bs[i] = maybe_shift_fp4_reg<ShiftBValue>(b[i]); \
105+
} \
106+
call_fma<Impl>(d, as, bs, c); \
107+
} else { \
108+
call_fma<Impl>(d, a, b, c); \
109+
} \
90110
} \
91111
};
92112

113+
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
114+
NValue, KValue, TransAValue, TransBValue, \
115+
SaturateValue, ImplType) \
116+
TL_DEFINE_MMA_DISPATCHER_IMPL(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
117+
NValue, KValue, TransAValue, TransBValue, \
118+
SaturateValue, false, false, ImplType)
119+
120+
#define TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT( \
121+
ATypeEnum, BTypeEnum, CTypeEnum, MValue, NValue, KValue, TransAValue, \
122+
TransBValue, SaturateValue, ShiftAValue, ShiftBValue, ImplType) \
123+
TL_DEFINE_MMA_DISPATCHER_IMPL( \
124+
ATypeEnum, BTypeEnum, CTypeEnum, MValue, NValue, KValue, TransAValue, \
125+
TransBValue, SaturateValue, ShiftAValue, ShiftBValue, ImplType)
126+
93127
// FP16 inputs (TN layout: A row-major, B column-major)
94128
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true,
95129
false, cute::SM80_16x8x16_F16F16F16F16_TN)
@@ -154,14 +188,19 @@ using SM120_FP8_FP4_F32_TN =
154188
cute::SM120_16x8x32_TN<cute::float_e4m3_t, cute::float_e2m1_t, float>;
155189
using SM120_FP4_FP8_F32_TN =
156190
cute::SM120_16x8x32_TN<cute::float_e2m1_t, cute::float_e4m3_t, float>;
157-
TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, 16, 8, 32,
158-
false, true, false, SM120_FP4_FP4_F32_TN)
159-
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat4_e2m1fn, kFloat32, 16, 8, 32,
160-
false, true, false, SM120_FP8_FP4_F32_TN)
161-
TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat8_e4m3, kFloat32, 16, 8, 32,
162-
false, true, false, SM120_FP4_FP8_F32_TN)
163-
191+
TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT(kFloat4_e2m1fn, kFloat4_e2m1fn,
192+
kFloat32, 16, 8, 32, false, true, false,
193+
true, true, SM120_FP4_FP4_F32_TN)
194+
TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT(kFloat8_e4m3, kFloat4_e2m1fn, kFloat32,
195+
16, 8, 32, false, true, false, false,
196+
true, SM120_FP8_FP4_F32_TN)
197+
TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT(kFloat4_e2m1fn, kFloat8_e4m3, kFloat32,
198+
16, 8, 32, false, true, false, true,
199+
false, SM120_FP4_FP8_F32_TN)
200+
201+
#undef TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT
164202
#undef TL_DEFINE_MMA_DISPATCHER
203+
#undef TL_DEFINE_MMA_DISPATCHER_IMPL
165204

166205
} // namespace detail
167206

@@ -178,37 +217,7 @@ TL_DEVICE void mma_sync(
178217
TransB, Saturate>;
179218
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
180219
"tl::mma_sync: unsupported configuration");
181-
if constexpr (AType == DataType::kFloat4_e2m1fn ||
182-
BType == DataType::kFloat4_e2m1fn) {
183-
// SM120 f8f6f4 MMA expects FP4 operands in the same register placement as
184-
// CuTe's b4x16 load path. Shift only FP4 operands; mixed FP8 operands keep
185-
// their native register bits.
186-
using AReg = typename Dispatcher::ARegType;
187-
using BReg = typename Dispatcher::BRegType;
188-
constexpr int nA = detail::MmaImplTraits<typename Dispatcher::Impl>::kARegs;
189-
constexpr int nB = detail::MmaImplTraits<typename Dispatcher::Impl>::kBRegs;
190-
AReg as[nA];
191-
BReg bs[nB];
192-
#pragma unroll
193-
for (int i = 0; i < nA; ++i) {
194-
if constexpr (AType == DataType::kFloat4_e2m1fn) {
195-
as[i] = a[i] << 2;
196-
} else {
197-
as[i] = a[i];
198-
}
199-
}
200-
#pragma unroll
201-
for (int i = 0; i < nB; ++i) {
202-
if constexpr (BType == DataType::kFloat4_e2m1fn) {
203-
bs[i] = b[i] << 2;
204-
} else {
205-
bs[i] = b[i];
206-
}
207-
}
208-
Dispatcher::exec(c, as, bs, c);
209-
} else {
210-
Dispatcher::exec(c, a, b, c);
211-
}
220+
Dispatcher::exec(c, a, b, c);
212221
}
213222

214223
} // namespace tl

tilelang/cuda/intrinsics/layout/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def get_ldmatrix_offset(
2828
transposed: bool = False,
2929
):
3030
assert matrix in ["A", "B"], "matrix should be either A or B"
31-
dtype = DataType(dtype)
32-
dtype_bits = dtype.bits
33-
is_fp4_e2m1fn = dtype_bits == 4 and str(dtype) == "float4_e2m1fn"
31+
dtype_obj = DataType(dtype)
32+
dtype_bits = dtype_obj.bits
33+
is_fp4_e2m1fn = dtype_bits == 4 and str(dtype_obj) == "float4_e2m1fn"
3434
if dtype_bits == 32:
3535
if matrix == "B" and transposed:
3636
transform_func = ldmatrix_32x4_to_shared_16x8_layout_b
@@ -78,7 +78,7 @@ def get_ldmatrix_offset(
7878
else:
7979
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
8080
else:
81-
raise ValueError(f"Unsupported dtype {dtype}")
81+
raise ValueError(f"Unsupported dtype {dtype_obj}")
8282

8383

8484
def shared_16x16_to_mma_32x8_layout(i, j):

tilelang/cuda/intrinsics/macro/mma_macro_generator.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def __init__(
118118
def _initialize_k_dim(self, a_dtype=T.float16):
119119
if isinstance(a_dtype, str):
120120
a_dtype = DataType(a_dtype)
121-
if str(a_dtype) == "float4_e2m1fn":
121+
a_dtype_str = str(a_dtype)
122+
if a_dtype_str == "float4_e2m1fn":
122123
if self.chunk < 32:
123124
raise ValueError(f"float4_e2m1fn MMA requires chunk >= 32, got chunk={self.chunk}")
124125
self.k_dim = 32
@@ -300,7 +301,9 @@ def _warp_ld_a_fp64(
300301
micro_size_k = self.micro_size_k
301302
local_size_a = self.local_size_a
302303
a_transposed = self.a_transposed
303-
a_dtype_bits = DataType(a_dtype).bits
304+
a_dtype_obj = DataType(a_dtype)
305+
a_dtype_bits = a_dtype_obj.bits
306+
is_fp4_a = str(a_dtype_obj) == "float4_e2m1fn"
304307
# ldmatrix cannot be used for int8 + trans case.
305308
ldmatrix_available = not (a_dtype_bits != 16 and a_transposed)
306309

@@ -344,8 +347,7 @@ def _warp_ldmatrix_a(
344347

345348
if ldmatrix_available:
346349
num = 4
347-
is_fp4 = str(DataType(a_dtype)) == "float4_e2m1fn"
348-
access_extent = 4 * num if is_fp4 else 2 * num
350+
access_extent = 4 * num if is_fp4_a else 2 * num
349351
row_off, col_off = get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed)
350352
src_indices = (
351353
tuple(A_other) + (A_base0 + wk + row_off, A_base1 + wi + col_off)
@@ -416,7 +418,9 @@ def _warp_ld_b_fp64(
416418
micro_size_k = self.micro_size_k
417419
local_size_b = self.local_size_b
418420
b_transposed = self.b_transposed
419-
b_dtype_bits = DataType(b_dtype).bits
421+
b_dtype_obj = DataType(b_dtype)
422+
b_dtype_bits = b_dtype_obj.bits
423+
is_fp4_b = str(b_dtype_obj) == "float4_e2m1fn"
420424
thread_binding = self.get_thread_binding()
421425

422426
# legalize shared buffer to region
@@ -464,8 +468,7 @@ def _warp_ldmatrix_b(
464468

465469
if ldmatrix_available:
466470
num = 4 if replicate_b else 2
467-
is_fp4 = str(DataType(b_dtype)) == "float4_e2m1fn"
468-
access_extent = 4 * num if is_fp4 else 2 * num
471+
access_extent = 4 * num if is_fp4_b else 2 * num
469472
row_off, col_off = get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed)
470473
src_indices = (
471474
tuple(B_other) + (B_base0 + wi + row_off, B_base1 + wk + col_off)
@@ -873,10 +876,11 @@ def __init__(
873876
self._initialize_transform_kind(transform_kind_a, transform_kind_b)
874877

875878
def _initialize_k_dim(self, a_dtype=T.float16):
876-
if str(DataType(a_dtype)) == "float4_e2m1fn":
879+
a_dtype_obj = DataType(a_dtype)
880+
if str(a_dtype_obj) == "float4_e2m1fn":
877881
self.k_dim = 32
878882
else:
879-
self.k_dim = 256 // DataType(a_dtype).bits
883+
self.k_dim = 256 // a_dtype_obj.bits
880884

881885
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
882886
self.local_size_a = (m_dim * k_dim) // warp_size

0 commit comments

Comments
 (0)