Skip to content

Commit 82eb011

Browse files
committed
[mlir][dxsa] Add root dxsa.module
Wraps the program into a module with optional attributes program type and shader version. When the binary has no header both attributes are omitted. Example: dxsa.module pixel_shader 5 0 { dxsa.dcl_global_flags <refactoringAllowed> } dxsa.module { dxsa.dcl_global_flags <refactoringAllowed> } Signed-off-by: Vladimir Shiryaev <tagolog@users.noreply.github.com>
1 parent cdfb680 commit 82eb011

35 files changed

Lines changed: 288 additions & 54 deletions

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,112 @@ include "mlir/IR/AttrTypeBase.td"
1515
include "mlir/IR/BuiltinAttributeInterfaces.td"
1616
include "mlir/IR/EnumAttr.td"
1717

18+
//===----------------------------------------------------------------------===//
19+
// DXSA op base class
20+
//===----------------------------------------------------------------------===//
21+
22+
// Base class for all operations in this dialect.
23+
class DXSA_Op<string mnemonic, list<Trait> traits = []> :
24+
Op<DXSADialect, mnemonic, traits>;
25+
26+
//===----------------------------------------------------------------------===//
27+
// DXSA module — top-level container op for a DXBC tokenized program
28+
//===----------------------------------------------------------------------===//
29+
30+
def DXSA_ProgramType_PixelShader : I32EnumAttrCase<"pixel_shader", 0>;
31+
def DXSA_ProgramType_VertexShader : I32EnumAttrCase<"vertex_shader", 1>;
32+
def DXSA_ProgramType_GeometryShader : I32EnumAttrCase<"geometry_shader", 2>;
33+
def DXSA_ProgramType_HullShader : I32EnumAttrCase<"hull_shader", 3>;
34+
def DXSA_ProgramType_DomainShader : I32EnumAttrCase<"domain_shader", 4>;
35+
def DXSA_ProgramType_ComputeShader : I32EnumAttrCase<"compute_shader", 5>;
36+
def DXSA_ProgramType_MeshShader : I32EnumAttrCase<"mesh_shader", 13>;
37+
def DXSA_ProgramType_AmplificationShader : I32EnumAttrCase<"amplification_shader",14>;
38+
39+
def DXSA_ProgramType : I32EnumAttr<
40+
"ProgramType", "DXBC tokenized program type", [
41+
DXSA_ProgramType_PixelShader,
42+
DXSA_ProgramType_VertexShader,
43+
DXSA_ProgramType_GeometryShader,
44+
DXSA_ProgramType_HullShader,
45+
DXSA_ProgramType_DomainShader,
46+
DXSA_ProgramType_ComputeShader,
47+
DXSA_ProgramType_MeshShader,
48+
DXSA_ProgramType_AmplificationShader
49+
]> {
50+
let cppNamespace = "::mlir::dxsa";
51+
let genSpecializedAttr = 0;
52+
}
53+
54+
def DXSA_ProgramTypeAttr :
55+
EnumAttr<DXSADialect, DXSA_ProgramType, "program_type"> {
56+
let assemblyFormat = "$value";
57+
}
58+
59+
def DXSA_ShaderVersionAttr : AttrDef<DXSADialect, "ShaderVersion"> {
60+
let mnemonic = "shader_version";
61+
let summary = "DXBC shader version (major.minor)";
62+
let description = [{
63+
The `#dxsa.shader_version` attribute holds the major and minor
64+
components of shader model version.
65+
66+
Example:
67+
68+
```mlir
69+
#dxsa.shader_version<5, 0>
70+
```
71+
}];
72+
let parameters = (ins "uint8_t":$major, "uint8_t":$minor);
73+
let assemblyFormat = "`<` $major `,` $minor `>`";
74+
}
75+
76+
def DXSA_ModuleOp : DXSA_Op<"module", [
77+
IsolatedFromAbove, NoRegionArguments, NoTerminator, SingleBlock]> {
78+
let summary = "the top-level container for a shader program";
79+
let description = [{
80+
The `dxsa.module` operation models the top-level container of a single
81+
shader tokenized program (one SHEX section of the DXBC binary).
82+
83+
The two optional attributes are shader program type and version.
84+
Both attributes are either both present (real binary with a SHEX
85+
header) or both absent (header-less raw token streams).
86+
87+
Example:
88+
89+
```mlir
90+
// Binary content with a SHEX header
91+
dxsa.module pixel_shader 5 0 {
92+
dxsa.dcl_global_flags <refactoringAllowed>
93+
}
94+
95+
// Binary content without a SHEX header
96+
dxsa.module {
97+
dxsa.dcl_global_flags <refactoringAllowed>
98+
}
99+
```
100+
}];
101+
102+
let arguments = (ins
103+
OptionalAttr<DXSA_ProgramTypeAttr>:$program_type,
104+
OptionalAttr<DXSA_ShaderVersionAttr>:$shader_version);
105+
let regions = (region SizedRegion<1>:$body);
106+
107+
let hasCustomAssemblyFormat = 1;
108+
let hasVerifier = 1;
109+
110+
let skipDefaultBuilders = 1;
111+
let builders = [
112+
OpBuilder<(ins
113+
CArg<"::mlir::dxsa::ProgramTypeAttr",
114+
"::mlir::dxsa::ProgramTypeAttr()">:$programType,
115+
CArg<"::mlir::dxsa::ShaderVersionAttr",
116+
"::mlir::dxsa::ShaderVersionAttr()">:$shaderVersion)>
117+
];
118+
119+
let extraClassDeclaration = [{
120+
::mlir::Block *getBodyBlock() { return &getBody().front(); }
121+
}];
122+
}
123+
18124
//===----------------------------------------------------------------------===//
19125
// DXSA enum definitions
20126
//===----------------------------------------------------------------------===//
@@ -313,10 +419,6 @@ def DXSA_UIntNonZero : AttrConstraint<
313419
// DXSA op definitions
314420
//===----------------------------------------------------------------------===//
315421

