diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index 74b850a7ed7c..ed5a295a38db 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -11,6 +11,8 @@ include "mlir/Dialect/DXSA/IR/DXSADialect.td" include "mlir/Dialect/DXSA/IR/DXSATypes.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" //===----------------------------------------------------------------------===// @@ -253,6 +255,32 @@ def DXSA_SystemValueNameAttr : let assemblyFormat = "$value"; } +//===----------------------------------------------------------------------===// +// DXSA ComponentMask bit-enum (mask field of operand, normalized to bits 0..3) +//===----------------------------------------------------------------------===// + +def DXSA_ComponentMask_X : I32BitEnumAttrCaseBit<"x", 0>; +def DXSA_ComponentMask_Y : I32BitEnumAttrCaseBit<"y", 1>; +def DXSA_ComponentMask_Z : I32BitEnumAttrCaseBit<"z", 2>; +def DXSA_ComponentMask_W : I32BitEnumAttrCaseBit<"w", 3>; + +def DXSA_ComponentMask : I32BitEnumAttr< + "ComponentMask", "component mask (subset of x, y, z, w)", [ + DXSA_ComponentMask_X, + DXSA_ComponentMask_Y, + DXSA_ComponentMask_Z, + DXSA_ComponentMask_W + ]> { + let separator = ", "; + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_ComponentMaskAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // DXSA op definitions //===----------------------------------------------------------------------===// @@ -335,6 +363,127 @@ def DXSA_Instruction : DXSA_Op<"instruction"> { let assemblyFormat = "$mnemonic $operands attr-dict"; } +def DXSA_InlineOperandType_Temp : I32EnumAttrCase<"temp", 0>; +def DXSA_InlineOperandType_Input : I32EnumAttrCase<"input", 1>; +def DXSA_InlineOperandType_Output : I32EnumAttrCase<"output", 2>; +def DXSA_InlineOperandType_IndexableTemp : I32EnumAttrCase<"indexable_temp", 3>; +def DXSA_InlineOperandType_Immediate32 : I32EnumAttrCase<"immediate32", 4>; +def DXSA_InlineOperandType_Immediate64 : I32EnumAttrCase<"immediate64", 5>; +def DXSA_InlineOperandType_Sampler : I32EnumAttrCase<"sampler", 6>; +def DXSA_InlineOperandType_Resource : I32EnumAttrCase<"resource", 7>; +def DXSA_InlineOperandType_ConstantBuffer : I32EnumAttrCase<"constant_buffer", 8>; +def DXSA_InlineOperandType_ImmediateConstantBuffer : I32EnumAttrCase<"immediate_constant_buffer", 9>; +def DXSA_InlineOperandType_Label : I32EnumAttrCase<"label", 10>; +def DXSA_InlineOperandType_InputPrimitiveId : I32EnumAttrCase<"input_primitive_id", 11>; +def DXSA_InlineOperandType_OutputDepth : I32EnumAttrCase<"output_depth", 12>; +def DXSA_InlineOperandType_Null : I32EnumAttrCase<"null", 13>; +def DXSA_InlineOperandType_Rasterizer : I32EnumAttrCase<"rasterizer", 14>; +def DXSA_InlineOperandType_OutputCoverageMask : I32EnumAttrCase<"output_coverage_mask", 15>; +def DXSA_InlineOperandType_Stream : I32EnumAttrCase<"stream", 16>; +def DXSA_InlineOperandType_FunctionBody : I32EnumAttrCase<"function_body", 17>; +def DXSA_InlineOperandType_FunctionTable : I32EnumAttrCase<"function_table", 18>; +def DXSA_InlineOperandType_Interface : I32EnumAttrCase<"interface", 19>; +def DXSA_InlineOperandType_FunctionInput : I32EnumAttrCase<"function_input", 20>; +def DXSA_InlineOperandType_FunctionOutput : I32EnumAttrCase<"function_output", 21>; +def DXSA_InlineOperandType_OutputControlPointId : I32EnumAttrCase<"output_control_point_id", 22>; +def DXSA_InlineOperandType_InputForkInstanceId : I32EnumAttrCase<"input_fork_instance_id", 23>; +def DXSA_InlineOperandType_InputJoinInstanceId : I32EnumAttrCase<"input_join_instance_id", 24>; +def DXSA_InlineOperandType_InputControlPoint : I32EnumAttrCase<"input_control_point", 25>; +def DXSA_InlineOperandType_OutputControlPoint : I32EnumAttrCase<"output_control_point", 26>; +def DXSA_InlineOperandType_InputPatchConstant : I32EnumAttrCase<"input_patch_constant", 27>; +def DXSA_InlineOperandType_InputDomainPoint : I32EnumAttrCase<"input_domain_point", 28>; +def DXSA_InlineOperandType_ThisPointer : I32EnumAttrCase<"this_pointer", 29>; +def DXSA_InlineOperandType_Uav : I32EnumAttrCase<"uav", 30>; +def DXSA_InlineOperandType_ThreadGroupSharedMemory : I32EnumAttrCase<"thread_group_shared_memory", 31>; +def DXSA_InlineOperandType_InputThreadId : I32EnumAttrCase<"input_thread_id", 32>; +def DXSA_InlineOperandType_InputThreadGroupId : I32EnumAttrCase<"input_thread_group_id", 33>; +def DXSA_InlineOperandType_InputThreadIdInGroup : I32EnumAttrCase<"input_thread_id_in_group", 34>; +def DXSA_InlineOperandType_InputCoverageMask : I32EnumAttrCase<"input_coverage_mask", 35>; +def DXSA_InlineOperandType_InputThreadIdInGroupFlattened : I32EnumAttrCase<"input_thread_id_in_group_flattened", 36>; +def DXSA_InlineOperandType_InputGsInstanceId : I32EnumAttrCase<"input_gs_instance_id", 37>; +def DXSA_InlineOperandType_OutputDepthGe : I32EnumAttrCase<"output_depth_ge", 38>; +def DXSA_InlineOperandType_OutputDepthLe : I32EnumAttrCase<"output_depth_le", 39>; +def DXSA_InlineOperandType_CycleCounter : I32EnumAttrCase<"cycle_counter", 40>; +def DXSA_InlineOperandType_OutputStencilRef : I32EnumAttrCase<"output_stencil_ref", 41>; +def DXSA_InlineOperandType_InnerCoverage : I32EnumAttrCase<"inner_coverage", 42>; + +def DXSA_InlineOperandType : I32EnumAttr< + "InlineOperandType", "operand type", [ + DXSA_InlineOperandType_Temp, + DXSA_InlineOperandType_Input, + DXSA_InlineOperandType_Output, + DXSA_InlineOperandType_IndexableTemp, + DXSA_InlineOperandType_Immediate32, + DXSA_InlineOperandType_Immediate64, + DXSA_InlineOperandType_Sampler, + DXSA_InlineOperandType_Resource, + DXSA_InlineOperandType_ConstantBuffer, + DXSA_InlineOperandType_ImmediateConstantBuffer, + DXSA_InlineOperandType_Label, + DXSA_InlineOperandType_InputPrimitiveId, + DXSA_InlineOperandType_OutputDepth, + DXSA_InlineOperandType_Null, + DXSA_InlineOperandType_Rasterizer, + DXSA_InlineOperandType_OutputCoverageMask, + DXSA_InlineOperandType_Stream, + DXSA_InlineOperandType_FunctionBody, + DXSA_InlineOperandType_FunctionTable, + DXSA_InlineOperandType_Interface, + DXSA_InlineOperandType_FunctionInput, + DXSA_InlineOperandType_FunctionOutput, + DXSA_InlineOperandType_OutputControlPointId, + DXSA_InlineOperandType_InputForkInstanceId, + DXSA_InlineOperandType_InputJoinInstanceId, + DXSA_InlineOperandType_InputControlPoint, + DXSA_InlineOperandType_OutputControlPoint, + DXSA_InlineOperandType_InputPatchConstant, + DXSA_InlineOperandType_InputDomainPoint, + DXSA_InlineOperandType_ThisPointer, + DXSA_InlineOperandType_Uav, + DXSA_InlineOperandType_ThreadGroupSharedMemory, + DXSA_InlineOperandType_InputThreadId, + DXSA_InlineOperandType_InputThreadGroupId, + DXSA_InlineOperandType_InputThreadIdInGroup, + DXSA_InlineOperandType_InputCoverageMask, + DXSA_InlineOperandType_InputThreadIdInGroupFlattened, + DXSA_InlineOperandType_InputGsInstanceId, + DXSA_InlineOperandType_OutputDepthGe, + DXSA_InlineOperandType_OutputDepthLe, + DXSA_InlineOperandType_CycleCounter, + DXSA_InlineOperandType_OutputStencilRef, + DXSA_InlineOperandType_InnerCoverage + ]> { + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_InlineOperandAttr : AttrDef { + let mnemonic = "inline_operand"; + let summary = "inline operand of an instruction"; + let description = [{ + The `#dxsa.inline_operand` attribute carries a fully decoded operand token + + Example: + + ```mlir + dxsa.dcl_output , index = [0]> + dxsa.dcl_output + ``` + }]; + let parameters = (ins + EnumParameter:$type, + "uint32_t":$components, + OptionalParameter<"::mlir::dxsa::ComponentMaskAttr">:$mask, + OptionalParameter<"::mlir::DenseI64ArrayAttr">:$index); + let assemblyFormat = [{ + `<` `type` `=` $type + `,` `components` `=` $components + (`,` `mask` `=` $mask^)? + (`,` `index` `=` $index^)? + `>` + }]; +} + def DXSA_DclGlobalFlags : DXSA_Op<"dcl_global_flags"> { let summary = "declares global shader flags"; let description = [{ @@ -534,4 +683,40 @@ def DXSA_DclInputPsSgv : DXSA_Op<"dcl_input_ps_sgv"> { let assemblyFormat = "$operand `,` $name attr-dict"; } +def DXSA_DclInput : DXSA_Op<"dcl_input"> { + let summary = "declares a shader input register"; + let description = [{ + The `dxsa.dcl_input` operation declares a shader input register. + + The register operand can be either an indexed input register + or a special scalar input. + + Example: + + ```mlir + dxsa.dcl_input , index = [0]> + ``` + }]; + let arguments = (ins DXSA_InlineOperandAttr:$operand); + let assemblyFormat = "$operand attr-dict"; +} + +def DXSA_DclOutput : DXSA_Op<"dcl_output"> { + let summary = "declares a shader output register"; + let description = [{ + The `dxsa.dcl_output` operation declares a shader output register. + + The register operand can be either an indexed output register + or a special scalar output. + + Example: + + ```mlir + dxsa.dcl_output , index = [0]> + ``` + }]; + let arguments = (ins DXSA_InlineOperandAttr:$operand); + let assemblyFormat = "$operand attr-dict"; +} + #endif // DXSA_OPS diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index c45a616a5648..f1c860c758a2 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -30,6 +30,10 @@ using UINT = unsigned int; using namespace mlir; using namespace llvm; +#define FAILURE_IF_FAILED(RES) \ + if (failed(RES)) \ + return failure(); + enum OpcodeClass { D3D10_SB_FLOAT_OP, D3D10_SB_INT_OP, @@ -400,6 +404,19 @@ struct OperandComponents { }; }; +static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) { + auto componentMask = static_cast(0); + if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_X) + componentMask |= dxsa::ComponentMask::x; + if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Y) + componentMask |= dxsa::ComponentMask::y; + if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Z) + componentMask |= dxsa::ComponentMask::z; + if (rawComponentMask & D3D10_SB_OPERAND_4_COMPONENT_MASK_W) + componentMask |= dxsa::ComponentMask::w; + return componentMask; +} + class DXBuilder { public: DXBuilder(MLIRContext *context, StringAttr name) @@ -597,6 +614,27 @@ class DXBuilder { systemValueNameAttr); } + dxsa::InlineOperandAttr buildInlineOperandAttr( + dxsa::InlineOperandType operandType, uint32_t components, + std::optional mask, ArrayRef indexArray) { + auto *ctx = builder.getContext(); + auto maskAttr = mask ? dxsa::ComponentMaskAttr::get(ctx, *mask) + : dxsa::ComponentMaskAttr(); + auto indexAttr = indexArray.empty() + ? DenseI64ArrayAttr() + : DenseI64ArrayAttr::get(ctx, indexArray); + return dxsa::InlineOperandAttr::get(ctx, operandType, components, maskAttr, + indexAttr); + } + + Instruction buildDclInput(dxsa::InlineOperandAttr operand, Location loc) { + return dxsa::DclInput::create(builder, loc, operand); + } + + Instruction buildDclOutput(dxsa::InlineOperandAttr operand, Location loc) { + return dxsa::DclOutput::create(builder, loc, operand); + } + private: MLIRContext *context; ModuleOp module; @@ -1076,6 +1114,66 @@ class Parser { return builder.buildDclInputPsSgv(*operand, *systemValueName, loc); } + FailureOr parseInlineOperand() { + auto token = parseToken(); + FAILURE_IF_FAILED(token); + + auto rawOperandType = DECODE_D3D10_SB_OPERAND_TYPE(*token); + auto isExtended = DECODE_IS_D3D10_SB_OPERAND_EXTENDED(*token); + auto loc = getLocation(); + + if (isImmOperand(*token)) + return emitError(loc, "immediate operand is not supported yet"); + + auto type = dxsa::symbolizeInlineOperandType(rawOperandType); + if (!type) + return emitError(loc, "unknown operand type: ") << rawOperandType; + + auto components = parseOperandComponents(*token); + FAILURE_IF_FAILED(components); + + auto indexTypes = parseOperandIndexTypes(*token); + FAILURE_IF_FAILED(indexTypes); + + if (isExtended) + return emitError(loc, "extended operand tokens are not yet supported in " + "inline operand attribute"); + + if (components->kind == OperandComponentsKind::Swizzle || + components->kind == OperandComponentsKind::One) + return emitError(loc, "swizzled / single-component operand selection is " + "not supported in inline operand attribute"); + + std::optional mask; + if (components->kind == OperandComponentsKind::Mask) + mask = decodeComponentMask(components->mask); + + SmallVector indices; + for (auto indexType : *indexTypes) { + if (indexType != D3D10_SB_OPERAND_INDEX_IMMEDIATE32) + return emitError(getLocation(), "unsupported index representation: ") + << indexType; + auto value = parseToken(); + FAILURE_IF_FAILED(value); + indices.push_back(static_cast(*value)); + } + + return builder.buildInlineOperandAttr(*type, components->num, mask, + indices); + } + + FailureOr parseDclInput(Location loc) { + auto operand = parseInlineOperand(); + FAILURE_IF_FAILED(operand); + return builder.buildDclInput(*operand, loc); + } + + FailureOr parseDclOutput(Location loc) { + auto operand = parseInlineOperand(); + FAILURE_IF_FAILED(operand); + return builder.buildDclOutput(*operand, loc); + } + OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc, Instruction &out) { FailureOr result; @@ -1116,6 +1214,12 @@ class Parser { case D3D10_SB_OPCODE_DCL_INPUT_PS_SGV: result = parseDclInputPsSgv(loc); break; + case D3D10_SB_OPCODE_DCL_INPUT: + result = parseDclInput(loc); + break; + case D3D10_SB_OPCODE_DCL_OUTPUT: + result = parseDclOutput(loc); + break; default: return std::nullopt; } diff --git a/mlir/test/Target/DXSA/dcl_input.mlir b/mlir/test/Target/DXSA/dcl_input.mlir new file mode 100644 index 000000000000..4f42d4c31c14 --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_input.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_input.bin | FileCheck %s + +// CHECK-LABEL: module +module { + // dcl_input v0.x + // CHECK: dxsa.dcl_input , index = [0]> + + // dcl_input vOutputControlPointID + // CHECK-NEXT: dxsa.dcl_input +} diff --git a/mlir/test/Target/DXSA/dcl_output.mlir b/mlir/test/Target/DXSA/dcl_output.mlir new file mode 100644 index 000000000000..6c5ff0eb0e59 --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_output.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_output.bin | FileCheck %s + +// CHECK-LABEL: module +module { + // dcl_output o0.xyzw + // CHECK: dxsa.dcl_output , index = [0]> + + // dcl_output oDepth + // CHECK-NEXT: dxsa.dcl_output +} diff --git a/mlir/test/Target/DXSA/inputs/dcl_input.bin b/mlir/test/Target/DXSA/inputs/dcl_input.bin new file mode 100644 index 000000000000..831303295432 Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/dcl_input.bin differ diff --git a/mlir/test/Target/DXSA/inputs/dcl_output.bin b/mlir/test/Target/DXSA/inputs/dcl_output.bin new file mode 100644 index 000000000000..2757d4b80f65 Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/dcl_output.bin differ