diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index 9d26894db8c5..d718879c353d 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -281,6 +281,24 @@ def DXSA_ComponentMaskAttr : let assemblyFormat = "`<` $value `>`"; } +def DXSA_ConstantBufferAccessPattern_ImmediateIndexed : I32EnumAttrCase<"immediateIndexed", 0>; +def DXSA_ConstantBufferAccessPattern_DynamicIndexed : I32EnumAttrCase<"dynamicIndexed", 1>; + +def DXSA_ConstantBufferAccessPattern : I32EnumAttr< + "ConstantBufferAccessPattern", "constant buffer access pattern", [ + DXSA_ConstantBufferAccessPattern_ImmediateIndexed, + DXSA_ConstantBufferAccessPattern_DynamicIndexed + ]> { + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_ConstantBufferAccessPatternAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // DXSA attribute Constraints //===----------------------------------------------------------------------===// @@ -922,4 +940,33 @@ def DXSA_DclTgsmStructured : DXSA_Op<"dcl_tgsm_structured"> { let hasVerifier = 1; } +def DXSA_DclConstantBuffer : DXSA_Op<"dcl_constant_buffer"> { + let summary = "declares a constant buffer"; + let description = [{ + The `dxsa.dcl_constant_buffer` operation declares a constant buffer + with its access pattern. + + Examples: + + ```mlir + dxsa.dcl_constant_buffer , + dxsa.dcl_constant_buffer , + ``` + }]; + let arguments = (ins + I32Attr:$id, + I32Attr:$size, + OptionalAttr:$lbound, + OptionalAttr:$ubound, + OptionalAttr:$space, + DXSA_ConstantBufferAccessPatternAttr:$access_pattern); + let assemblyFormat = [{ + ` ` `<` `id` `=` $id `,` `size` `=` $size + (`,` `lbound` `=` $lbound^ `,` `ubound` `=` $ubound + `,` `space` `=` $space)? + `>` `,` $access_pattern attr-dict + }]; + let hasVerifier = 1; +} + #endif // DXSA_OPS diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index f703a32e48ea..5e4e2abfde24 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -84,6 +84,15 @@ LogicalResult DclTgsmStructured::verify() { return success(); } +LogicalResult DclConstantBuffer::verify() { + auto lbound = getLbound(); + auto ubound = getUbound(); + if (lbound && ubound && *lbound > *ubound) + return emitOpError("expected lbound <= ubound, got lbound=") + << *lbound << ", ubound=" << *ubound; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd attribute method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index 1ac5161fbb96..9227e71d4e8a 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -692,6 +692,18 @@ class DXBuilder { builder.getI32IntegerAttr(structCount)); } + Instruction buildDclConstantBuffer( + uint32_t id, uint32_t size, std::optional lbound, + std::optional ubound, std::optional space, + dxsa::ConstantBufferAccessPattern accessPattern, Location loc) { + auto optionalToAttr = [&](std::optional v) -> IntegerAttr { + return v ? builder.getI32IntegerAttr(*v) : IntegerAttr(); + }; + return dxsa::DclConstantBuffer::create( + builder, loc, id, size, optionalToAttr(lbound), optionalToAttr(ubound), + optionalToAttr(space), accessPattern); + } + private: MLIRContext *context; ModuleOp module; @@ -1317,6 +1329,63 @@ class Parser { *structCount, loc); } + FailureOr parseDclConstantBuffer(uint32_t opcodeToken, + Location loc) { + auto rawAccessPattern = + DECODE_D3D10_SB_CONSTANT_BUFFER_ACCESS_PATTERN(opcodeToken); + auto accessPattern = + dxsa::symbolizeConstantBufferAccessPattern(rawAccessPattern); + if (!accessPattern) + return emitError(loc, "unknown constant buffer access pattern: ") + << rawAccessPattern; + + auto operandToken = parseToken(); + FAILURE_IF_FAILED(operandToken); + + auto operandType = DECODE_D3D10_SB_OPERAND_TYPE(*operandToken); + if (operandType != D3D10_SB_OPERAND_TYPE_CONSTANT_BUFFER) + return emitError(loc, "unexpected operand type: ") << operandType; + + if (DECODE_IS_D3D10_SB_OPERAND_EXTENDED(*operandToken)) + return emitError(loc, "extended operand tokens are not supported"); + + auto indexDim = DECODE_D3D10_SB_OPERAND_INDEX_DIMENSION(*operandToken); + if (indexDim != D3D10_SB_OPERAND_INDEX_2D && + indexDim != D3D10_SB_OPERAND_INDEX_3D) + return emitError(loc, "unsupported index dimension: ") << indexDim; + + SmallVector indices; + indices.reserve(indexDim); + for (uint32_t i = 0; i < indexDim; ++i) { + auto indexRepesentation = + DECODE_D3D10_SB_OPERAND_INDEX_REPRESENTATION(i, *operandToken); + if (indexRepesentation != D3D10_SB_OPERAND_INDEX_IMMEDIATE32) + return emitError(loc, "unsupported index representation: ") + << indexRepesentation; + auto value = parseToken(); + FAILURE_IF_FAILED(value); + indices.push_back(*value); + } + + switch (indexDim) { + case D3D10_SB_OPERAND_INDEX_2D: + return builder.buildDclConstantBuffer( + /*id=*/indices[0], /*size=*/indices[1], /*lbound=*/std::nullopt, + /*ubound=*/std::nullopt, /*space=*/std::nullopt, *accessPattern, loc); + case D3D10_SB_OPERAND_INDEX_3D: { + auto sizeToken = parseToken(); + FAILURE_IF_FAILED(sizeToken); + auto spaceToken = parseToken(); + FAILURE_IF_FAILED(spaceToken); + return builder.buildDclConstantBuffer( + /*id=*/indices[0], /*size=*/*sizeToken, /*lbound=*/indices[1], + /*ubound=*/indices[2], /*space=*/*spaceToken, *accessPattern, loc); + } + default: + llvm_unreachable("indexDim was validated above"); + } + } + OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc, Instruction &out) { FailureOr result; @@ -1393,6 +1462,9 @@ class Parser { case D3D11_SB_OPCODE_DCL_THREAD_GROUP_SHARED_MEMORY_STRUCTURED: result = parseDclTgsmStructured(loc); break; + case D3D10_SB_OPCODE_DCL_CONSTANT_BUFFER: + result = parseDclConstantBuffer(opcodeToken, loc); + break; default: return std::nullopt; } diff --git a/mlir/test/Target/DXSA/dcl_constant_buffer.mlir b/mlir/test/Target/DXSA/dcl_constant_buffer.mlir new file mode 100644 index 000000000000..2e2cf2ec2598 --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_constant_buffer.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_constant_buffer.bin | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: dxsa.dcl_constant_buffer , +// CHECK-NEXT: dxsa.dcl_constant_buffer , +// CHECK-NEXT: } diff --git a/mlir/test/Target/DXSA/dcl_constant_buffer_invalid.mlir b/mlir/test/Target/DXSA/dcl_constant_buffer_invalid.mlir new file mode 100644 index 000000000000..9f62f32e651c --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_constant_buffer_invalid.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{expected lbound <= ubound, got lbound=5, ubound=3}} +dxsa.dcl_constant_buffer , diff --git a/mlir/test/Target/DXSA/inputs/dcl_constant_buffer.bin b/mlir/test/Target/DXSA/inputs/dcl_constant_buffer.bin new file mode 100644 index 000000000000..23216be554f7 Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/dcl_constant_buffer.bin differ