Skip to content

Commit bc69c6d

Browse files
committed
[mlir][dxsa] Add root dxsa.module
Wraps the program into a module with optional attributes program type and shader version. When the binary has no header both attributes are omitted. Example: dxsa.module pixel_shader 5 0 { dxsa.dcl_global_flags <refactoringAllowed> } dxsa.module { dxsa.dcl_global_flags <refactoringAllowed> } Signed-off-by: Vladimir Shiryaev <tagolog@users.noreply.github.com>
1 parent 85119f8 commit bc69c6d

35 files changed

Lines changed: 305 additions & 55 deletions

mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,112 @@ include "mlir/IR/AttrTypeBase.td"
1515
include "mlir/IR/BuiltinAttributeInterfaces.td"
1616
include "mlir/IR/EnumAttr.td"
1717

18+
//===----------------------------------------------------------------------===//
19+
// DXSA op base class
20+
//===----------------------------------------------------------------------===//
21+
22+
// Base class for all operations in this dialect.
23+
class DXSA_Op<string mnemonic, list<Trait> traits = []> :
24+
Op<DXSADialect, mnemonic, traits>;
25+
26+
//===----------------------------------------------------------------------===//
27+
// DXSA module — top-level container op for a DXBC tokenized program
28+
//===----------------------------------------------------------------------===//
29+
30+
def DXSA_ProgramType_PixelShader : I32EnumAttrCase<"pixel_shader", 0>;
31+
def DXSA_ProgramType_VertexShader : I32EnumAttrCase<"vertex_shader", 1>;
32+
def DXSA_ProgramType_GeometryShader : I32EnumAttrCase<"geometry_shader", 2>;
33+
def DXSA_ProgramType_HullShader : I32EnumAttrCase<"hull_shader", 3>;
34+
def DXSA_ProgramType_DomainShader : I32EnumAttrCase<"domain_shader", 4>;
35+
def DXSA_ProgramType_ComputeShader : I32EnumAttrCase<"compute_shader", 5>;
36+
def DXSA_ProgramType_MeshShader : I32EnumAttrCase<"mesh_shader", 13>;
37+
def DXSA_ProgramType_AmplificationShader : I32EnumAttrCase<"amplification_shader",14>;
38+
39+
def DXSA_ProgramType : I32EnumAttr<
40+
"ProgramType", "DXBC tokenized program type", [
41+
DXSA_ProgramType_PixelShader,
42+
DXSA_ProgramType_VertexShader,
43+
DXSA_ProgramType_GeometryShader,
44+
DXSA_ProgramType_HullShader,
45+
DXSA_ProgramType_DomainShader,
46+
DXSA_ProgramType_ComputeShader,
47+
DXSA_ProgramType_MeshShader,
48+
DXSA_ProgramType_AmplificationShader
49+
]> {
50+
let cppNamespace = "::mlir::dxsa";
51+
let genSpecializedAttr = 0;
52+
}
53+
54+
def DXSA_ProgramTypeAttr :
55+
EnumAttr<DXSADialect, DXSA_ProgramType, "program_type"> {
56+
let assemblyFormat = "$value";
57+
}
58+
59+
def DXSA_ShaderVersionAttr : AttrDef<DXSADialect, "ShaderVersion"> {
60+
let mnemonic = "shader_version";
61+
let summary = "DXBC shader version (major.minor)";
62+
let description = [{
63+
The `#dxsa.shader_version` attribute holds the major and minor
64+
components of shader model version.
65+
66+
Example:
67+
68+
```mlir
69+
#dxsa.shader_version<5, 0>
70+
```
71+
}];
72+
let parameters = (ins "uint8_t":$major, "uint8_t":$minor);
73+
let assemblyFormat = "`<` $major `,` $minor `>`";
74+
}
75+
76+
def DXSA_ModuleOp : DXSA_Op<"module", [
77+
IsolatedFromAbove, NoRegionArguments, NoTerminator, SingleBlock]> {
78+
let summary = "the top-level container for a shader program";
79+
let description = [{
80+
The `dxsa.module` operation models the top-level container of a single
81+
shader tokenized program (one SHEX section of the DXBC binary).
82+
83+
The two optional attributes are shader program type and version.
84+
Both attributes are either both present (real binary with a SHEX
85+
header) or both absent (header-less raw token streams).
86+
87+
Example:
88+
89+
```mlir
90+
// Binary content with a SHEX header
91+
dxsa.module pixel_shader 5 0 {
92+
dxsa.dcl_global_flags <refactoringAllowed>
93+
}
94+
95+
// Binary content without a SHEX header
96+
dxsa.module {
97+
dxsa.dcl_global_flags <refactoringAllowed>
98+
}
99+
```
100+
}];
101+
102+
let arguments = (ins
103+
OptionalAttr<DXSA_ProgramTypeAttr>:$program_type,
104+
OptionalAttr<DXSA_ShaderVersionAttr>:$shader_version);
105+
let regions = (region SizedRegion<1>:$body);
106+
107+
let hasCustomAssemblyFormat = 1;
108+
let hasVerifier = 1;
109+
110+
let skipDefaultBuilders = 1;
111+
let builders = [
112+
OpBuilder<(ins
113+
CArg<"::mlir::dxsa::ProgramTypeAttr",
114+
"::mlir::dxsa::ProgramTypeAttr()">:$programType,
115+
CArg<"::mlir::dxsa::ShaderVersionAttr",
116+
"::mlir::dxsa::ShaderVersionAttr()">:$shaderVersion)>
117+
];
118+
119+
let extraClassDeclaration = [{
120+
::mlir::Block *getBodyBlock() { return &getBody().front(); }
121+
}];
122+
}
123+
18124
//===----------------------------------------------------------------------===//
19125
// DXSA enum definitions
20126
//===----------------------------------------------------------------------===//
@@ -392,10 +498,6 @@ def DXSA_UIntNonZero : AttrConstraint<
392498
// DXSA op definitions
393499
//===----------------------------------------------------------------------===//
394500

