Skip to content

Commit 3dc31c5

Browse files
committed
[mlir][dxsa] Add dcl_uav_typed, dcl_uav_raw and dcl_uav_structured instructions
Example: dxsa.dcl_uav_typed <id = 0, dim = texture2d>, <x = float, y = float, z = float, w = float>, <flags = globallyCoherent> dxsa.dcl_uav_raw <id = 1>, <flags = globallyCoherent|rasterizerOrdered> dxsa.dcl_uav_structured <id = 2, struct_byte_stride = 16>, <flags = hasOrderPreservingCounter> Signed-off-by: Vladimir Shiryaev <tagolog@users.noreply.github.com>
1 parent 817a298 commit 3dc31c5

12 files changed

Lines changed: 414 additions & 0 deletions

File tree

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,27 @@ def DXSA_ResourceReturnTypeAttr :
313313
let assemblyFormat = "$value";
314314
}
315315

316+
def DXSA_UAVFlags_GloballyCoherent : I32BitEnumAttrCaseBit<"globallyCoherent", 0>;
317+
def DXSA_UAVFlags_RasterizerOrdered : I32BitEnumAttrCaseBit<"rasterizerOrdered", 1>;
318+
def DXSA_UAVFlags_HasOrderPreservingCounter : I32BitEnumAttrCaseBit<"hasOrderPreservingCounter", 2>;
319+
320+
def DXSA_UAVFlags : I32BitEnumAttr<
321+
"UAVFlags", "UAV access flags", [
322+
DXSA_UAVFlags_GloballyCoherent,
323+
DXSA_UAVFlags_RasterizerOrdered,
324+
DXSA_UAVFlags_HasOrderPreservingCounter
325+
]> {
326+
let separator = "|";
327+
let cppNamespace = "::mlir::dxsa";
328+
let genSpecializedAttr = 0;
329+
let printBitEnumPrimaryGroups = 1;
330+
}
331+
332+
def DXSA_UAVFlagsAttr :
333+
EnumAttr<DXSADialect, DXSA_UAVFlags, "uav_flags"> {
334+
let assemblyFormat = "$value";
335+
}
336+
316337
//===----------------------------------------------------------------------===//
317338
// DXSA ComponentMask bit-enum (mask field of operand, normalized to bits 0..3)
318339
//===----------------------------------------------------------------------===//
@@ -1022,4 +1043,112 @@ def DXSA_DclResource : DXSA_Op<"dcl_resource"> {
10221043
let hasVerifier = 1;
10231044
}
10241045

