@@ -69,9 +69,9 @@ struct MmaDispatcher {
6969 }
7070};
7171
72- #define TL_DEFINE_MMA_DISPATCHER (ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
73- NValue, KValue, TransAValue, TransBValue, \
74- SaturateValue , ImplType) \
72+ #define TL_DEFINE_MMA_DISPATCHER_IMPL ( \
73+ ATypeEnum, BTypeEnum, CTypeEnum, MValue, NValue, KValue, TransAValue, \
74+ TransBValue, SaturateValue, ShiftAValue, ShiftBValue , ImplType) \
7575 template <> \
7676 struct MmaDispatcher <DataType::ATypeEnum, DataType::BTypeEnum, \
7777 DataType::CTypeEnum, MValue, NValue, KValue, \
@@ -84,12 +84,46 @@ struct MmaDispatcher {
8484 static_assert ( \
8585 std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
8686 " tl::mma_sync requires matching accumulator/output regs" ); \
87+ template <bool Shift, class Reg > \
88+ static TL_DEVICE Reg maybe_shift_fp4_reg (Reg reg) { \
89+ if constexpr (Shift) { \
90+ return reg << 2 ; \
91+ } else { \
92+ return reg; \
93+ } \
94+ } \
8795 static TL_DEVICE void exec (CRegType *d, const ARegType *a, \
8896 const BRegType *b, const CRegType *c) { \
89- call_fma<Impl>(d, a, b, c); \
97+ if constexpr (ShiftAValue || ShiftBValue) { \
98+ ARegType as[Traits::kARegs ]; \
99+ BRegType bs[Traits::kBRegs ]; \
100+ _Pragma (" unroll" ) for (int i = 0 ; i < Traits::kARegs ; ++i) { \
101+ as[i] = maybe_shift_fp4_reg<ShiftAValue>(a[i]); \
102+ } \
103+ _Pragma (" unroll" ) for (int i = 0 ; i < Traits::kBRegs ; ++i) { \
104+ bs[i] = maybe_shift_fp4_reg<ShiftBValue>(b[i]); \
105+ } \
106+ call_fma<Impl>(d, as, bs, c); \
107+ } else { \
108+ call_fma<Impl>(d, a, b, c); \
109+ } \
90110 } \
91111 };
92112
113+ #define TL_DEFINE_MMA_DISPATCHER (ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
114+ NValue, KValue, TransAValue, TransBValue, \
115+ SaturateValue, ImplType) \
116+ TL_DEFINE_MMA_DISPATCHER_IMPL (ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
117+ NValue, KValue, TransAValue, TransBValue, \
118+ SaturateValue, false , false , ImplType)
119+
120+ #define TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT ( \
121+ ATypeEnum, BTypeEnum, CTypeEnum, MValue, NValue, KValue, TransAValue, \
122+ TransBValue, SaturateValue, ShiftAValue, ShiftBValue, ImplType) \
123+ TL_DEFINE_MMA_DISPATCHER_IMPL ( \
124+ ATypeEnum, BTypeEnum, CTypeEnum, MValue, NValue, KValue, TransAValue, \
125+ TransBValue, SaturateValue, ShiftAValue, ShiftBValue, ImplType)
126+
93127// FP16 inputs (TN layout: A row-major, B column-major)
94128TL_DEFINE_MMA_DISPATCHER (kFloat16 , kFloat16 , kFloat16 , 16 , 8 , 16 , false , true ,
95129 false , cute::SM80_16x8x16_F16F16F16F16_TN)
@@ -154,14 +188,19 @@ using SM120_FP8_FP4_F32_TN =
154188 cute::SM120_16x8x32_TN<cute::float_e4m3_t , cute::float_e2m1_t , float >;
155189using SM120_FP4_FP8_F32_TN =
156190 cute::SM120_16x8x32_TN<cute::float_e2m1_t , cute::float_e4m3_t , float >;
157- TL_DEFINE_MMA_DISPATCHER (kFloat4_e2m1fn , kFloat4_e2m1fn , kFloat32 , 16 , 8 , 32 ,
158- false , true , false , SM120_FP4_FP4_F32_TN)
159- TL_DEFINE_MMA_DISPATCHER (kFloat8_e4m3 , kFloat4_e2m1fn , kFloat32 , 16 , 8 , 32 ,
160- false , true , false , SM120_FP8_FP4_F32_TN)
161- TL_DEFINE_MMA_DISPATCHER (kFloat4_e2m1fn , kFloat8_e4m3 , kFloat32 , 16 , 8 , 32 ,
162- false , true , false , SM120_FP4_FP8_F32_TN)
163-
191+ TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT (kFloat4_e2m1fn , kFloat4_e2m1fn ,
192+ kFloat32 , 16 , 8 , 32 , false , true , false ,
193+ true , true , SM120_FP4_FP4_F32_TN)
194+ TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT (kFloat8_e4m3 , kFloat4_e2m1fn , kFloat32 ,
195+ 16 , 8 , 32 , false , true , false , false ,
196+ true , SM120_FP8_FP4_F32_TN)
197+ TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT (kFloat4_e2m1fn , kFloat8_e4m3 , kFloat32 ,
198+ 16 , 8 , 32 , false , true , false , true ,
199+ false , SM120_FP4_FP8_F32_TN)
200+
201+ #undef TL_DEFINE_MMA_DISPATCHER_WITH_FP4_SHIFT
164202#undef TL_DEFINE_MMA_DISPATCHER
203+ #undef TL_DEFINE_MMA_DISPATCHER_IMPL
165204
166205} // namespace detail
167206
@@ -178,37 +217,7 @@ TL_DEVICE void mma_sync(
178217 TransB, Saturate>;
179218 static_assert (!std::is_void_v<typename Dispatcher::CRegType>,
180219 " tl::mma_sync: unsupported configuration" );
181- if constexpr (AType == DataType::kFloat4_e2m1fn ||
182- BType == DataType::kFloat4_e2m1fn ) {
183- // SM120 f8f6f4 MMA expects FP4 operands in the same register placement as
184- // CuTe's b4x16 load path. Shift only FP4 operands; mixed FP8 operands keep
185- // their native register bits.
186- using AReg = typename Dispatcher::ARegType;
187- using BReg = typename Dispatcher::BRegType;
188- constexpr int nA = detail::MmaImplTraits<typename Dispatcher::Impl>::kARegs ;
189- constexpr int nB = detail::MmaImplTraits<typename Dispatcher::Impl>::kBRegs ;
190- AReg as[nA];
191- BReg bs[nB];
192- #pragma unroll
193- for (int i = 0 ; i < nA; ++i) {
194- if constexpr (AType == DataType::kFloat4_e2m1fn ) {
195- as[i] = a[i] << 2 ;
196- } else {
197- as[i] = a[i];
198- }
199- }
200- #pragma unroll
201- for (int i = 0 ; i < nB; ++i) {
202- if constexpr (BType == DataType::kFloat4_e2m1fn ) {
203- bs[i] = b[i] << 2 ;
204- } else {
205- bs[i] = b[i];
206- }
207- }
208- Dispatcher::exec (c, as, bs, c);
209- } else {
210- Dispatcher::exec (c, a, b, c);
211- }
220+ Dispatcher::exec (c, a, b, c);
212221}
213222
214223} // namespace tl
0 commit comments