316-
// Base class for the operation in this dialect
317-
class DXSA_Op<string mnemonic, list<Trait> traits = []> :
318-
Op<DXSADialect, mnemonic, traits>;
319-
320422
def DXSA_Operand : DXSA_Op<"operand"> {
321423
let summary = "defines an operand of an instruction";
322424
let description = [{

mlir/include/mlir/Target/DXSA/BinaryParser.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99
#ifndef MLIR_TARGET_DXSA_BINARYPARSER_H
1010
#define MLIR_TARGET_DXSA_BINARYPARSER_H
1111

12+
#include "mlir/Dialect/DXSA/IR/DXSA.h"
1213
#include "mlir/IR/BuiltinOps.h"
1314
#include "mlir/IR/MLIRContext.h"
1415
#include "mlir/IR/OwningOpRef.h"
1516
#include "llvm/Support/SourceMgr.h"
1617

1718
namespace mlir::dxsa {
19+
/// Deserializes the given binary \p source and creates a MLIR ModuleOp in the given \p context.
20+
OwningOpRef<dxsa::ModuleOp> deserialize(llvm::SourceMgr &source, MLIRContext *context);
1821

19-
/// Decode DXSA binary \p source and return an MLIR module.
20-
OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
21-
MLIRContext *context);
22-
/// Encode \p source to DXSA binary.
23-
LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output);
22+
/// Serializes the given MLIR \p moduleOp and writes to \p output.
23+
LogicalResult serialize(mlir::ModuleOp moduleOp, raw_ostream &output);
2424
} // namespace mlir::dxsa
2525

2626
#endif // MLIR_TARGET_DXSA_BINARYPARSER_H

mlir/lib/Dialect/DXSA/IR/DXSA.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/IR/Builders.h"
1212
#include "mlir/IR/DialectImplementation.h"
13+
#include "mlir/IR/OpImplementation.h"
1314
#include "llvm/ADT/StringExtras.h"
1415
#include "llvm/ADT/TypeSwitch.h"
1516

@@ -41,6 +42,76 @@ void DXSADialect::initialize() {
4142
#define GET_OP_CLASSES
4243
#include "mlir/Dialect/DXSA/IR/DXSAOps.cpp.inc"
4344

45+
//===----------------------------------------------------------------------===//
46+
// ModuleOp
47+
//===----------------------------------------------------------------------===//
48+
49+
void ModuleOp::build(OpBuilder &builder, OperationState &state,
50+
ProgramTypeAttr programType,
51+
ShaderVersionAttr shaderVersion) {
52+
if (programType)
53+
state.addAttribute("program_type", programType);
54+
if (shaderVersion)
55+
state.addAttribute("shader_version", shaderVersion);
56+
OpBuilder::InsertionGuard guard(builder);
57+
builder.createBlock(state.addRegion());
58+
}
59+
60+
ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
61+
Region *body = result.addRegion();
62+
63+
// Parse optional shader information like `pixel_shader 5 0`.
64+
StringRef typeKeyword;
65+
auto typeLoc = parser.getCurrentLocation();
66+
if (succeeded(parser.parseOptionalKeyword(&typeKeyword))) {
67+
auto programType = symbolizeProgramType(typeKeyword);
68+
if (!programType)
69+
return parser.emitError(typeLoc)
70+
<< "unknown program type: " << typeKeyword;
71+
result.addAttribute("program_type", ProgramTypeAttr::get(
72+
parser.getContext(), *programType));
73+
74+
unsigned major = 0, minor = 0;
75+
if (parser.parseInteger(major) || parser.parseInteger(minor))
76+
return failure();
77+
result.addAttribute("shader_version",
78+
ShaderVersionAttr::get(parser.getContext(),
79+
static_cast<uint8_t>(major),
80+
static_cast<uint8_t>(minor)));
81+
}
82+
83+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
84+
parser.parseRegion(*body, /*arguments=*/{}))
85+
return failure();
86+
87+
if (body->empty())
88+
body->push_back(new Block());
89+
90+
return success();
91+
}
92+
93+
void ModuleOp::print(OpAsmPrinter &printer) {
94+
if (auto programType = getProgramType()) {
95+
printer << ' ' << stringifyProgramType(*programType);
96+
auto version = getShaderVersionAttr();
97+
printer << ' ' << static_cast<unsigned>(version.getMajor()) << ' '
98+
<< static_cast<unsigned>(version.getMinor());
99+
}
100+
printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
101+
{"program_type", "shader_version"});
102+
printer << ' ';
103+
printer.printRegion(getBody());
104+
}
105+
106+
LogicalResult ModuleOp::verify() {
107+
bool hasType = static_cast<bool>(getProgramTypeAttr());
108+
bool hasVersion = static_cast<bool>(getShaderVersionAttr());
109+
if (hasType != hasVersion)
110+
return emitOpError(
111+
"program_type and shader_version must both be present or both absent");
112+
return success();
113+
}
114+
44115
//===----------------------------------------------------------------------===//
45116
// Op verifiers
46117
//===----------------------------------------------------------------------===//

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,23 @@ static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) {
420420

421421
class DXBuilder {
422422
public:
423-
DXBuilder(MLIRContext *context, StringAttr name)
424-
: context(context),
425-
module(ModuleOp::create(builder, FileLineColLoc::get(name, 0, 0))),
426-
builder(module.getRegion()) {}
423+
explicit DXBuilder(MLIRContext *context)
424+
: context(context), builder(context) {}
427425

428426
using Index = mlir::Value;
429427
using Operand = mlir::Value;
430428
using Instruction = mlir::Operation *;
431-
using Module = mlir::ModuleOp;
429+
using Module = mlir::dxsa::ModuleOp;
430+
431+
Module createModule(mlir::dxsa::ProgramTypeAttr programType,
432+
mlir::dxsa::ShaderVersionAttr shaderVersion,
433+
FileLineColLoc loc) {
434+
OperationState state(loc, Module::getOperationName());
435+
Module::build(builder, state, programType, shaderVersion);
436+
auto module = cast<Module>(Operation::create(state));
437+
builder.setInsertionPointToStart(&module.getBody().front());
438+
return module;
439+
}
432440

433441
Index buildIndexImm32(int32_t imm, FileLineColLoc loc) {
434442
Operation *op =
@@ -523,10 +531,6 @@ class DXBuilder {
523531
builder.getStringAttr(name));
524532
}
525533

526-
Module buildModule(ArrayRef<Instruction> instructions, FileLineColLoc loc) {
527-
return module;
528-
}
529-
530534
Instruction buildDclGlobalFlags(dxsa::GlobalFlags flags, Location loc) {
531535
auto flagsAttr = dxsa::GlobalFlagsAttr::get(builder.getContext(), flags);
532536
return dxsa::DclGlobalFlags::create(builder, loc, flagsAttr);
@@ -706,7 +710,6 @@ class DXBuilder {
706710

707711
private:
708712
MLIRContext *context;
709-
ModuleOp module;
710713
OpBuilder builder;
711714
};
712715

@@ -1530,15 +1533,61 @@ class Parser {
15301533

15311534
FailureOr<Module> parseModule() {
15321535
FileLineColLoc loc = getLocation(0);
1533-
std::vector<Instruction> instructions;
1536+
auto header = parseProgramHeader();
1537+
if (failed(header))
1538+
return failure();
1539+
mlir::dxsa::ProgramTypeAttr programType;
1540+
mlir::dxsa::ShaderVersionAttr shaderVersion;
1541+
if (*header) {
1542+
programType =
1543+
mlir::dxsa::ProgramTypeAttr::get(name.getContext(), (*header)->type);
1544+
shaderVersion = mlir::dxsa::ShaderVersionAttr::get(
1545+
name.getContext(), (*header)->major, (*header)->minor);
1546+
}
1547+
auto module = builder.createModule(programType, shaderVersion, loc);
15341548
while (currentTokenOffset < buffer.size()) {
15351549
FailureOr<Instruction> inst = parseInstruction();
15361550
if (failed(inst)) {
15371551
return failure();
15381552
}
1539-
instructions.push_back(*inst);
15401553
}
1541-
return builder.buildModule(instructions, loc);
1554+
return module;
1555+
}
1556+
1557+
struct ProgramHeader {
1558+
mlir::dxsa::ProgramType type;
1559+
uint8_t major;
1560+
uint8_t minor;
1561+
};
1562+
1563+
/// If the buffer begins with a tokenized-program header (VersionToken +
1564+
/// LengthToken), decode and consume both tokens and return the program type
1565+
/// and shader model. Otherwise return without touching the parser current
1566+
/// position.
1567+
FailureOr<std::optional<ProgramHeader>> parseProgramHeader() {
1568+
constexpr size_t headerSize = 2 * sizeof(uint32_t);
1569+
if (currentTokenOffset + headerSize > buffer.size())
1570+
return std::optional<ProgramHeader>{};
1571+
1572+
auto versionToken = support::endian::read<uint32_t>(
1573+
buffer.begin() + currentTokenOffset, endianness::little);
1574+
if (DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(versionToken) != 0)
1575+
return std::optional<ProgramHeader>{};
1576+
1577+
auto rawType = static_cast<uint32_t>(
1578+
DECODE_D3D10_SB_TOKENIZED_PROGRAM_TYPE(versionToken));
1579+
auto programType = dxsa::symbolizeProgramType(rawType);
1580+
if (!programType)
1581+
return std::optional<ProgramHeader>{};
1582+
1583+
auto major = static_cast<uint8_t>(
1584+
DECODE_D3D10_SB_TOKENIZED_PROGRAM_MAJOR_VERSION(versionToken));
1585+
auto minor = static_cast<uint8_t>(
1586+
DECODE_D3D10_SB_TOKENIZED_PROGRAM_MINOR_VERSION(versionToken));
1587+
1588+
FAILURE_IF_FAILED(parseToken()); // VersionToken
1589+
FAILURE_IF_FAILED(parseToken()); // LengthToken
1590+
return std::optional<ProgramHeader>{{*programType, major, minor}};
15421591
}
15431592

15441593
LogicalResult verifyInstructionLength(size_t beginOffset, uint32_t length) {
@@ -1558,8 +1607,8 @@ class Parser {
15581607
};
15591608

15601609
namespace mlir::dxsa {
1561-
OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
1562-
MLIRContext *context) {
1610+
OwningOpRef<ModuleOp> deserialize(llvm::SourceMgr &source,
1611+
MLIRContext *context) {
15631612

15641613
if (source.getNumBuffers() != 1) {
15651614
emitError(UnknownLoc::get(context), "one source file should be provided");
@@ -1575,7 +1624,7 @@ OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
15751624
context->allowUnregisteredDialects();
15761625
context->loadAllAvailableDialects();
15771626

1578-
DXBuilder builder(context, name);
1627+
DXBuilder builder(context);
15791628
Parser parser(builder, name, buffer);
15801629
FailureOr<ModuleOp> mod = parser.parseModule();
15811630
if (failed(mod))

mlir/lib/Target/DXSA/BinaryWriter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using namespace mlir;
1818
using namespace llvm;
1919

2020
namespace mlir::dxsa {
21-
LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output) {
21+
LogicalResult serialize(mlir::ModuleOp source, raw_ostream &output) {
2222
Region &region = source.getRegion();
2323
assert(region.hasOneBlock() && "invalid module");
2424
return failure();

mlir/lib/Target/DXSA/TranslateRegistration.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void registerFromDxsaBinTranslation() {
1919
"import-dxsa-bin", "Translate DXSA binary to MLIR",
2020
[](llvm::SourceMgr &sourceMgr,
2121
MLIRContext *context) -> OwningOpRef<Operation *> {
22-
return dxsa::importDxsaBinaryToModule(sourceMgr, context);
22+
return dxsa::deserialize(sourceMgr, context);
2323
},
2424
[](DialectRegistry &registry) { registry.insert<dxsa::DXSADialect>(); }};
2525
}
@@ -28,7 +28,7 @@ void registerToDxsaBinTranslation() {
2828
TranslateFromMLIRRegistration registration{
2929
"export-dxsa-bin", "Translate MLIR to DXSA binary",
3030
[](ModuleOp source, raw_ostream &output) {
31-
return dxsa::exportModuleToDxsaBinary(source, output);
31+
return dxsa::serialize(source, output);
3232
},
3333
[](DialectRegistry &registry) { registry.insert<dxsa::DXSADialect>(); }};
3434
}

0 commit comments

Comments
 (0)