diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index c4a86d4dfcdc..9d26894db8c5 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -894,4 +894,32 @@ def DXSA_DclTgsmRaw : DXSA_Op<"dcl_tgsm_raw"> { let hasVerifier = 1; } +def DXSA_DclTgsmStructured : DXSA_Op<"dcl_tgsm_structured"> { + let summary = "declares a reference to a Thread Group Shared Memory region"; + let description = [{ + The `dxsa.dcl_tgsm_structured` operation declares a reference to a Thread Group Shared Memory region. + The memory is viewed as an array of structures. + + The `$operand` is the `g#` register being declared. + The `$struct_byte_stride` is a uint in bytes and must be a multiple of 4. + The `$struct_count` is the number of structures. + The total size `$struct_byte_stride * $struct_count` must not exceed + 32 KB. + + Example: + + ```mlir + dxsa.dcl_tgsm_structured , 16, 64 + ``` + }]; + let arguments = (ins DXSA_InlineOperandAttr:$operand, + ConfinedAttr]>:$struct_byte_stride, + ConfinedAttr]>:$struct_count); + let assemblyFormat = + "$operand `,` $struct_byte_stride `,` $struct_count attr-dict"; + let hasVerifier = 1; +} + #endif // DXSA_OPS diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index 029a94dae411..f703a32e48ea 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -70,6 +70,20 @@ LogicalResult DclTgsmRaw::verify() { return success(); } +LogicalResult DclTgsmStructured::verify() { + auto stride = getStructByteStride(); + auto count = getStructCount(); + if (stride % 4 != 0) + return emitOpError("struct byte stride must be a multiple of 4, got ") + << stride; + auto totalSize = static_cast(stride) * count; + if (totalSize > 32768) + return emitOpError("total size struct_byte_stride * struct_count must " + "be <= 32768, got ") + << totalSize; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd attribute method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index 46de34369d1b..1ac5161fbb96 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -684,6 +684,14 @@ class DXBuilder { builder.getI32IntegerAttr(byteCount)); } + Instruction buildDclTgsmStructured(dxsa::InlineOperandAttr operand, + uint32_t structByteStride, + uint32_t structCount, Location loc) { + return dxsa::DclTgsmStructured::create( + builder, loc, operand, builder.getI32IntegerAttr(structByteStride), + builder.getI32IntegerAttr(structCount)); + } + private: MLIRContext *context; ModuleOp module; @@ -1298,6 +1306,17 @@ class Parser { return builder.buildDclTgsmRaw(*operand, *byteCount, loc); } + FailureOr parseDclTgsmStructured(Location loc) { + auto operand = parseInlineOperand(); + FAILURE_IF_FAILED(operand); + auto structByteStride = parseToken(); + FAILURE_IF_FAILED(structByteStride); + auto structCount = parseToken(); + FAILURE_IF_FAILED(structCount); + return builder.buildDclTgsmStructured(*operand, *structByteStride, + *structCount, loc); + } + OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc, Instruction &out) { FailureOr result; @@ -1371,6 +1390,9 @@ class Parser { case D3D11_SB_OPCODE_DCL_THREAD_GROUP_SHARED_MEMORY_RAW: result = parseDclTgsmRaw(loc); break; + case D3D11_SB_OPCODE_DCL_THREAD_GROUP_SHARED_MEMORY_STRUCTURED: + result = parseDclTgsmStructured(loc); + break; default: return std::nullopt; } diff --git a/mlir/test/Target/DXSA/dcl_tgsm_structured.mlir b/mlir/test/Target/DXSA/dcl_tgsm_structured.mlir new file mode 100644 index 000000000000..3119f1ec7948 --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_tgsm_structured.mlir @@ -0,0 +1,5 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_tgsm_structured.bin | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: dxsa.dcl_tgsm_structured , 16, 64 +// CHECK-NEXT: } diff --git a/mlir/test/Target/DXSA/dcl_tgsm_structured_invalid.mlir b/mlir/test/Target/DXSA/dcl_tgsm_structured_invalid.mlir new file mode 100644 index 000000000000..fd9adcc9b1b2 --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_tgsm_structured_invalid.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{'dxsa.dcl_tgsm_structured' op struct byte stride must be a multiple of 4, got 6}} +dxsa.dcl_tgsm_structured , 6, 64 + +// ----- + +// expected-error@+1 {{attribute 'struct_byte_stride' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 32768}} +dxsa.dcl_tgsm_structured , 0, 64 + +// ----- + +// expected-error@+1 {{attribute 'struct_byte_stride' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 32768}} +dxsa.dcl_tgsm_structured , 32772, 1 + +// ----- + +// expected-error@+1 {{attribute 'struct_count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 8192}} +dxsa.dcl_tgsm_structured , 16, 0 + +// ----- + +// expected-error@+1 {{attribute 'struct_count' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive whose maximum value is 8192}} +dxsa.dcl_tgsm_structured , 4, 8193 + +// ----- + +// expected-error@+1 {{'dxsa.dcl_tgsm_structured' op total size struct_byte_stride * struct_count must be <= 32768, got 65536}} +dxsa.dcl_tgsm_structured , 32768, 2 diff --git a/mlir/test/Target/DXSA/inputs/dcl_tgsm_structured.bin b/mlir/test/Target/DXSA/inputs/dcl_tgsm_structured.bin new file mode 100644 index 000000000000..c28052eb813b Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/dcl_tgsm_structured.bin differ