Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,24 @@ def DXSA_ComponentMaskAttr :
let assemblyFormat = "`<` $value `>`";
}

def DXSA_ConstantBufferAccessPattern_ImmediateIndexed : I32EnumAttrCase<"immediateIndexed", 0>;
def DXSA_ConstantBufferAccessPattern_DynamicIndexed : I32EnumAttrCase<"dynamicIndexed", 1>;

def DXSA_ConstantBufferAccessPattern : I32EnumAttr<
"ConstantBufferAccessPattern", "constant buffer access pattern", [
DXSA_ConstantBufferAccessPattern_ImmediateIndexed,
DXSA_ConstantBufferAccessPattern_DynamicIndexed
]> {
let cppNamespace = "::mlir::dxsa";
let genSpecializedAttr = 0;
}

def DXSA_ConstantBufferAccessPatternAttr :
EnumAttr<DXSADialect, DXSA_ConstantBufferAccessPattern,
"constant_buffer_access_pattern"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// DXSA attribute Constraints
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -922,4 +940,33 @@ def DXSA_DclTgsmStructured : DXSA_Op<"dcl_tgsm_structured"> {
let hasVerifier = 1;
}

def DXSA_DclConstantBuffer : DXSA_Op<"dcl_constant_buffer"> {
let summary = "declares a constant buffer";
let description = [{
The `dxsa.dcl_constant_buffer` operation declares a constant buffer
with its access pattern.

Examples:

```mlir
dxsa.dcl_constant_buffer <id = 0, size = 1>, <immediateIndexed>
dxsa.dcl_constant_buffer <id = 0, size = 4, lbound = 0, ubound = 3, space = 1>, <dynamicIndexed>
```
}];
let arguments = (ins
I32Attr:$id,
I32Attr:$size,
OptionalAttr<I32Attr>:$lbound,
OptionalAttr<I32Attr>:$ubound,
OptionalAttr<I32Attr>:$space,
DXSA_ConstantBufferAccessPatternAttr:$access_pattern);
let assemblyFormat = [{
` ` `<` `id` `=` $id `,` `size` `=` $size
(`,` `lbound` `=` $lbound^ `,` `ubound` `=` $ubound
`,` `space` `=` $space)?
`>` `,` $access_pattern attr-dict
}];
let hasVerifier = 1;
}

#endif // DXSA_OPS
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/DXSA/IR/DXSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ LogicalResult DclTgsmStructured::verify() {
return success();
}

LogicalResult DclConstantBuffer::verify() {
auto lbound = getLbound();
auto ubound = getUbound();
if (lbound && ubound && *lbound > *ubound)
return emitOpError("expected lbound <= ubound, got lbound=")
<< *lbound << ", ubound=" << *ubound;
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd attribute method definitions
//===----------------------------------------------------------------------===//
Expand Down
72 changes: 72 additions & 0 deletions mlir/lib/Target/DXSA/BinaryParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,18 @@ class DXBuilder {
builder.getI32IntegerAttr(structCount));
}

Instruction buildDclConstantBuffer(
uint32_t id, uint32_t size, std::optional<uint32_t> lbound,
std::optional<uint32_t> ubound, std::optional<uint32_t> space,
dxsa::ConstantBufferAccessPattern accessPattern, Location loc) {
auto optionalToAttr = [&](std::optional<uint32_t> v) -> IntegerAttr {
return v ? builder.getI32IntegerAttr(*v) : IntegerAttr();
};
return dxsa::DclConstantBuffer::create(
builder, loc, id, size, optionalToAttr(lbound), optionalToAttr(ubound),
optionalToAttr(space), accessPattern);
}

private:
MLIRContext *context;
ModuleOp module;
Expand Down Expand Up @@ -1317,6 +1329,63 @@ class Parser {
*structCount, loc);
}

