@@ -117,54 +117,42 @@ Attribute MmaOpGFX11_WMMAType::getThrValLayoutC() const {
117117
118118LogicalResult MmaOpGFX11_WMMAType::verify (function_ref<InFlightDiagnostic()> emitError, int32_t m,
119119 int32_t n, int32_t k, Type elemTyA, Type elemTyB,
120- Type elemTyAcc) {
120+ Type elemTyAcc, bool signA, bool signB, bool clamp ) {
121121 if (m != 16 || n != 16 || k != 16 ) {
122122 return emitError () << " GFX11 WMMA requires M=N=K=16, got " << m << " x" << n << " x" << k;
123123 }
124124
125- bool valid = false ;
126-
127- // fp16/bf16 inputs, f32 accumulator. (16-bit accumulator variants exist on
128- // RDNA3 but require VGPR-pair packing/expansion around OPSEL; not yet
129- // implemented here.)
130- if (elemTyA.isF16 () && elemTyB.isF16 () && elemTyAcc.isF32 ())
131- valid = true ;
132- if (elemTyA.isBF16 () && elemTyB.isBF16 () && elemTyAcc.isF32 ())
133- valid = true ;
134-
135- // Integer inputs: REQUIRE explicit unsigned signedness (ui8/ui4).
136- //
137- // The atom contract is unsigned-only because emitAtomCallSSA invokes the
138- // ROCDL iu8/iu4 intrinsics with signA=signB=false (unsigned interpretation
139- // of the packed operands).
140- auto isUI = [](Type t, unsigned width) {
125+ // Determine which path this is. fp16/bf16 inputs go to the f32-accumulator
126+ // intrinsics, which have no sign/clamp operands. iu8/iu4 inputs go to the
127+ // i32-accumulator intrinsics, which take all three.
128+ const bool isFp = (elemTyA.isF16 () && elemTyB.isF16 () && elemTyAcc.isF32 ()) ||
129+ (elemTyA.isBF16 () && elemTyB.isBF16 () && elemTyAcc.isF32 ());
130+
131+ // For integer paths, accept any IntegerType width 8 or 4 regardless of
132+ // signedness (signless/si/ui). The caller controls how the input bits are
133+ // interpreted via signA/signB on the intrinsic.
134+ auto isInt = [](Type t, unsigned width) {
141135 auto it = dyn_cast<IntegerType>(t);
142- return it && it.getWidth () == width && it. isUnsigned () ;
136+ return it && it.getWidth () == width;
143137 };
138+ const bool isI8x8 = isInt (elemTyA, 8 ) && isInt (elemTyB, 8 ) && elemTyAcc.isInteger (32 );
139+ const bool isI4x4 = isInt (elemTyA, 4 ) && isInt (elemTyB, 4 ) && elemTyAcc.isInteger (32 );
140+ const bool isInt8or4 = isI8x8 || isI4x4;
144141
145- if (isUI (elemTyA, 8 ) && isUI (elemTyB, 8 ) && elemTyAcc.isInteger (32 ))
146- valid = true ;
147- if (isUI (elemTyA, 4 ) && isUI (elemTyB, 4 ) && elemTyAcc.isInteger (32 ))
148- valid = true ;
149-
150- if (!valid) {
151- // Steer the caller to ui8/ui4 explicitly.
152- auto looksLikeInt = [](Type t, unsigned w) {
153- auto it = dyn_cast<IntegerType>(t);
154- return it && it.getWidth () == w;
155- };
156- if ((looksLikeInt (elemTyA, 8 ) || looksLikeInt (elemTyA, 4 )) && elemTyAcc.isInteger (32 )) {
157- return emitError () << " GFX11 WMMA integer inputs must be unsigned "
158- " (ui8/ui4); got A="
159- << elemTyA << " , B=" << elemTyB
160- << " . The lowered ROCDL iu8/iu4 intrinsic is invoked "
161- " with signA=signB=false, so signless/signed "
162- " operands would silently get unsigned semantics. "
163- " Signed-integer WMMA is not yet implemented." ;
164- }
142+ if (!isFp && !isInt8or4) {
165143 return emitError () << " unsupported GFX11 WMMA configuration: " << m << " x" << n << " x" << k
166144 << " with A=" << elemTyA << " , B=" << elemTyB << " , Acc=" << elemTyAcc;
167145 }
146+
147+ // fp16/bf16 intrinsics do not have signA/signB/clamp operands. Refuse to
148+ // construct an atom that promises something the codegen cannot deliver.
149+ if (isFp && (signA || signB || clamp)) {
150+ return emitError () << " GFX11 WMMA fp16/bf16 path does not accept signA/signB/clamp "
151+ " (the ROCDL fp WMMA intrinsics have no such operands); "
152+ " got signA="
153+ << signA << " , signB=" << signB << " , clamp=" << clamp;
154+ }
155+
168156 return success ();
169157}
170158
@@ -247,7 +235,6 @@ FailureOr<Value> MmaOpGFX11_WMMAType::emitAtomCallSSA(OpBuilder &builder, Locati
247235 StringRef opName;
248236 SmallVector<NamedAttribute, 3 > attrs;
249237 SmallVector<Value, 3 > operands;
250- BoolAttr falseAttr = builder.getBoolAttr (false );
251238
252239 if (elemTyA.isF16 () && elemTyB.isF16 () && elemTyAcc.isF32 ()) {
253240 opName = ROCDL::wmma_f32_16x16x16_f16::getOperationName ();
@@ -256,21 +243,19 @@ FailureOr<Value> MmaOpGFX11_WMMAType::emitAtomCallSSA(OpBuilder &builder, Locati
256243 opName = ROCDL::wmma_f32_16x16x16_bf16::getOperationName ();
257244 operands = {a, b, c};
258245 } else if (elemTyA.isInteger (8 ) && elemTyB.isInteger (8 ) && elemTyAcc.isInteger (32 )) {
259- // Unsigned-only by contract (see verify()). signA=signB=false matches the
260- // ui8 element type enforced there. clamp=false preserves wraparound on the
261- // i32 accumulator.
246+ // Integer paths: signA/signB/clamp come from the type parameters so the
247+ // caller controls whether each operand is interpreted as signed.
262248 opName = ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
263249 operands = {a, b, c};
264- attrs.push_back ({builder.getStringAttr (" signA" ), falseAttr });
265- attrs.push_back ({builder.getStringAttr (" signB" ), falseAttr });
266- attrs.push_back ({builder.getStringAttr (" clamp" ), falseAttr });
250+ attrs.push_back ({builder.getStringAttr (" signA" ), builder. getBoolAttr ( getSignA ()) });
251+ attrs.push_back ({builder.getStringAttr (" signB" ), builder. getBoolAttr ( getSignB ()) });
252+ attrs.push_back ({builder.getStringAttr (" clamp" ), builder. getBoolAttr ( getClamp ()) });
267253 } else if (elemTyA.isInteger (4 ) && elemTyB.isInteger (4 ) && elemTyAcc.isInteger (32 )) {
268- // Same unsigned-only contract as iu8; see verify().
269254 opName = ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
270255 operands = {a, b, c};
271- attrs.push_back ({builder.getStringAttr (" signA" ), falseAttr });
272- attrs.push_back ({builder.getStringAttr (" signB" ), falseAttr });
273- attrs.push_back ({builder.getStringAttr (" clamp" ), falseAttr });
256+ attrs.push_back ({builder.getStringAttr (" signA" ), builder. getBoolAttr ( getSignA ()) });
257+ attrs.push_back ({builder.getStringAttr (" signB" ), builder. getBoolAttr ( getSignB ()) });
258+ attrs.push_back ({builder.getStringAttr (" clamp" ), builder. getBoolAttr ( getClamp ()) });
274259 } else {
275260 return failure ();
276261 }
0 commit comments