395-
// Base class for the operation in this dialect
396-
class DXSA_Op<string mnemonic, list<Trait> traits = []> :
397-
Op<DXSADialect, mnemonic, traits>;
398-
399501
def DXSA_Operand : DXSA_Op<"operand"> {
400502
let summary = "defines an operand of an instruction";
401503
let description = [{

mlir/include/mlir/Target/DXSA/BinaryParser.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,20 @@
99
#ifndef MLIR_TARGET_DXSA_BINARYPARSER_H
1010
#define MLIR_TARGET_DXSA_BINARYPARSER_H
1111

12+
#include "mlir/Dialect/DXSA/IR/DXSA.h"
1213
#include "mlir/IR/BuiltinOps.h"
1314
#include "mlir/IR/MLIRContext.h"
1415
#include "mlir/IR/OwningOpRef.h"
1516
#include "llvm/Support/SourceMgr.h"
1617

1718
namespace mlir::dxsa {
19+
/// Deserializes the given binary \p source and creates a MLIR ModuleOp in the
20+
/// given \p context.
21+
OwningOpRef<dxsa::ModuleOp> deserialize(llvm::SourceMgr &source,
22+
MLIRContext *context);
1823

19-
/// Decode DXSA binary \p source and return an MLIR module.
20-
OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
21-
MLIRContext *context);
22-
/// Encode \p source to DXSA binary.
23-
LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output);
24+
/// Serializes the given MLIR \p moduleOp and writes to \p output.
25+
LogicalResult serialize(mlir::ModuleOp moduleOp, raw_ostream &output);
2426
} // namespace mlir::dxsa
2527

2628
#endif // MLIR_TARGET_DXSA_BINARYPARSER_H

mlir/lib/Dialect/DXSA/IR/DXSA.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/IR/Builders.h"
1212
#include "mlir/IR/DialectImplementation.h"
13+
#include "mlir/IR/OpImplementation.h"
1314
#include "llvm/ADT/StringExtras.h"
1415
#include "llvm/ADT/TypeSwitch.h"
1516

@@ -41,6 +42,74 @@ void DXSADialect::initialize() {
4142
#define GET_OP_CLASSES
4243
#include "mlir/Dialect/DXSA/IR/DXSAOps.cpp.inc"
4344

45+
//===----------------------------------------------------------------------===//
46+
// ModuleOp
47+
//===----------------------------------------------------------------------===//
48+
49+
void ModuleOp::build(OpBuilder &builder, OperationState &state,
50+
ProgramTypeAttr programType,
51+
ShaderVersionAttr shaderVersion) {
52+
if (programType)
53+
state.addAttribute("program_type", programType);
54+
if (shaderVersion)
55+
state.addAttribute("shader_version", shaderVersion);
56+
OpBuilder::InsertionGuard guard(builder);
57+
builder.createBlock(state.addRegion());
58+
}
59+
60+
ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
61+
// Parse optional shader information like `pixel_shader 5 0`.
62+
StringRef typeKeyword;
63+
auto typeLoc = parser.getCurrentLocation();
64+
if (succeeded(parser.parseOptionalKeyword(&typeKeyword))) {
65+
auto programType = symbolizeProgramType(typeKeyword);
66+
if (!programType)
67+
return parser.emitError(typeLoc)
68+
<< "unknown program type: " << typeKeyword;
69+
result.addAttribute("program_type", ProgramTypeAttr::get(
70+
parser.getContext(), *programType));
71+
72+
uint8_t major = 0, minor = 0;
73+
if (parser.parseInteger(major) || parser.parseInteger(minor))
74+
return failure();
75+
result.addAttribute(
76+
"shader_version",
77+
ShaderVersionAttr::get(parser.getContext(), major, minor));
78+
}
79+
80+
Region *body = result.addRegion();
81+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
82+
parser.parseRegion(*body, /*arguments=*/{}))
83+
return failure();
84+
85+
if (body->empty())
86+
body->push_back(new Block());
87+
88+
return success();
89+
}
90+
91+
void ModuleOp::print(OpAsmPrinter &printer) {
92+
if (auto programType = getProgramType()) {
93+
printer << ' ' << stringifyProgramType(*programType);
94+
auto version = getShaderVersionAttr();
95+
printer << ' ' << static_cast<unsigned>(version.getMajor()) << ' '
96+
<< static_cast<unsigned>(version.getMinor());
97+
}
98+
printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
99+
{"program_type", "shader_version"});
100+
printer << ' ';
101+
printer.printRegion(getBody());
102+
}
103+
104+
LogicalResult ModuleOp::verify() {
105+
bool hasType = static_cast<bool>(getProgramTypeAttr());
106+
bool hasVersion = static_cast<bool>(getShaderVersionAttr());
107+
if (hasType != hasVersion)
108+
return emitOpError(
109+
"program_type and shader_version must both be present or both absent");
110+
return success();
111+
}
112+
44113
//===----------------------------------------------------------------------===//
45114
// Op verifiers
46115
//===----------------------------------------------------------------------===//

