@@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t {
3131 Fp8Fp8TyId,
3232 Fp8Bf8TyId,
3333 Bf8Fp8TyId,
34- Bf8Bf8TyId
34+ Bf8Bf8TyId,
35+ // FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0)
36+ // These provide larger K dimension (128 for 16x16, 64 for 32x32)
37+ Fp8Fp8ScaledTyId,
38+ Fp8Bf8ScaledTyId,
39+ Bf8Fp8ScaledTyId,
40+ Bf8Bf8ScaledTyId
3541};
3642
3743struct MfmaInsnInfo {
@@ -71,7 +77,8 @@ class MfmaInsn {
7177 MfmaInsnAttr getAttr () const ;
7278 Type getArgTypeFor (Type elementTypeA);
7379 VectorType getRetType (Type elementType);
74- bool isCoherentWithK (int64_t kPack , int64_t kPerBlock );
80+ bool isCoherentWithK (int64_t kPack , int64_t kPerBlock ,
81+ int64_t scheduleVersion);
7582};
7683
7784template <typename T>
@@ -138,7 +145,8 @@ class MfmaInsnGroup {
138145public:
139146 static FailureOr<MfmaInsnGroup> select (Type elementTypeA, Type elementTypeB,
140147 StringRef arch, int64_t mnPerXdl,
141- int64_t kPack , int64_t kPackPerBlock );
148+ int64_t kPack , int64_t kPackPerBlock ,
149+ int64_t scheduleVersion);
142150 MfmaInsnGroup (Type elementTypeA, Type elementTypeB, const MfmaInsn &insn,
143151 const MfmaInsnGroupAttr &groupAttr);
144152 int64_t getMRepeats (int64_t mPerWave );
@@ -150,8 +158,13 @@ class MfmaInsnGroup {
150158 Type getArgTypeA ();
151159 Type getArgTypeB ();
152160 VectorType getRetType ();
153- bool isCoherentWithK (int64_t kPack , int64_t kPerBlock );
161+ bool isCoherentWithK (int64_t kPack , int64_t kPerBlock ,
162+ int64_t scheduleVersion);
154163 SmallString<16 > getROCDLIntrinsicName () { return groupAttr.insn ; }
164+
165+ // Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0)
166+ // These instructions have larger K dimension (128 for 16x16, 64 for 32x32)
167+ bool isScaledFp8 () const ;
155168};
156169
157170} // namespace rock
0 commit comments