FailureOr<Instruction> parseDclConstantBuffer(uint32_t opcodeToken,
Comment thread
tagolog marked this conversation as resolved.
Location loc) {
auto rawAccessPattern =
DECODE_D3D10_SB_CONSTANT_BUFFER_ACCESS_PATTERN(opcodeToken);
auto accessPattern =
dxsa::symbolizeConstantBufferAccessPattern(rawAccessPattern);
if (!accessPattern)
return emitError(loc, "unknown constant buffer access pattern: ")
<< rawAccessPattern;

auto operandToken = parseToken();
FAILURE_IF_FAILED(operandToken);

auto operandType = DECODE_D3D10_SB_OPERAND_TYPE(*operandToken);
if (operandType != D3D10_SB_OPERAND_TYPE_CONSTANT_BUFFER)
return emitError(loc, "unexpected operand type: ") << operandType;

if (DECODE_IS_D3D10_SB_OPERAND_EXTENDED(*operandToken))
return emitError(loc, "extended operand tokens are not supported");

auto indexDim = DECODE_D3D10_SB_OPERAND_INDEX_DIMENSION(*operandToken);
if (indexDim != D3D10_SB_OPERAND_INDEX_2D &&
indexDim != D3D10_SB_OPERAND_INDEX_3D)
return emitError(loc, "unsupported index dimension: ") << indexDim;

SmallVector<uint32_t, 3> indices;
indices.reserve(indexDim);
for (uint32_t i = 0; i < indexDim; ++i) {
auto indexRepesentation =
DECODE_D3D10_SB_OPERAND_INDEX_REPRESENTATION(i, *operandToken);
if (indexRepesentation != D3D10_SB_OPERAND_INDEX_IMMEDIATE32)
return emitError(loc, "unsupported index representation: ")
<< indexRepesentation;
auto value = parseToken();
FAILURE_IF_FAILED(value);
indices.push_back(*value);
}

switch (indexDim) {
case D3D10_SB_OPERAND_INDEX_2D:
return builder.buildDclConstantBuffer(
/*id=*/indices[0], /*size=*/indices[1], /*lbound=*/std::nullopt,
/*ubound=*/std::nullopt, /*space=*/std::nullopt, *accessPattern, loc);
case D3D10_SB_OPERAND_INDEX_3D: {
auto sizeToken = parseToken();
FAILURE_IF_FAILED(sizeToken);
auto spaceToken = parseToken();
FAILURE_IF_FAILED(spaceToken);
return builder.buildDclConstantBuffer(
/*id=*/indices[0], /*size=*/*sizeToken, /*lbound=*/indices[1],
/*ubound=*/indices[2], /*space=*/*spaceToken, *accessPattern, loc);
}
default:
llvm_unreachable("indexDim was validated above");
}
}

OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
Instruction &out) {
FailureOr<Instruction> result;
Expand Down Expand Up @@ -1393,6 +1462,9 @@ class Parser {
case D3D11_SB_OPCODE_DCL_THREAD_GROUP_SHARED_MEMORY_STRUCTURED:
result = parseDclTgsmStructured(loc);
break;
case D3D10_SB_OPCODE_DCL_CONSTANT_BUFFER:
result = parseDclConstantBuffer(opcodeToken, loc);
break;
default:
return std::nullopt;
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Target/DXSA/dcl_constant_buffer.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_constant_buffer.bin | FileCheck %s

// CHECK: module {
// CHECK-NEXT: dxsa.dcl_constant_buffer <id = 0, size = 1>, <immediateIndexed>
// CHECK-NEXT: dxsa.dcl_constant_buffer <id = 0, size = 4, lbound = 0, ubound = 3, space = 1>, <dynamicIndexed>
// CHECK-NEXT: }
4 changes: 4 additions & 0 deletions mlir/test/Target/DXSA/dcl_constant_buffer_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+1 {{expected lbound <= ubound, got lbound=5, ubound=3}}
dxsa.dcl_constant_buffer <id = 0, size = 4, lbound = 5, ubound = 3, space = 1>, <dynamicIndexed>
Binary file not shown.