Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -894,4 +894,32 @@ def DXSA_DclTgsmRaw : DXSA_Op<"dcl_tgsm_raw"> {
let hasVerifier = 1;
}

def DXSA_DclTgsmStructured : DXSA_Op<"dcl_tgsm_structured"> {
let summary = "declares a reference to a Thread Group Shared Memory region";
let description = [{
The `dxsa.dcl_tgsm_structured` operation declares a reference to a Thread Group Shared Memory region.
The memory is viewed as an array of structures.

The `$operand` is the `g#` register being declared.
The `$struct_byte_stride` is a uint in bytes and must be a multiple of 4.
The `$struct_count` is the number of structures.
The total size `$struct_byte_stride * $struct_count` must not exceed
32 KB.

Example:

```mlir
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 16, 64
```
}];
let arguments = (ins DXSA_InlineOperandAttr:$operand,
ConfinedAttr<I32Attr,
[IntPositive, IntMaxValue<32768>]>:$struct_byte_stride,
ConfinedAttr<I32Attr,
[IntPositive, IntMaxValue<8192>]>:$struct_count);
let assemblyFormat =
"$operand `,` $struct_byte_stride `,` $struct_count attr-dict";
let hasVerifier = 1;
}

#endif // DXSA_OPS
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/DXSA/IR/DXSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ LogicalResult DclTgsmRaw::verify() {
return success();
}

LogicalResult DclTgsmStructured::verify() {
auto stride = getStructByteStride();
auto count = getStructCount();
if (stride % 4 != 0)
return emitOpError("struct byte stride must be a multiple of 4, got ")
<< stride;
auto totalSize = static_cast<uint64_t>(stride) * count;
if (totalSize > 32768)
return emitOpError("total size struct_byte_stride * struct_count must "
"be <= 32768, got ")
<< totalSize;
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd attribute method definitions
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Target/DXSA/BinaryParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,14 @@ class DXBuilder {
builder.getI32IntegerAttr(byteCount));
}

Instruction buildDclTgsmStructured(dxsa::InlineOperandAttr operand,
uint32_t structByteStride,
uint32_t structCount, Location loc) {
return dxsa::DclTgsmStructured::create(
builder, loc, operand, builder.getI32IntegerAttr(structByteStride),
builder.getI32IntegerAttr(structCount));
}

private:
MLIRContext *context;
ModuleOp module;
Expand Down Expand Up @@ -1298,6 +1306,17 @@ class Parser {
return builder.buildDclTgsmRaw(*operand, *byteCount, loc);
}

FailureOr<Instruction> parseDclTgsmStructured(Location loc) {
auto operand = parseInlineOperand();
FAILURE_IF_FAILED(operand);
auto structByteStride = parseToken();
FAILURE_IF_FAILED(structByteStride);
auto structCount = parseToken();
FAILURE_IF_FAILED(structCount);
return builder.buildDclTgsmStructured(*operand, *structByteStride,
*structCount, loc);
}

OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
Instruction &out) {
FailureOr<Instruction> result;
Expand Down Expand Up @@ -1371,6 +1390,9 @@ class Parser {
case D3D11_SB_OPCODE_DCL_THREAD_GROUP_SHARED_MEMORY_RAW:
result = parseDclTgsmRaw(loc);
break;
case D3D11_SB_OPCODE_DCL_THREAD_GROUP_SHARED_MEMORY_STRUCTURED:
result = parseDclTgsmStructured(loc);
break;
default:
return std::nullopt;
}
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/DXSA/dcl_tgsm_structured.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_tgsm_structured.bin | FileCheck %s

// CHECK: module {
// CHECK-NEXT: dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 16, 64
// CHECK-NEXT: }
29 changes: 29 additions & 0 deletions mlir/test/Target/DXSA/dcl_tgsm_structured_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+1 {{'dxsa.dcl_tgsm_structured' op struct byte stride must be a multiple of 4, got 6}}
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 6, 64

// -----

// expected-error@+1 {{attribute 'struct_byte_stride' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 32768}}
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 0, 64

// -----

// expected-error@+1 {{attribute 'struct_byte_stride' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 32768}}
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 32772, 1

// -----

// expected-error@+1 {{attribute 'struct_count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 8192}}
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 16, 0

// -----

// expected-error@+1 {{attribute 'struct_count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 8192}}
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 4, 8193

// -----

// expected-error@+1 {{'dxsa.dcl_tgsm_structured' op total size struct_byte_stride * struct_count must be <= 32768, got 65536}}
dxsa.dcl_tgsm_structured <type = thread_group_shared_memory, components = 0, index = [0]>, 32768, 2
Binary file not shown.