1046+
def DXSA_DclUavTyped : DXSA_Op<"dcl_uav_typed"> {
1047+
let summary = "declares a typed UnorderedAccessView (UAV) for use by a shader";
1048+
let description = [{
1049+
The `dxsa.dcl_uav_typed` operation declares a typed UnorderedAccessView (UAV) for use by a shader.
1050+
1051+
Example:
1052+
1053+
```mlir
1054+
dxsa.dcl_uav_typed <id = 0, dim = buffer>,
1055+
<x = unorm, y = snorm, z = sint, w = uint>
1056+
dxsa.dcl_uav_typed <id = 1, dim = texture2d>,
1057+
<x = float, y = float, z = float, w = float>,
1058+
<flags = globallyCoherent>
1059+
dxsa.dcl_uav_typed <id = 0, dim = texture3d,
1060+
lbound = 0, ubound = 3, space = 1>,
1061+
<x = float, y = float, z = float, w = float>
1062+
```
1063+
}];
1064+
1065+
let arguments = (ins
1066+
I32Attr:$id,
1067+
DXSA_ResourceDimensionAttr:$dim,
1068+
DXSA_ResourceReturnTypeAttr:$x,
1069+
DXSA_ResourceReturnTypeAttr:$y,
1070+
DXSA_ResourceReturnTypeAttr:$z,
1071+
DXSA_ResourceReturnTypeAttr:$w,
1072+
OptionalAttr<DXSA_UAVFlagsAttr>:$flags,
1073+
OptionalAttr<I32Attr>:$lbound,
1074+
OptionalAttr<I32Attr>:$ubound,
1075+
OptionalAttr<I32Attr>:$space);
1076+
let assemblyFormat = [{
1077+
` ` `<` `id` `=` $id `,` `dim` `=` $dim
1078+
(`,` `lbound` `=` $lbound^ `,` `ubound` `=` $ubound
1079+
`,` `space` `=` $space)? `>` `,`
1080+
`<` `x` `=` $x `,` `y` `=` $y `,` `z` `=` $z `,` `w` `=` $w `>`
1081+
(`,` `<` `flags` `=` $flags^ `>`)?
1082+
attr-dict
1083+
}];
1084+
let hasVerifier = 1;
1085+
}
1086+
1087+
def DXSA_DclUavRaw : DXSA_Op<"dcl_uav_raw"> {
1088+
let summary = "declares a raw UnorderedAccessView (UAV) for use by a shader";
1089+
let description = [{
1090+
The `dxsa.dcl_uav_raw` operation declares a raw UnorderedAccessView (UAV) for use by a shader.
1091+
1092+
Example:
1093+
1094+
```mlir
1095+
dxsa.dcl_uav_raw <id = 0>
1096+
dxsa.dcl_uav_raw <id = 1>, <flags = globallyCoherent>
1097+
dxsa.dcl_uav_raw <id = 0, lbound = 0, ubound = 3, space = 1>
1098+
```
1099+
}];
1100+
1101+
let arguments = (ins
1102+
I32Attr:$id,
1103+
OptionalAttr<DXSA_UAVFlagsAttr>:$flags,
1104+
OptionalAttr<I32Attr>:$lbound,
1105+
OptionalAttr<I32Attr>:$ubound,
1106+
OptionalAttr<I32Attr>:$space);
1107+
let assemblyFormat = [{
1108+
` ` `<` `id` `=` $id
1109+
(`,` `lbound` `=` $lbound^ `,` `ubound` `=` $ubound
1110+
`,` `space` `=` $space)? `>`
1111+
(`,` `<` `flags` `=` $flags^ `>`)?
1112+
attr-dict
1113+
}];
1114+
let hasVerifier = 1;
1115+
}
1116+
1117+
def DXSA_DclUavStructured : DXSA_Op<"dcl_uav_structured"> {
1118+
let summary = "declares a structured UnorderedAccessView (UAV) for use by a shader";
1119+
let description = [{
1120+
The `dxsa.dcl_uav_structured` operation declares a structured UnorderedAccessView (UAV) for use by a shader.
1121+
1122+
`$struct_byte_stride` is the structure size in bytes; it must be a
1123+
multiple of 4.
1124+
1125+
Example:
1126+
1127+
```mlir
1128+
dxsa.dcl_uav_structured <id = 0, struct_byte_stride = 16>
1129+
dxsa.dcl_uav_structured <id = 1, struct_byte_stride = 32>,
1130+
<flags = globallyCoherent|hasOrderPreservingCounter>
1131+
dxsa.dcl_uav_structured <id = 0, struct_byte_stride = 32,
1132+
lbound = 0, ubound = 3, space = 1>
1133+
```
1134+
}];
1135+
1136+
let arguments = (ins
1137+
I32Attr:$id,
1138+
ConfinedAttr<I32Attr, [IntPositive]>:$struct_byte_stride,
1139+
OptionalAttr<DXSA_UAVFlagsAttr>:$flags,
1140+
OptionalAttr<I32Attr>:$lbound,
1141+
OptionalAttr<I32Attr>:$ubound,
1142+
OptionalAttr<I32Attr>:$space);
1143+
let assemblyFormat = [{
1144+
` ` `<` `id` `=` $id
1145+
`,` `struct_byte_stride` `=` $struct_byte_stride
1146+
(`,` `lbound` `=` $lbound^ `,` `ubound` `=` $ubound
1147+
`,` `space` `=` $space)? `>`
1148+
(`,` `<` `flags` `=` $flags^ `>`)?
1149+
attr-dict
1150+
}];
1151+
let hasVerifier = 1;
1152+
}
1153+
10251154
#endif // DXSA_OPS

mlir/lib/Dialect/DXSA/IR/DXSA.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,56 @@ LogicalResult DclResource::verify() {
103103
return success();
104104
}
105105

