Skip to content

Commit a89b52d

Browse files
committed
[mlir][dxsa] Add dcl_thread_group instruction
Example: dxsa.dcl_thread_group 8, 8, 1 Signed-off-by: Vladimir Shiryaev <tagolog@users.noreply.github.com>
1 parent 856b5d3 commit a89b52d

6 files changed

Lines changed: 112 additions & 0 deletions

File tree

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,14 @@ def DXSA_DclTemps : DXSA_Op<"dcl_temps"> {
158158
let assemblyFormat = [{ $count attr-dict }];
159159
}
160160

161+
def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> {
162+
let summary = "declare thread group size";
163+
let arguments = (ins
164+
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<1024>]>:$x,
165+
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<1024>]>:$y,
166+
ConfinedAttr<I32Attr, [IntPositive, IntMaxValue<64>]>:$z);
167+
let assemblyFormat = [{ $x `,` $y `,` $z attr-dict }];
168+
let hasVerifier = 1;
169+
}
170+
161171
#endif // DXSA_OPS

mlir/lib/Dialect/DXSA/IR/DXSA.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ LogicalResult DclGlobalFlags::verify() {
4444
return success();
4545
}
4646

47+
//===----------------------------------------------------------------------===//
48+
// DclThreadGroup
49+
//===----------------------------------------------------------------------===//
50+
51+
LogicalResult DclThreadGroup::verify() {
52+
auto total = uint64_t(getX()) * uint64_t(getY()) * uint64_t(getZ());
53+
if (total > 1024)
54+
return emitOpError("thread group size x*y*z must be <= 1024, got ")
55+
<< total;
56+
return success();
57+
}
58+
4759
//===----------------------------------------------------------------------===//
4860
// TableGen'd op method definitions
4961
//===----------------------------------------------------------------------===//

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,13 @@ class DXBuilder {
516516
builder.getI32IntegerAttr(count));
517517
}
518518

519+
Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z,
520+
Location loc) {
521+
return dxsa::DclThreadGroup::create(
522+
builder, loc, builder.getI32IntegerAttr(x),
523+
builder.getI32IntegerAttr(y), builder.getI32IntegerAttr(z));
524+
}
525+
519526
private:
520527
MLIRContext *context;
521528
ModuleOp module;
@@ -831,6 +838,41 @@ class Parser {
831838
return builder.buildDclTemps(count, loc);
832839
}
833840

841+
FailureOr<Instruction> parseDclThreadGroup(Location loc) {
842+
auto parseDimension = [&](StringRef dimensionName,
843+
uint32_t maxValue) -> FailureOr<uint32_t> {
844+
auto token = parseToken();
845+
if (failed(token))
846+
return failure();
847+
auto value = *token;
848+
if (value == 0 || value > maxValue) {
849+
emitError(getLocation(), "thread group ")
850+
<< dimensionName << " dimension must be in [1, " << maxValue
851+
<< "], got " << value;
852+
return failure();
853+
}
854+
return value;
855+
};
856+
857+
auto x = parseDimension("x", 1024);
858+
if (failed(x))
859+
return failure();
860+
auto y = parseDimension("y", 1024);
861+
if (failed(y))
862+
return failure();
863+
auto z = parseDimension("z", 64);
864+
if (failed(z))
865+
return failure();
866+
867+
auto total = uint64_t(*x) * uint64_t(*y) * uint64_t(*z);
868+
if (total > 1024) {
869+
emitError(getLocation(), "thread group size x*y*z must be <= 1024, got ")
870+
<< total;
871+
return failure();
872+
}
873+
return builder.buildDclThreadGroup(*x, *y, *z, loc);
874+
}
875+
834876
OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
835877
Instruction &out) {
836878
FailureOr<Instruction> result;
@@ -841,6 +883,9 @@ class Parser {
841883
case D3D10_SB_OPCODE_DCL_TEMPS:
842884
result = parseDclTemps(loc);
843885
break;
886+
case D3D11_SB_OPCODE_DCL_THREAD_GROUP:
887+
result = parseDclThreadGroup(loc);
888+
break;
844889
default:
845890
return std::nullopt;
846891
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s
2+
3+
// CHECK: module {
4+
// CHECK-NEXT: dxsa.dcl_thread_group 1, 1, 1
5+
// CHECK-NEXT: dxsa.dcl_thread_group 8, 8, 1
6+
// CHECK-NEXT: dxsa.dcl_thread_group 32, 32, 1
7+
// CHECK-NEXT: dxsa.dcl_thread_group 16, 1, 64
8+
// CHECK-NEXT: dxsa.dcl_thread_group 1024, 1, 1
9+
// CHECK-NEXT: dxsa.dcl_thread_group 1, 1024, 1
10+
// CHECK-NEXT: }
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
4+
dxsa.dcl_thread_group 0, 1, 1
5+
6+
// -----
7+
8+
// expected-error@+1 {{attribute 'x' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
9+
dxsa.dcl_thread_group 1025, 1, 1
10+
11+
// -----
12+
13+
// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
14+
dxsa.dcl_thread_group 1, 0, 1
15+
16+
// -----
17+
18+
// expected-error@+1 {{attribute 'y' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 1024}}
19+
dxsa.dcl_thread_group 1, 1025, 1
20+
21+
// -----
22+
23+
// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}}
24+
dxsa.dcl_thread_group 1, 1, 0
25+
26+
// -----
27+
28+
// expected-error@+1 {{attribute 'z' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 64}}
29+
dxsa.dcl_thread_group 1, 1, 65
30+
31+
// -----
32+
33+
// 64 * 8 * 4 == 2048
34+
// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}}
35+
dxsa.dcl_thread_group 64, 8, 4
96 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)