@@ -420,15 +420,23 @@ static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) {
420420
421421class DXBuilder {
422422public:
423- DXBuilder (MLIRContext *context, StringAttr name)
424- : context(context),
425- module (ModuleOp::create(builder, FileLineColLoc::get(name, 0 , 0 ))),
426- builder(module .getRegion()) {}
423+ explicit DXBuilder (MLIRContext *context)
424+ : context(context), builder(context) {}
427425
428426 using Index = mlir::Value;
429427 using Operand = mlir::Value;
430428 using Instruction = mlir::Operation *;
431- using Module = mlir::ModuleOp;
429+ using Module = mlir::dxsa::ModuleOp;
430+
431+ Module createModule (mlir::dxsa::ProgramTypeAttr programType,
432+ mlir::dxsa::ShaderVersionAttr shaderVersion,
433+ Location loc) {
434+ OperationState state (loc, Module::getOperationName ());
435+ Module::build (builder, state, programType, shaderVersion);
436+ auto module = cast<Module>(Operation::create (state));
437+ builder.setInsertionPointToStart (&module .getBody ().front ());
438+ return module ;
439+ }
432440
433441 Index buildIndexImm32 (int32_t imm, FileLineColLoc loc) {
434442 Operation *op =
@@ -523,10 +531,6 @@ class DXBuilder {
523531 builder.getStringAttr (name));
524532 }
525533
526- Module buildModule (ArrayRef<Instruction> instructions, FileLineColLoc loc) {
527- return module ;
528- }
529-
530534 Instruction buildDclGlobalFlags (dxsa::GlobalFlags flags, Location loc) {
531535 auto flagsAttr = dxsa::GlobalFlagsAttr::get (builder.getContext (), flags);
532536 return dxsa::DclGlobalFlags::create (builder, loc, flagsAttr);
@@ -706,7 +710,6 @@ class DXBuilder {
706710
707711private:
708712 MLIRContext *context;
709- ModuleOp module ;
710713 OpBuilder builder;
711714};
712715
@@ -1530,15 +1533,60 @@ class Parser {
15301533
15311534 FailureOr<Module> parseModule () {
15321535 FileLineColLoc loc = getLocation (0 );
1533- std::vector<Instruction> instructions;
1536+ auto header = parseProgramHeader ();
1537+ FAILURE_IF_FAILED (header);
1538+ mlir::dxsa::ProgramTypeAttr programType;
1539+ mlir::dxsa::ShaderVersionAttr shaderVersion;
1540+ if (*header) {
1541+ programType =
1542+ mlir::dxsa::ProgramTypeAttr::get (name.getContext (), (*header)->type );
1543+ shaderVersion = mlir::dxsa::ShaderVersionAttr::get (
1544+ name.getContext (), (*header)->major , (*header)->minor );
1545+ }
1546+ auto module = builder.createModule (programType, shaderVersion, loc);
15341547 while (currentTokenOffset < buffer.size ()) {
15351548 FailureOr<Instruction> inst = parseInstruction ();
15361549 if (failed (inst)) {
15371550 return failure ();
15381551 }
1539- instructions.push_back (*inst);
15401552 }
1541- return builder.buildModule (instructions, loc);
1553+ return module ;
1554+ }
1555+
1556+ struct ProgramHeader {
1557+ mlir::dxsa::ProgramType type;
1558+ uint8_t major;
1559+ uint8_t minor;
1560+ };
1561+
1562+ // / If the buffer begins with a tokenized-program header (VersionToken +
1563+ // / LengthToken), decode and consume both tokens and return the program type
1564+ // / and shader model. Otherwise return without touching the parser current
1565+ // / position.
1566+ FailureOr<std::optional<ProgramHeader>> parseProgramHeader () {
1567+ constexpr size_t headerSize = 2 * sizeof (uint32_t );
1568+ if (currentTokenOffset + headerSize > buffer.size ())
1569+ return std::optional<ProgramHeader>{};
1570+
1571+ auto versionToken = support::endian::read<uint32_t >(
1572+ buffer.begin () + currentTokenOffset, endianness::little);
1573+ if (DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH (versionToken) != 0 )
1574+ return std::optional<ProgramHeader>{};
1575+
1576+ auto rawType = static_cast <uint32_t >(
1577+ DECODE_D3D10_SB_TOKENIZED_PROGRAM_TYPE (versionToken));
1578+ auto programType = dxsa::symbolizeProgramType (rawType);
1579+ if (!programType)
1580+ return std::optional<ProgramHeader>{};
1581+
1582+ auto major = static_cast <uint8_t >(
1583+ DECODE_D3D10_SB_TOKENIZED_PROGRAM_MAJOR_VERSION (versionToken));
1584+ auto minor = static_cast <uint8_t >(
1585+ DECODE_D3D10_SB_TOKENIZED_PROGRAM_MINOR_VERSION (versionToken));
1586+
1587+ FAILURE_IF_FAILED (parseToken ()); // VersionToken
1588+ FAILURE_IF_FAILED (parseToken ()); // LengthToken
1589+ return std::optional<ProgramHeader>{{*programType, major, minor}};
15421590 }
15431591
15441592 LogicalResult verifyInstructionLength (size_t beginOffset, uint32_t length) {
@@ -1558,8 +1606,8 @@ class Parser {
15581606};
15591607
15601608namespace mlir ::dxsa {
1561- OwningOpRef<ModuleOp> importDxsaBinaryToModule (llvm::SourceMgr &source,
1562- MLIRContext *context) {
1609+ OwningOpRef<ModuleOp> deserialize (llvm::SourceMgr &source,
1610+ MLIRContext *context) {
15631611
15641612 if (source.getNumBuffers () != 1 ) {
15651613 emitError (UnknownLoc::get (context), " one source file should be provided" );
@@ -1575,7 +1623,7 @@ OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
15751623 context->allowUnregisteredDialects ();
15761624 context->loadAllAvailableDialects ();
15771625
1578- DXBuilder builder (context, name );
1626+ DXBuilder builder (context);
15791627 Parser parser (builder, name, buffer);
15801628 FailureOr<ModuleOp> mod = parser.parseModule ();
15811629 if (failed (mod))
0 commit comments