Skip to content

Commit b3c06cf

Browse files
authored
Merge branch 'main' into feat-issue-574-basis-inttuplebuilder
2 parents 72cf4e8 + 61a9130 commit b3c06cf

13 files changed

Lines changed: 524 additions & 111 deletions

File tree

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
project = "FlyDSL"
1616
copyright = "2024-2026, Advanced Micro Devices, Inc."
1717
author = "AMD"
18-
release = "0.1.8"
18+
release = "0.2.0"
1919

2020
# -- General configuration ---------------------------------------------------
2121
extensions = [

include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,25 @@ def FlyROCDL_MmaOpGFX11_WMMA : FlyROCDL_MmaOp<"MmaOpGFX11_WMMA", "gfx11.wmma", [
102102
"int32_t":$k,
103103
"Type":$elemTyA,
104104
"Type":$elemTyB,
105-
"Type":$elemTyAcc
105+
"Type":$elemTyAcc,
106+
// Integer-WMMA controls, forwarded to the ROCDL iu8/iu4 intrinsic. Ignored
107+
// (and required to be false) on fp16/bf16 paths — those intrinsics have no
108+
// such operands. Always-printed in the assembly format for clarity.
109+
"bool":$signA,
110+
"bool":$signB,
111+
"bool":$clamp
106112
);
107-
let assemblyFormat = "`<` custom<MNKDimensionList>($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `>`";
113+
let assemblyFormat = "`<` custom<MNKDimensionList>($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `,` `signA` `=` $signA `,` `signB` `=` $signB `,` `clamp` `=` $clamp `>`";
108114

109115
let builders = [
116+
// Legacy 6-arg builder: defaults signA/signB/clamp=false. Backward-compat
117+
// for fp16/bf16 callers; integer callers should use the 9-arg form below.
110118
TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{
111-
return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc);
119+
return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc, false, false, false);
120+
}]>,
121+
// Explicit 9-arg builder for integer-WMMA callers.
122+
TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc, "bool":$signA, "bool":$signB, "bool":$clamp), [{
123+
return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc, signA, signB, clamp);
112124
}]>
113125
];
114126
let genVerifyDecl = 1;

lib/Bindings/Python/FlyROCDLExtension.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,18 @@ struct PyMmaOpGFX11_WMMAType : PyConcreteType<PyMmaOpGFX11_WMMAType> {
8585
c.def_static(
8686
"get",
8787
[](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc,
88-
DefaultingPyMlirContext context) {
89-
return PyMmaOpGFX11_WMMAType(context->getRef(), wrap(MmaOpGFX11_WMMAType::get(
90-
m, n, k, unwrap(elemTyA),
91-
unwrap(elemTyB), unwrap(elemTyAcc))));
88+
bool signA, bool signB, bool clamp, DefaultingPyMlirContext context) {
89+
return PyMmaOpGFX11_WMMAType(
90+
context->getRef(),
91+
wrap(MmaOpGFX11_WMMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB),
92+
unwrap(elemTyAcc), signA, signB, clamp)));
9293
},
9394
"m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, nb::kw_only(),
94-
"context"_a = nb::none(),
95+
"sign_a"_a = false, "sign_b"_a = false, "clamp"_a = false, "context"_a = nb::none(),
9596
"Create a MmaOpGFX11_WMMAType with m, n, k dimensions and element types "
96-
"(RDNA3 / RDNA3.5 wave32 WMMA, v16 operand ABI)");
97+
"(RDNA3 / RDNA3.5 wave32 WMMA, v16 operand ABI). "
98+
"sign_a/sign_b/clamp are forwarded to the iu8/iu4 intrinsic for integer "
99+
"paths; must be false for fp16/bf16.");
97100
}
98101
};
99102

