Skip to content

Commit bae2b2c

Browse files
committed
[mlir][dxsa] Refine dcl_input and dcl_output instructions
Encode the operand as an inline attribute instead of separate SSA values. Example: dxsa.dcl_input <type = input, components = 4, mask = <x>, index = [0]> dxsa.dcl_output <type = output, components = 4, mask = <x, y, z, w>, index = [0]> Signed-off-by: Vladimir Shiryaev <tagolog@users.noreply.github.com>
1 parent b1ceaab commit bae2b2c

4 files changed

Lines changed: 241 additions & 23 deletions

File tree

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 153 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
include "mlir/Dialect/DXSA/IR/DXSADialect.td"
1313
include "mlir/Dialect/DXSA/IR/DXSATypes.td"
14+
include "mlir/IR/AttrTypeBase.td"
15+
include "mlir/IR/BuiltinAttributeInterfaces.td"
1416
include "mlir/IR/EnumAttr.td"
1517

1618
//===----------------------------------------------------------------------===//
@@ -253,6 +255,32 @@ def DXSA_SystemValueNameAttr :
253255
let assemblyFormat = "$value";
254256
}
255257

258+
//===----------------------------------------------------------------------===//
259+
// DXSA ComponentMask bit-enum (mask field of operand, normalized to bits 0..3)
260+
//===----------------------------------------------------------------------===//
261+
262+
def DXSA_ComponentMask_X : I32BitEnumAttrCaseBit<"x", 0>;
263+
def DXSA_ComponentMask_Y : I32BitEnumAttrCaseBit<"y", 1>;
264+
def DXSA_ComponentMask_Z : I32BitEnumAttrCaseBit<"z", 2>;
265+
def DXSA_ComponentMask_W : I32BitEnumAttrCaseBit<"w", 3>;
266+
267+
def DXSA_ComponentMask : I32BitEnumAttr<
268+
"ComponentMask", "component mask (subset of x, y, z, w)", [
269+
DXSA_ComponentMask_X,
270+
DXSA_ComponentMask_Y,
271+
DXSA_ComponentMask_Z,
272+
DXSA_ComponentMask_W
273+
]> {
274+
let separator = ", ";
275+
let cppNamespace = "::mlir::dxsa";
276+
let genSpecializedAttr = 0;
277+
}
278+
279+
def DXSA_ComponentMaskAttr :
280+
EnumAttr<DXSADialect, DXSA_ComponentMask, "component_mask"> {
281+
let assemblyFormat = "`<` $value `>`";
282+
}
283+
256284
//===----------------------------------------------------------------------===//
257285
// DXSA op definitions
258286
//===----------------------------------------------------------------------===//
@@ -335,6 +363,127 @@ def DXSA_Instruction : DXSA_Op<"instruction"> {
335363
let assemblyFormat = "$mnemonic $operands attr-dict";
336364
}
337365

