Skip to content

Commit 3285a60

Browse files
committed
[mlir][dxsa] Add dcl_thread_group instruction
Example: dxsa.dcl_thread_group<x = 1, y = 1, z = 1> Signed-off-by: Vladimir Shiryaev <tagolog@users.noreply.github.com>
1 parent 014c7b3 commit 3285a60

6 files changed

Lines changed: 100 additions & 0 deletions

File tree

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,4 +719,28 @@ def DXSA_DclOutput : DXSA_Op<"dcl_output"> {
719719
let assemblyFormat = "$operand attr-dict";
720720
}
721721

722+
def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> {
723+
let summary = "declares compute shader thread group dimensions";
724+
let description = [{
725+
The `dxsa.dcl_thread_group` operation declares the `$x`, `$y` and `$z`
726+
dimensions of a compute shader thread group.
727+
728+
The product `$x` * `$y` * `$z` must not exceed 1024.
729+
730+
Example:
731+
732+
```mlir
733+
dxsa.dcl_thread_group<x = 1, y = 1, z = 1>
734+
```
735+
}];
736+
let arguments = (ins
737+
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<1024>]>:$x,
738+
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<1024>]>:$y,
739+
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<64>]>:$z);
740+
let assemblyFormat = [{
741+
`<` `x` `=` $x `,` `y` `=` $y `,` `z` `=` $z `>` attr-dict
742+
}];
743+
let hasVerifier = 1;
744+
}
745+
722746
#endif // DXSA_OPS

mlir/lib/Dialect/DXSA/IR/DXSA.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ void DXSADialect::initialize() {
3434
>();
3535
}
3636

37+
//===----------------------------------------------------------------------===//
38+
// DclThreadGroup
39+
//===----------------------------------------------------------------------===//
40+
41+
LogicalResult DclThreadGroup::verify() {
42+
constexpr int64_t maxTotalThreads = 1024;
43+
if (auto total = static_cast<int64_t>(getX()) * getY() * getZ();
44+
total > maxTotalThreads)
45+
return emitOpError("thread group size x*y*z must be <= ")
46+
<< maxTotalThreads << ", got " << total;
47+
return success();
48+
}
49+
3750
//===----------------------------------------------------------------------===//
3851
// TableGen'd op method definitions
3952
//===----------------------------------------------------------------------===//

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,13 @@ class DXBuilder {
635635
return dxsa::DclOutput::create(builder, loc, operand);
636636
}
637637

638+
Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z,
639+
Location loc) {
640+
return dxsa::DclThreadGroup::create(
641+
builder, loc, builder.getI32IntegerAttr(x),
642+
builder.getI32IntegerAttr(y), builder.getI32IntegerAttr(z));
643+
}
644+
638645
private:
639646
MLIRContext *context;
640647
ModuleOp module;
@@ -1174,6 +1181,16 @@ class Parser {
11741181
return builder.buildDclOutput(*operand, loc);
11751182
}
11761183

1184+
FailureOr<Instruction> parseDclThreadGroup(Location loc) {
1185+
auto x = parseToken();
1186+
FAILURE_IF_FAILED(x);
1187+
auto y = parseToken();
1188+
FAILURE_IF_FAILED(y);
1189+
auto z = parseToken();
1190+
FAILURE_IF_FAILED(z);
1191+
return builder.buildDclThreadGroup(*x, *y, *z, loc);
1192+
}
1193+
11771194
OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
11781195
Instruction &out) {
11791196
FailureOr<Instruction> result;
@@ -1220,6 +1237,9 @@ class Parser {
12201237
case D3D10_SB_OPCODE_DCL_OUTPUT:
12211238
result = parseDclOutput(loc);
12221239
break;
1240+
case D3D11_SB_OPCODE_DCL_THREAD_GROUP:
1241+
result = parseDclThreadGroup(loc);
1242+
break;
12231243
default:
12241244
return std::nullopt;
12251245
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s
2+
3+
// CHECK: module {
4+
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1, y = 1, z = 1>
5+
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1024, y = 1, z = 1>
6+
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1, y = 1024, z = 1>
7+
// CHECK-NEXT: dxsa.dcl_thread_group<x = 1, y = 1, z = 64>
8+
// CHECK-NEXT: }
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
4+
dxsa.dcl_thread_group <x = 0, y = 1, z = 1>
5+
6+
// -----
7+
8+
// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
9+
dxsa.dcl_thread_group <x = 1025, y = 1, z = 1>
10+
11+
// -----
12+
13+
// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
14+
dxsa.dcl_thread_group <x = 1, y = 0, z = 1>
15+
16+
// -----
17+
18+
// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
19+
dxsa.dcl_thread_group <x = 1, y = 1025, z = 1>
20+
21+
// -----
22+
23+
// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}}
24+
dxsa.dcl_thread_group <x = 1, y = 1, z = 0>
25+
26+
// -----
27+
28+
// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}}
29+
dxsa.dcl_thread_group <x = 1, y = 1, z = 65>
30+
31+
// -----
32+
33+
// 64 * 8 * 4 == 2048
34+
// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}}
35+
dxsa.dcl_thread_group <x = 64, y = 8, z = 4>
64 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)