lib/Dialect/FlyROCDL/GFX11/MmaAtom.cpp

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -117,54 +117,42 @@ Attribute MmaOpGFX11_WMMAType::getThrValLayoutC() const {
117117

118118
LogicalResult MmaOpGFX11_WMMAType::verify(function_ref<InFlightDiagnostic()> emitError, int32_t m,
119119
int32_t n, int32_t k, Type elemTyA, Type elemTyB,
120-
Type elemTyAcc) {
120+
Type elemTyAcc, bool signA, bool signB, bool clamp) {
121121
if (m != 16 || n != 16 || k != 16) {
122122
return emitError() << "GFX11 WMMA requires M=N=K=16, got " << m << "x" << n << "x" << k;
123123
}
124124

125-
bool valid = false;
126-
127-
// fp16/bf16 inputs, f32 accumulator. (16-bit accumulator variants exist on
128-
// RDNA3 but require VGPR-pair packing/expansion around OPSEL; not yet
129-
// implemented here.)
130-
if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32())
131-
valid = true;
132-
if (elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32())
133-
valid = true;
134-
135-
// Integer inputs: REQUIRE explicit unsigned signedness (ui8/ui4).
136-
//
137-
// The atom contract is unsigned-only because emitAtomCallSSA invokes the
138-
// ROCDL iu8/iu4 intrinsics with signA=signB=false (unsigned interpretation
139-
// of the packed operands).
140-
auto isUI = [](Type t, unsigned width) {
125+
// Determine which path this is. fp16/bf16 inputs go to the f32-accumulator
126+
// intrinsics, which have no sign/clamp operands. iu8/iu4 inputs go to the
127+
// i32-accumulator intrinsics, which take all three.
128+
const bool isFp = (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) ||
129+
(elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32());
130+
131+
// For integer paths, accept any IntegerType width 8 or 4 regardless of
132+
// signedness (signless/si/ui). The caller controls how the input bits are
133+
// interpreted via signA/signB on the intrinsic.
134+
auto isInt = [](Type t, unsigned width) {
141135
auto it = dyn_cast<IntegerType>(t);
142-
return it && it.getWidth() == width && it.isUnsigned();
136+
return it && it.getWidth() == width;
143137
};
138+
const bool isI8x8 = isInt(elemTyA, 8) && isInt(elemTyB, 8) && elemTyAcc.isInteger(32);
139+
const bool isI4x4 = isInt(elemTyA, 4) && isInt(elemTyB, 4) && elemTyAcc.isInteger(32);
140+
const bool isInt8or4 = isI8x8 || isI4x4;
144141

145-
if (isUI(elemTyA, 8) && isUI(elemTyB, 8) && elemTyAcc.isInteger(32))
146-
valid = true;
147-
if (isUI(elemTyA, 4) && isUI(elemTyB, 4) && elemTyAcc.isInteger(32))
148-
valid = true;
149-
150-
if (!valid) {
151-
// Steer the caller to ui8/ui4 explicitly.
152-
auto looksLikeInt = [](Type t, unsigned w) {
153-
auto it = dyn_cast<IntegerType>(t);
154-
return it && it.getWidth() == w;
155-
};
156-
if ((looksLikeInt(elemTyA, 8) || looksLikeInt(elemTyA, 4)) && elemTyAcc.isInteger(32)) {
157-
return emitError() << "GFX11 WMMA integer inputs must be unsigned "
158-
"(ui8/ui4); got A="
159-
<< elemTyA << ", B=" << elemTyB
160-
<< ". The lowered ROCDL iu8/iu4 intrinsic is invoked "
161-
"with signA=signB=false, so signless/signed "
162-
"operands would silently get unsigned semantics. "
163-
"Signed-integer WMMA is not yet implemented.";
164-
}
142+
if (!isFp && !isInt8or4) {
165143
return emitError() << "unsupported GFX11 WMMA configuration: " << m << "x" << n << "x" << k
166144
<< " with A=" << elemTyA << ", B=" << elemTyB << ", Acc=" << elemTyAcc;
167145
}
146+
147+
// fp16/bf16 intrinsics do not have signA/signB/clamp operands. Refuse to
148+
// construct an atom that promises something the codegen cannot deliver.
149+
if (isFp && (signA || signB || clamp)) {
150+
return emitError() << "GFX11 WMMA fp16/bf16 path does not accept signA/signB/clamp "
151+
"(the ROCDL fp WMMA intrinsics have no such operands); "
152+
"got signA="
153+
<< signA << ", signB=" << signB << ", clamp=" << clamp;
154+
}
155+
168156
return success();
169157
}
170158

@@ -247,7 +235,6 @@ FailureOr<Value> MmaOpGFX11_WMMAType::emitAtomCallSSA(OpBuilder &builder, Locati
247235
StringRef opName;
248236
SmallVector<NamedAttribute, 3> attrs;
249237
SmallVector<Value, 3> operands;
250-
BoolAttr falseAttr = builder.getBoolAttr(false);
251238

252239
if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) {
253240
opName = ROCDL::wmma_f32_16x16x16_f16::getOperationName();
@@ -256,21 +243,19 @@ FailureOr<Value> MmaOpGFX11_WMMAType::emitAtomCallSSA(OpBuilder &builder, Locati
256243
opName = ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
257244
operands = {a, b, c};
258245
} else if (elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) {
259-
// Unsigned-only by contract (see verify()). signA=signB=false matches the
260-
// ui8 element type enforced there. clamp=false preserves wraparound on the
261-
// i32 accumulator.
246+
// Integer paths: signA/signB/clamp come from the type parameters so the
247+
// caller controls whether each operand is interpreted as signed.
262248
opName = ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
263249
operands = {a, b, c};
264-
attrs.push_back({builder.getStringAttr("signA"), falseAttr});
265-
attrs.push_back({builder.getStringAttr("signB"), falseAttr});
266-
attrs.push_back({builder.getStringAttr("clamp"), falseAttr});
250+
attrs.push_back({builder.getStringAttr("signA"), builder.getBoolAttr(getSignA())});
251+
attrs.push_back({builder.getStringAttr("signB"), builder.getBoolAttr(getSignB())});
252+
attrs.push_back({builder.getStringAttr("clamp"), builder.getBoolAttr(getClamp())});
267253
} else if (elemTyA.isInteger(4) && elemTyB.isInteger(4) && elemTyAcc.isInteger(32)) {
268-
// Same unsigned-only contract as iu8; see verify().
269254
opName = ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
270255
operands = {a, b, c};
271-
attrs.push_back({builder.getStringAttr("signA"), falseAttr});
272-
attrs.push_back({builder.getStringAttr("signB"), falseAttr});
273-
attrs.push_back({builder.getStringAttr("clamp"), falseAttr});
256+
attrs.push_back({builder.getStringAttr("signA"), builder.getBoolAttr(getSignA())});
257+
attrs.push_back({builder.getStringAttr("signB"), builder.getBoolAttr(getSignB())});
258+
attrs.push_back({builder.getStringAttr("clamp"), builder.getBoolAttr(getClamp())});
274259
} else {
275260
return failure();
276261
}

python/flydsl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
# Copyright (c) 2025 FlyDSL Project Contributors
33
# ruff: noqa: I001
44

5-
__version__ = "0.1.9"
5+
__version__ = "0.2.0"
66

77
from .autotune import Config as Config, autotune as autotune # noqa: E402

0 commit comments

Comments
 (0)