366+
def DXSA_InlineOperandType_Temp : I32EnumAttrCase<"temp", 0>;
367+
def DXSA_InlineOperandType_Input : I32EnumAttrCase<"input", 1>;
368+
def DXSA_InlineOperandType_Output : I32EnumAttrCase<"output", 2>;
369+
def DXSA_InlineOperandType_IndexableTemp : I32EnumAttrCase<"indexable_temp", 3>;
370+
def DXSA_InlineOperandType_Immediate32 : I32EnumAttrCase<"immediate32", 4>;
371+
def DXSA_InlineOperandType_Immediate64 : I32EnumAttrCase<"immediate64", 5>;
372+
def DXSA_InlineOperandType_Sampler : I32EnumAttrCase<"sampler", 6>;
373+
def DXSA_InlineOperandType_Resource : I32EnumAttrCase<"resource", 7>;
374+
def DXSA_InlineOperandType_ConstantBuffer : I32EnumAttrCase<"constant_buffer", 8>;
375+
def DXSA_InlineOperandType_ImmediateConstantBuffer : I32EnumAttrCase<"immediate_constant_buffer", 9>;
376+
def DXSA_InlineOperandType_Label : I32EnumAttrCase<"label", 10>;
377+
def DXSA_InlineOperandType_InputPrimitiveId : I32EnumAttrCase<"input_primitive_id", 11>;
378+
def DXSA_InlineOperandType_OutputDepth : I32EnumAttrCase<"output_depth", 12>;
379+
def DXSA_InlineOperandType_Null : I32EnumAttrCase<"null", 13>;
380+
def DXSA_InlineOperandType_Rasterizer : I32EnumAttrCase<"rasterizer", 14>;
381+
def DXSA_InlineOperandType_OutputCoverageMask : I32EnumAttrCase<"output_coverage_mask", 15>;
382+
def DXSA_InlineOperandType_Stream : I32EnumAttrCase<"stream", 16>;
383+
def DXSA_InlineOperandType_FunctionBody : I32EnumAttrCase<"function_body", 17>;
384+
def DXSA_InlineOperandType_FunctionTable : I32EnumAttrCase<"function_table", 18>;
385+
def DXSA_InlineOperandType_Interface : I32EnumAttrCase<"interface", 19>;
386+
def DXSA_InlineOperandType_FunctionInput : I32EnumAttrCase<"function_input", 20>;
387+
def DXSA_InlineOperandType_FunctionOutput : I32EnumAttrCase<"function_output", 21>;
388+
def DXSA_InlineOperandType_OutputControlPointId : I32EnumAttrCase<"output_control_point_id", 22>;
389+
def DXSA_InlineOperandType_InputForkInstanceId : I32EnumAttrCase<"input_fork_instance_id", 23>;
390+
def DXSA_InlineOperandType_InputJoinInstanceId : I32EnumAttrCase<"input_join_instance_id", 24>;
391+
def DXSA_InlineOperandType_InputControlPoint : I32EnumAttrCase<"input_control_point", 25>;
392+
def DXSA_InlineOperandType_OutputControlPoint : I32EnumAttrCase<"output_control_point", 26>;
393+
def DXSA_InlineOperandType_InputPatchConstant : I32EnumAttrCase<"input_patch_constant", 27>;
394+
def DXSA_InlineOperandType_InputDomainPoint : I32EnumAttrCase<"input_domain_point", 28>;
395+
def DXSA_InlineOperandType_ThisPointer : I32EnumAttrCase<"this_pointer", 29>;
396+
def DXSA_InlineOperandType_Uav : I32EnumAttrCase<"uav", 30>;
397+
def DXSA_InlineOperandType_ThreadGroupSharedMemory : I32EnumAttrCase<"thread_group_shared_memory", 31>;
398+
def DXSA_InlineOperandType_InputThreadId : I32EnumAttrCase<"input_thread_id", 32>;
399+
def DXSA_InlineOperandType_InputThreadGroupId : I32EnumAttrCase<"input_thread_group_id", 33>;
400+
def DXSA_InlineOperandType_InputThreadIdInGroup : I32EnumAttrCase<"input_thread_id_in_group", 34>;
401+
def DXSA_InlineOperandType_InputCoverageMask : I32EnumAttrCase<"input_coverage_mask", 35>;
402+
def DXSA_InlineOperandType_InputThreadIdInGroupFlattened : I32EnumAttrCase<"input_thread_id_in_group_flattened", 36>;
403+
def DXSA_InlineOperandType_InputGsInstanceId : I32EnumAttrCase<"input_gs_instance_id", 37>;
404+
def DXSA_InlineOperandType_OutputDepthGe : I32EnumAttrCase<"output_depth_ge", 38>;
405+
def DXSA_InlineOperandType_OutputDepthLe : I32EnumAttrCase<"output_depth_le", 39>;
406+
def DXSA_InlineOperandType_CycleCounter : I32EnumAttrCase<"cycle_counter", 40>;
407+
def DXSA_InlineOperandType_OutputStencilRef : I32EnumAttrCase<"output_stencil_ref", 41>;
408+
def DXSA_InlineOperandType_InnerCoverage : I32EnumAttrCase<"inner_coverage", 42>;
409+
410+
def DXSA_InlineOperandType : I32EnumAttr<
411+
"InlineOperandType", "operand type", [
412+
DXSA_InlineOperandType_Temp,
413+
DXSA_InlineOperandType_Input,
414+
DXSA_InlineOperandType_Output,
415+
DXSA_InlineOperandType_IndexableTemp,
416+
DXSA_InlineOperandType_Immediate32,
417+
DXSA_InlineOperandType_Immediate64,
418+
DXSA_InlineOperandType_Sampler,
419+
DXSA_InlineOperandType_Resource,
420+
DXSA_InlineOperandType_ConstantBuffer,
421+
DXSA_InlineOperandType_ImmediateConstantBuffer,
422+
DXSA_InlineOperandType_Label,
423+
DXSA_InlineOperandType_InputPrimitiveId,
424+
DXSA_InlineOperandType_OutputDepth,
425+
DXSA_InlineOperandType_Null,
426+
DXSA_InlineOperandType_Rasterizer,
427+
DXSA_InlineOperandType_OutputCoverageMask,
428+
DXSA_InlineOperandType_Stream,
429+
DXSA_InlineOperandType_FunctionBody,
430+
DXSA_InlineOperandType_FunctionTable,
431+
DXSA_InlineOperandType_Interface,
432+
DXSA_InlineOperandType_FunctionInput,
433+
DXSA_InlineOperandType_FunctionOutput,
434+
DXSA_InlineOperandType_OutputControlPointId,
435+
DXSA_InlineOperandType_InputForkInstanceId,
436+
DXSA_InlineOperandType_InputJoinInstanceId,
437+
DXSA_InlineOperandType_InputControlPoint,
438+
DXSA_InlineOperandType_OutputControlPoint,
439+
DXSA_InlineOperandType_InputPatchConstant,
440+
DXSA_InlineOperandType_InputDomainPoint,
441+
DXSA_InlineOperandType_ThisPointer,
442+
DXSA_InlineOperandType_Uav,
443+
DXSA_InlineOperandType_ThreadGroupSharedMemory,
444+
DXSA_InlineOperandType_InputThreadId,
445+
DXSA_InlineOperandType_InputThreadGroupId,
446+
DXSA_InlineOperandType_InputThreadIdInGroup,
447+
DXSA_InlineOperandType_InputCoverageMask,
448+
DXSA_InlineOperandType_InputThreadIdInGroupFlattened,
449+
DXSA_InlineOperandType_InputGsInstanceId,
450+
DXSA_InlineOperandType_OutputDepthGe,
451+
DXSA_InlineOperandType_OutputDepthLe,
452+
DXSA_InlineOperandType_CycleCounter,
453+
DXSA_InlineOperandType_OutputStencilRef,
454+
DXSA_InlineOperandType_InnerCoverage
455+
]> {
456+
let cppNamespace = "::mlir::dxsa";
457+
let genSpecializedAttr = 0;
458+
}
459+
460+
def DXSA_InlineOperandAttr : AttrDef<DXSADialect, "InlineOperand"> {
461+
let mnemonic = "inline_operand";
462+
let summary = "inline operand of an instruction";
463+
let description = [{
464+
The `#dxsa.inline_operand` attribute carries a fully decoded operand token
465+
466+
Example:
467+
468+
```mlir
469+
dxsa.dcl_output <type = output, components = 4, mask = <x, y, z, w>, index = [0]>
470+
dxsa.dcl_output <type = output_depth, components = 1>
471+
```
472+
}];
473+
let parameters = (ins
474+
EnumParameter<DXSA_InlineOperandType>:$type,
475+
"uint32_t":$components,
476+
OptionalParameter<"::mlir::dxsa::ComponentMaskAttr">:$mask,
477+
OptionalParameter<"::mlir::DenseI64ArrayAttr">:$index);
478+
let assemblyFormat = [{
479+
`<` `type` `=` $type
480+
`,` `components` `=` $components
481+
(`,` `mask` `=` $mask^)?
482+
(`,` `index` `=` $index^)?
483+
`>`
484+
}];
485+
}
486+
338487
def DXSA_DclGlobalFlags : DXSA_Op<"dcl_global_flags"> {
339488
let summary = "declares global shader flags";
340489
let description = [{
@@ -545,11 +694,10 @@ def DXSA_DclInput : DXSA_Op<"dcl_input"> {
545694
Example:
546695

547696
```mlir
548-
549-
dxsa.dcl_input %v0
697+
dxsa.dcl_input <type = input, components = 4, mask = <x>, index = [0]>
550698
```
551699
}];
552-
let arguments = (ins DXSA_OperandType:$operand);
700+
let arguments = (ins DXSA_InlineOperandAttr:$operand);
553701
let assemblyFormat = "$operand attr-dict";
554702
}
555703

@@ -564,10 +712,10 @@ def DXSA_DclOutput : DXSA_Op<"dcl_output"> {
564712
Example:
565713

566714
```mlir
567-
dxsa.dcl_output %o0
715+
dxsa.dcl_output <type = output, components = 4, mask = <x, y, z, w>, index = [0]>
568716
```
569717
}];
570-
let arguments = (ins DXSA_OperandType:$operand);
718+
let arguments = (ins DXSA_InlineOperandAttr:$operand);
571719
let assemblyFormat = "$operand attr-dict";
572720
}
573721

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ using UINT = unsigned int;
3030
using namespace mlir;
3131
using namespace llvm;
3232