106+
static LogicalResult verifyUavRange(Operation *op,
107+
std::optional<uint32_t> lbound,
108+
std::optional<uint32_t> ubound) {
109+
if (lbound && ubound && *lbound > *ubound)
110+
return op->emitOpError("expected lbound <= ubound, got lbound=")
111+
<< *lbound << ", ubound=" << *ubound;
112+
return success();
113+
}
114+
115+
static LogicalResult
116+
verifyNoOrderPreservingCounter(Operation *op, std::optional<UAVFlags> flags) {
117+
if (flags && bitEnumContainsAny(*flags, UAVFlags::hasOrderPreservingCounter))
118+
return op->emitOpError(
119+
"hasOrderPreservingCounter flag is only valid for dcl_uav_structured");
120+
return success();
121+
}
122+
123+
LogicalResult DclUavTyped::verify() {
124+
auto dim = getDim();
125+
switch (dim) {
126+
case ResourceDimension::buffer:
127+
case ResourceDimension::texture1d:
128+
case ResourceDimension::texture1darray:
129+
case ResourceDimension::texture2d:
130+
case ResourceDimension::texture2darray:
131+
case ResourceDimension::texture3d:
132+
break;
133+
default:
134+
return emitOpError("invalid dimension for typed UAV: ")
135+
<< stringifyResourceDimension(dim);
136+
}
137+
if (failed(verifyNoOrderPreservingCounter(*this, getFlags())))
138+
return failure();
139+
return verifyUavRange(*this, getLbound(), getUbound());
140+
}
141+
142+
LogicalResult DclUavRaw::verify() {
143+
if (failed(verifyNoOrderPreservingCounter(*this, getFlags())))
144+
return failure();
145+
return verifyUavRange(*this, getLbound(), getUbound());
146+
}
147+
148+
LogicalResult DclUavStructured::verify() {
149+
auto stride = getStructByteStride();
150+
if (stride % 4 != 0)
151+
return emitOpError("struct byte stride must be a multiple of 4, got ")
152+
<< stride;
153+
return verifyUavRange(*this, getLbound(), getUbound());
154+
}
155+
106156
//===----------------------------------------------------------------------===//
107157
// TableGen'd attribute method definitions
108158
//===----------------------------------------------------------------------===//

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,51 @@ class DXBuilder {
706706
toAttr(ubound), toAttr(space));
707707
}
708708

709+
dxsa::UAVFlagsAttr buildUavFlagsAttr(std::optional<dxsa::UAVFlags> flags) {
710+
return flags ? dxsa::UAVFlagsAttr::get(builder.getContext(), *flags)
711+
: dxsa::UAVFlagsAttr();
712+
}
713+
714+
Instruction buildDclUavTyped(
715+
uint32_t id, dxsa::ResourceDimension dim, dxsa::ResourceReturnType x,
716+
dxsa::ResourceReturnType y, dxsa::ResourceReturnType z,
717+
dxsa::ResourceReturnType w, std::optional<dxsa::UAVFlags> flags,
718+
std::optional<uint32_t> lbound, std::optional<uint32_t> ubound,
719+
std::optional<uint32_t> space, Location loc) {
720+
auto toAttr = [&](std::optional<uint32_t> v) -> IntegerAttr {
721+
return v ? builder.getI32IntegerAttr(*v) : IntegerAttr();
722+
};
723+
return dxsa::DclUavTyped::create(builder, loc, id, dim, x, y, z, w,
724+
buildUavFlagsAttr(flags), toAttr(lbound),
725+
toAttr(ubound), toAttr(space));
726+
}
727+
728+
Instruction buildDclUavRaw(uint32_t id, std::optional<dxsa::UAVFlags> flags,
729+
std::optional<uint32_t> lbound,
730+
std::optional<uint32_t> ubound,
731+
std::optional<uint32_t> space, Location loc) {
732+
auto toAttr = [&](std::optional<uint32_t> v) -> IntegerAttr {
733+
return v ? builder.getI32IntegerAttr(*v) : IntegerAttr();
734+
};
735+
return dxsa::DclUavRaw::create(builder, loc, id, buildUavFlagsAttr(flags),
736+
toAttr(lbound), toAttr(ubound),
737+
toAttr(space));
738+
}
739+
740+
Instruction buildDclUavStructured(uint32_t id, uint32_t structByteStride,
741+
std::optional<dxsa::UAVFlags> flags,
742+
std::optional<uint32_t> lbound,
743+
std::optional<uint32_t> ubound,
744+
std::optional<uint32_t> space,
745+
Location loc) {
746+
auto toAttr = [&](std::optional<uint32_t> v) -> IntegerAttr {
747+
return v ? builder.getI32IntegerAttr(*v) : IntegerAttr();
748+
};
749+
return dxsa::DclUavStructured::create(
750+
builder, loc, id, structByteStride, buildUavFlagsAttr(flags),
751+
toAttr(lbound), toAttr(ubound), toAttr(space));
752+
}
753+
709754
private:
710755
MLIRContext *context;
711756
ModuleOp module;
@@ -1398,6 +1443,105 @@ class Parser {
13981443
lbound, ubound, space, loc);
13991444
}
14001445