mlir/lib/Target/DXSA/BinaryParser.cpp

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,22 @@ static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) {
420420

421421
class DXBuilder {
422422
public:
423-
DXBuilder(MLIRContext *context, StringAttr name)
424-
: context(context),
425-
module(ModuleOp::create(builder, FileLineColLoc::get(name, 0, 0))),
426-
builder(module.getRegion()) {}
423+
explicit DXBuilder(MLIRContext *context)
424+
: context(context), builder(context) {}
427425

428426
using Index = mlir::Value;
429427
using Operand = mlir::Value;
430428
using Instruction = mlir::Operation *;
431-
using Module = mlir::ModuleOp;
429+
using Module = dxsa::ModuleOp;
430+
431+
Module createModule(dxsa::ProgramTypeAttr programType,
432+
dxsa::ShaderVersionAttr shaderVersion, Location loc) {
433+
OperationState state(loc, Module::getOperationName());
434+
Module::build(builder, state, programType, shaderVersion);
435+
auto module = cast<Module>(Operation::create(state));
436+
builder.setInsertionPointToStart(&module.getBody().front());
437+
return module;
438+
}
432439

433440
Index buildIndexImm32(int32_t imm, FileLineColLoc loc) {
434441
Operation *op =
@@ -523,10 +530,6 @@ class DXBuilder {
523530
builder.getStringAttr(name));
524531
}
525532

526-
Module buildModule(ArrayRef<Instruction> instructions, FileLineColLoc loc) {
527-
return module;
528-
}
529-
530533
Instruction buildDclGlobalFlags(dxsa::GlobalFlags flags, Location loc) {
531534
auto flagsAttr = dxsa::GlobalFlagsAttr::get(builder.getContext(), flags);
532535
return dxsa::DclGlobalFlags::create(builder, loc, flagsAttr);
@@ -762,7 +765,6 @@ class DXBuilder {
762765

763766
private:
764767
MLIRContext *context;
765-
ModuleOp module;
766768
OpBuilder builder;
767769
};
768770

@@ -779,9 +781,11 @@ class Parser {
779781
using Instruction = DXBuilder::Instruction;
780782
using Module = DXBuilder::Module;
781783

784+
/// Width of the token in the program binary stream.
785+
static constexpr size_t tokenSize = sizeof(uint32_t);
786+
782787
/// Parse the current token and move the cursor to the next one.
783788
Token parseToken() {
784-
constexpr size_t tokenSize = sizeof(uint32_t);
785789
if ((currentTokenOffset + tokenSize) > buffer.size()) {
786790
emitError(getLocation(), "unexpected end of file");
787791
return failure();
@@ -1764,15 +1768,76 @@ class Parser {
17641768

17651769
FailureOr<Module> parseModule() {
17661770
FileLineColLoc loc = getLocation(0);
1767-
std::vector<Instruction> instructions;
1771+
auto header = parseProgramHeader();
1772+
FAILURE_IF_FAILED(header);
1773+
dxsa::ProgramTypeAttr programType;
1774+
dxsa::ShaderVersionAttr shaderVersion;
1775+
if (*header) {
1776+
programType =
1777+
dxsa::ProgramTypeAttr::get(name.getContext(), (*header)->type);
1778+
shaderVersion = dxsa::ShaderVersionAttr::get(
1779+
name.getContext(), (*header)->major, (*header)->minor);
1780+
}
1781+
auto module = builder.createModule(programType, shaderVersion, loc);
17681782
while (currentTokenOffset < buffer.size()) {
17691783
FailureOr<Instruction> inst = parseInstruction();
17701784
if (failed(inst)) {
17711785
return failure();
17721786
}
1773-
instructions.push_back(*inst);
17741787
}
1775-
return builder.buildModule(instructions, loc);
1788+
return module;
1789+
}
1790+
1791+
struct ProgramHeader {
1792+
dxsa::ProgramType type;
1793+
uint8_t major;
1794+
uint8_t minor;
1795+
};
1796+
1797+
/// If the buffer begins with a tokenized-program header (VersionToken +
1798+
/// LengthToken), decode and consume both tokens and return the program type
1799+
/// and shader version. Otherwise return without touching the parser current
1800+
/// position.
1801+
FailureOr<std::optional<ProgramHeader>> parseProgramHeader() {
1802+
auto remainingBytes = buffer.size() - currentTokenOffset;
1803+
if (remainingBytes < tokenSize)
1804+
return std::optional<ProgramHeader>{};
1805+
1806+
auto versionToken = support::endian::read<uint32_t>(
1807+
buffer.begin() + currentTokenOffset, endianness::little);
1808+
uint32_t rawProgramType =
1809+
DECODE_D3D10_SB_TOKENIZED_PROGRAM_TYPE(versionToken);
1810+
auto programType = dxsa::symbolizeProgramType(rawProgramType);
1811+
if (!programType)
1812+
return std::optional<ProgramHeader>{};
1813+
1814+
constexpr size_t headerSize = 2 * tokenSize;
1815+
if (remainingBytes < headerSize)
1816+
return emitError(getLocation(),
1817+
"expected LengthToken after VersionToken");
1818+
1819+
auto versionTokenLength =
1820+
DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(versionToken);
1821+
if (versionTokenLength != 0)
1822+
return emitError(getLocation(), "VersionToken length must be 0, got ")
1823+
<< versionTokenLength;
1824+
1825+
auto lengthToken = support::endian::read<uint32_t>(
1826+
buffer.begin() + currentTokenOffset + tokenSize, endianness::little);
1827+
auto programLength = DECODE_D3D10_SB_TOKENIZED_PROGRAM_LENGTH(lengthToken);
1828+
constexpr size_t minProgramLen = 2; // VersionToken and LengthToken
1829+
if (programLength < minProgramLen)
1830+
return emitError(getLocation(), "LengthToken must be >= ")
1831+
<< minProgramLen << ", got " << programLength;
1832+
1833+
uint8_t major =
1834+
DECODE_D3D10_SB_TOKENIZED_PROGRAM_MAJOR_VERSION(versionToken);
1835+
uint8_t minor =
1836+
DECODE_D3D10_SB_TOKENIZED_PROGRAM_MINOR_VERSION(versionToken);
1837+
1838+
FAILURE_IF_FAILED(parseToken()); // VersionToken
1839+
FAILURE_IF_FAILED(parseToken()); // LengthToken
1840+
return std::optional<ProgramHeader>{{*programType, major, minor}};
17761841
}
17771842

17781843
LogicalResult verifyInstructionLength(size_t beginOffset, uint32_t length) {
@@ -1792,8 +1857,8 @@ class Parser {
17921857
};
17931858

17941859
namespace mlir::dxsa {
1795-
OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
1796-
MLIRContext *context) {
1860+
OwningOpRef<ModuleOp> deserialize(llvm::SourceMgr &source,
1861+
MLIRContext *context) {
17971862

17981863
if (source.getNumBuffers() != 1) {
17991864
emitError(UnknownLoc::get(context), "one source file should be provided");
@@ -1809,7 +1874,7 @@ OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
18091874
context->allowUnregisteredDialects();
18101875
context->loadAllAvailableDialects();
18111876

1812-
DXBuilder builder(context, name);
1877+
DXBuilder builder(context);
18131878
Parser parser(builder, name, buffer);
18141879
FailureOr<ModuleOp> mod = parser.parseModule();
18151880
if (failed(mod))

0 commit comments

Comments
 (0)