33+
#define FAILURE_IF_FAILED(RES) \
34+
if (failed(RES)) \
35+
return failure();
36+
3337
enum OpcodeClass {
3438
D3D10_SB_FLOAT_OP,
3539
D3D10_SB_INT_OP,
@@ -400,6 +404,19 @@ struct OperandComponents {
400404
};
401405
};
402406

407+
static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) {
408+
auto componentMask = static_cast<dxsa::ComponentMask>(0);
409+
if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_X)
410+
componentMask |= dxsa::ComponentMask::x;
411+
if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Y)
412+
componentMask |= dxsa::ComponentMask::y;
413+
if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Z)
414+
componentMask |= dxsa::ComponentMask::z;
415+
if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_W)
416+
componentMask |= dxsa::ComponentMask::w;
417+
return componentMask;
418+
}
419+
403420
class DXBuilder {
404421
public:
405422
DXBuilder(MLIRContext *context, StringAttr name)
@@ -597,11 +614,24 @@ class DXBuilder {
597614
systemValueNameAttr);
598615
}
599616

600-
Instruction buildDclInput(Operand operand, Location loc) {
617+
dxsa::InlineOperandAttr buildInlineOperandAttr(
618+
dxsa::InlineOperandType operandType, uint32_t components,
619+
std::optional<dxsa::ComponentMask> mask, ArrayRef<int64_t> indexArray) {
620+
auto *ctx = builder.getContext();
621+
auto maskAttr = mask ? dxsa::ComponentMaskAttr::get(ctx, *mask)
622+
: dxsa::ComponentMaskAttr();
623+
auto indexAttr = indexArray.empty()
624+
? DenseI64ArrayAttr()
625+
: DenseI64ArrayAttr::get(ctx, indexArray);
626+
return dxsa::InlineOperandAttr::get(ctx, operandType, components, maskAttr,
627+
indexAttr);
628+
}
629+
630+
Instruction buildDclInput(dxsa::InlineOperandAttr operand, Location loc) {
601631
return dxsa::DclInput::create(builder, loc, operand);
602632
}
603633

604-
Instruction buildDclOutput(Operand operand, Location loc) {
634+
Instruction buildDclOutput(dxsa::InlineOperandAttr operand, Location loc) {
605635
return dxsa::DclOutput::create(builder, loc, operand);
606636
}
607637

@@ -1084,17 +1114,63 @@ class Parser {
10841114
return builder.buildDclInputPsSgv(*operand, *systemValueName, loc);
10851115
}
10861116

1117+
FailureOr<dxsa::InlineOperandAttr> parseInlineOperand() {
1118+
auto token = parseToken();
1119+
FAILURE_IF_FAILED(token);
1120+
1121+
auto rawOperandType = DECODE_D3D10_SB_OPERAND_TYPE(*token);
1122+
auto isExtended = DECODE_IS_D3D10_SB_OPERAND_EXTENDED(*token);
1123+
auto loc = getLocation();
1124+
1125+
if (isImmOperand(*token))
1126+
return emitError(loc, "immediate operand is not supported yet");
1127+
1128+
auto type = dxsa::symbolizeInlineOperandType(rawOperandType);
1129+
if (!type)
1130+
return emitError(loc, "unknown operand type: ") << rawOperandType;
1131+
1132+
auto components = parseOperandComponents(*token);
1133+
FAILURE_IF_FAILED(components);
1134+
1135+
auto indexTypes = parseOperandIndexTypes(*token);
1136+
FAILURE_IF_FAILED(indexTypes);
1137+
1138+
if (isExtended)
1139+
return emitError(loc, "extended operand tokens are not yet supported in "
1140+
"inline operand attribute");
1141+
1142+
if (components->kind == OperandComponentsKind::Swizzle ||
1143+
components->kind == OperandComponentsKind::One)
1144+
return emitError(loc, "swizzled / single-component operand selection is "
1145+
"not supported in inline operand attribute");
1146+
1147+
std::optional<dxsa::ComponentMask> mask;
1148+
if (components->kind == OperandComponentsKind::Mask)
1149+
mask = decodeComponentMask(components->mask);
1150+
1151+
SmallVector<int64_t, 3> indices;
1152+
for (auto indexType : *indexTypes) {
1153+
if (indexType != D3D10_SB_OPERAND_INDEX_IMMEDIATE32)
1154+
return emitError(getLocation(), "unsupported index representation: ")
1155+
<< indexType;
1156+
auto value = parseToken();
1157+
FAILURE_IF_FAILED(value);
1158+
indices.push_back(static_cast<int32_t>(*value));
1159+
}
1160+
1161+
return builder.buildInlineOperandAttr(*type, components->num, mask,
1162+
indices);
1163+
}
1164+
10871165
FailureOr<Instruction> parseDclInput(Location loc) {
1088-
auto operand = parseOperand();
1089-
if (failed(operand))
1090-
return failure();
1166+
auto operand = parseInlineOperand();
1167+
FAILURE_IF_FAILED(operand);
10911168
return builder.buildDclInput(*operand, loc);
10921169
}
10931170

10941171
FailureOr<Instruction> parseDclOutput(Location loc) {
1095-
auto operand = parseOperand();
1096-
if (failed(operand))
1097-
return failure();
1172+
auto operand = parseInlineOperand();
1173+
FAILURE_IF_FAILED(operand);
10981174
return builder.buildDclOutput(*operand, loc);
10991175
}
11001176

mlir/test/Target/DXSA/dcl_input.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
// CHECK-LABEL: module
44
module {
55
// dcl_input v0.x
6-
// CHECK: %0 = dxsa.index.imm {imm = 0 : i32}
7-
// CHECK-NEXT: %1 = dxsa.operand %0 {mask = 16 : i32, num_components = 4 : i32, type = 1 : i32}
8-
// CHECK-NEXT: dxsa.dcl_input %1
6+
// CHECK: dxsa.dcl_input <type = input, components = 4, mask = <x>, index = [0]>
97

108
// dcl_input vOutputControlPointID
11-
// CHECK-NEXT: %2 = dxsa.operand {num_components = 1 : i32, type = 22 : i32}
12-
// CHECK-NEXT: dxsa.dcl_input %2
9+
// CHECK-NEXT: dxsa.dcl_input <type = output_control_point_id, components = 1>
1310
}

mlir/test/Target/DXSA/dcl_output.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
// CHECK-LABEL: module
44
module {
55
// dcl_output o0.xyzw
6-
// CHECK: %0 = dxsa.index.imm {imm = 0 : i32}
7-
// CHECK-NEXT: %1 = dxsa.operand %0 {mask = 240 : i32, num_components = 4 : i32, type = 2 : i32}
8-
// CHECK-NEXT: dxsa.dcl_output %1
6+
// CHECK: dxsa.dcl_output <type = output, components = 4, mask = <x, y, z, w>, index = [0]>
97

108
// dcl_output oDepth
11-
// CHECK-NEXT: %2 = dxsa.operand {num_components = 1 : i32, type = 12 : i32}
12-
// CHECK-NEXT: dxsa.dcl_output %2
9+
// CHECK-NEXT: dxsa.dcl_output <type = output_depth, components = 1>
1310
}

0 commit comments

Comments
 (0)