1446+
std::optional<dxsa::UAVFlags> decodeUavFlags(uint32_t opcodeToken) {
1447+
auto flags = static_cast<dxsa::UAVFlags>(0);
1448+
if (opcodeToken & D3D11_SB_GLOBALLY_COHERENT_ACCESS)
1449+
flags = flags | dxsa::UAVFlags::globallyCoherent;
1450+
if (opcodeToken & D3D11_SB_RASTERIZER_ORDERED_ACCESS)
1451+
flags = flags | dxsa::UAVFlags::rasterizerOrdered;
1452+
if (opcodeToken & D3D11_SB_UAV_HAS_ORDER_PRESERVING_COUNTER)
1453+
flags = flags | dxsa::UAVFlags::hasOrderPreservingCounter;
1454+
if (static_cast<uint32_t>(flags) == 0)
1455+
return std::nullopt;
1456+
return flags;
1457+
}
1458+
1459+
struct UavOperand {
1460+
uint32_t id;
1461+
std::optional<uint32_t> lbound;
1462+
std::optional<uint32_t> ubound;
1463+
};
1464+
1465+
FailureOr<UavOperand> parseUavOperand(Location loc) {
1466+
auto operand = parseInlineOperand();
1467+
FAILURE_IF_FAILED(operand);
1468+
if (operand->getType() != dxsa::InlineOperandType::uav)
1469+
return emitError(loc, "operand must be a uav register, got ")
1470+
<< dxsa::stringifyInlineOperandType(operand->getType());
1471+
auto indexArray = operand->getIndex();
1472+
auto indexDim = indexArray ? indexArray.size() : 0;
1473+
if (indexDim != 1 && indexDim != 3)
1474+
return emitError(loc, "operand must have a 1D or 3D index, got ")
1475+
<< indexDim;
1476+
UavOperand uav{static_cast<uint32_t>(indexArray[0]), std::nullopt,
1477+
std::nullopt};
1478+
if (indexDim == 3) {
1479+
uav.lbound = static_cast<uint32_t>(indexArray[1]);
1480+
uav.ubound = static_cast<uint32_t>(indexArray[2]);
1481+
}
1482+
return uav;
1483+
}
1484+
1485+
FailureOr<std::optional<uint32_t>> parseUavSpace(const UavOperand &uav) {
1486+
if (!uav.lbound)
1487+
return std::optional<uint32_t>(std::nullopt);
1488+
auto spaceToken = parseToken();
1489+
FAILURE_IF_FAILED(spaceToken);
1490+
return std::optional<uint32_t>(*spaceToken);
1491+
}
1492+
1493+
FailureOr<Instruction> parseDclUavTyped(uint32_t opcodeToken, Location loc) {
1494+
auto rawDim = DECODE_D3D10_SB_RESOURCE_DIMENSION(opcodeToken);
1495+
auto dim = dxsa::symbolizeResourceDimension(rawDim);
1496+
if (!dim)
1497+
return emitError(loc, "unknown resource dimension: ") << rawDim;
1498+
1499+
auto flags = decodeUavFlags(opcodeToken);
1500+
1501+
auto uav = parseUavOperand(loc);
1502+
FAILURE_IF_FAILED(uav);
1503+
1504+
auto returnTypeToken = parseToken();
1505+
FAILURE_IF_FAILED(returnTypeToken);
1506+
auto x = parseResourceReturnType(*returnTypeToken, 0, loc);
1507+
FAILURE_IF_FAILED(x);
1508+
auto y = parseResourceReturnType(*returnTypeToken, 1, loc);
1509+
FAILURE_IF_FAILED(y);
1510+
auto z = parseResourceReturnType(*returnTypeToken, 2, loc);
1511+
FAILURE_IF_FAILED(z);
1512+
auto w = parseResourceReturnType(*returnTypeToken, 3, loc);
1513+
FAILURE_IF_FAILED(w);
1514+
1515+
auto space = parseUavSpace(*uav);
1516+
FAILURE_IF_FAILED(space);
1517+
1518+
return builder.buildDclUavTyped(uav->id, *dim, *x, *y, *z, *w, flags,
1519+
uav->lbound, uav->ubound, *space, loc);
1520+
}
1521+
1522+
FailureOr<Instruction> parseDclUavRaw(uint32_t opcodeToken, Location loc) {
1523+
auto flags = decodeUavFlags(opcodeToken);
1524+
auto uav = parseUavOperand(loc);
1525+
FAILURE_IF_FAILED(uav);
1526+
auto space = parseUavSpace(*uav);
1527+
FAILURE_IF_FAILED(space);
1528+
return builder.buildDclUavRaw(uav->id, flags, uav->lbound, uav->ubound,
1529+
*space, loc);
1530+
}
1531+
1532+
FailureOr<Instruction> parseDclUavStructured(uint32_t opcodeToken,
1533+
Location loc) {
1534+
auto flags = decodeUavFlags(opcodeToken);
1535+
auto uav = parseUavOperand(loc);
1536+
FAILURE_IF_FAILED(uav);
1537+
auto strideToken = parseToken();
1538+
FAILURE_IF_FAILED(strideToken);
1539+
auto space = parseUavSpace(*uav);
1540+
FAILURE_IF_FAILED(space);
1541+
return builder.buildDclUavStructured(uav->id, *strideToken, flags,
1542+
uav->lbound, uav->ubound, *space, loc);
1543+
}
1544+
14011545
OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
14021546
Instruction &out) {
14031547
FailureOr<Instruction> result;
@@ -1477,6 +1621,15 @@ class Parser {
14771621
case D3D10_SB_OPCODE_DCL_RESOURCE:
14781622
result = parseDclResource(opcodeToken, loc);
14791623
break;
1624+
case D3D11_SB_OPCODE_DCL_UNORDERED_ACCESS_VIEW_TYPED:
1625+
result = parseDclUavTyped(opcodeToken, loc);
1626+
break;
1627+
case D3D11_SB_OPCODE_DCL_UNORDERED_ACCESS_VIEW_RAW:
1628+
result = parseDclUavRaw(opcodeToken, loc);
1629+
break;
1630+
case D3D11_SB_OPCODE_DCL_UNORDERED_ACCESS_VIEW_STRUCTURED:
1631+
result = parseDclUavStructured(opcodeToken, loc);
1632+
break;
14801633
default:
14811634
return std::nullopt;
14821635
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_uav_raw.bin | FileCheck %s
2+
3+
// CHECK: module {
4+
// CHECK-NEXT: dxsa.dcl_uav_raw <id = 0>
5+
// CHECK-NEXT: dxsa.dcl_uav_raw <id = 1>, <flags = globallyCoherent>
6+
// CHECK-NEXT: dxsa.dcl_uav_raw <id = 2>, <flags = globallyCoherent|rasterizerOrdered>
7+
// CHECK-NEXT: dxsa.dcl_uav_raw <id = 0, lbound = 0, ubound = 3, space = 1>
8+
// CHECK-NEXT: }
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// expected-error@+1 {{'dxsa.dcl_uav_raw' op hasOrderPreservingCounter flag is only valid for dcl_uav_structured}}
4+
dxsa.dcl_uav_raw <id = 0>, <flags = hasOrderPreservingCounter>
5+
6+
// -----
7+
8+
// expected-error@+1 {{'dxsa.dcl_uav_raw' op expected lbound <= ubound, got lbound=5, ubound=3}}
9+
dxsa.dcl_uav_raw <id = 0, lbound = 5, ubound = 3, space = 1>

0 commit comments

Comments
 (0)