@@ -420,15 +420,23 @@ static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) {
420420
421421class DXBuilder {
422422public:
423- DXBuilder (MLIRContext *context, StringAttr name)
424- : context(context),
425- module (ModuleOp::create(builder, FileLineColLoc::get(name, 0 , 0 ))),
426- builder(module .getRegion()) {}
423+ explicit DXBuilder (MLIRContext *context)
424+ : context(context), builder(context) {}
427425
428426 using Index = mlir::Value;
429427 using Operand = mlir::Value;
430428 using Instruction = mlir::Operation *;
431- using Module = mlir::ModuleOp;
429+ using Module = mlir::dxsa::ModuleOp;
430+
431+ Module createModule (mlir::dxsa::ProgramTypeAttr programType,
432+ mlir::dxsa::ShaderVersionAttr shaderVersion,
433+ FileLineColLoc loc) {
434+ OperationState state (loc, Module::getOperationName ());
435+ Module::build (builder, state, programType, shaderVersion);
436+ auto module = cast<Module>(Operation::create (state));
437+ builder.setInsertionPointToStart (&module .getBody ().front ());
438+ return module ;
439+ }
432440
433441 Index buildIndexImm32 (int32_t imm, FileLineColLoc loc) {
434442 Operation *op =
@@ -523,10 +531,6 @@ class DXBuilder {
523531 builder.getStringAttr (name));
524532 }
525533
526- Module buildModule (ArrayRef<Instruction> instructions, FileLineColLoc loc) {
527- return module ;
528- }
529-
530534 Instruction buildDclGlobalFlags (dxsa::GlobalFlags flags, Location loc) {
531535 auto flagsAttr = dxsa::GlobalFlagsAttr::get (builder.getContext (), flags);
532536 return dxsa::DclGlobalFlags::create (builder, loc, flagsAttr);
@@ -706,7 +710,6 @@ class DXBuilder {
706710
707711private:
708712 MLIRContext *context;
709- ModuleOp module ;
710713 OpBuilder builder;
711714};
712715
@@ -1530,15 +1533,61 @@ class Parser {
15301533
15311534 FailureOr<Module> parseModule () {
15321535 FileLineColLoc loc = getLocation (0 );
1533- std::vector<Instruction> instructions;
1536+ auto header = parseProgramHeader ();
1537+ if (failed (header))
1538+ return failure ();
1539+ mlir::dxsa::ProgramTypeAttr programType;
1540+ mlir::dxsa::ShaderVersionAttr shaderVersion;
1541+ if (*header) {
1542+ programType =
1543+ mlir::dxsa::ProgramTypeAttr::get (name.getContext (), (*header)->type );
1544+ shaderVersion = mlir::dxsa::ShaderVersionAttr::get (
1545+ name.getContext (), (*header)->major , (*header)->minor );
1546+ }
1547+ auto module = builder.createModule (programType, shaderVersion, loc);
15341548 while (currentTokenOffset < buffer.size ()) {
15351549 FailureOr<Instruction> inst = parseInstruction ();
15361550 if (failed (inst)) {
15371551 return failure ();
15381552 }
1539- instructions.push_back (*inst);
15401553 }
1541- return builder.buildModule (instructions, loc);
1554+ return module ;
1555+ }
1556+
1557+ struct ProgramHeader {
1558+ mlir::dxsa::ProgramType type;
1559+ uint8_t major;
1560+ uint8_t minor;
1561+ };
1562+
1563+ // / If the buffer begins with a tokenized-program header (VersionToken +
1564+ // / LengthToken), decode and consume both tokens and return the program type
1565+ // / and shader model. Otherwise return without touching the parser current
1566+ // / position.
1567+ FailureOr<std::optional<ProgramHeader>> parseProgramHeader () {
1568+ constexpr size_t headerSize = 2 * sizeof (uint32_t );
1569+ if (currentTokenOffset + headerSize > buffer.size ())
1570+ return std::optional<ProgramHeader>{};
1571+
1572+ auto versionToken = support::endian::read<uint32_t >(
1573+ buffer.begin () + currentTokenOffset, endianness::little);
1574+ if (DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH (versionToken) != 0 )
1575+ return std::optional<ProgramHeader>{};
1576+
1577+ auto rawType = static_cast <uint32_t >(
1578+ DECODE_D3D10_SB_TOKENIZED_PROGRAM_TYPE (versionToken));
1579+ auto programType = dxsa::symbolizeProgramType (rawType);
1580+ if (!programType)
1581+ return std::optional<ProgramHeader>{};
1582+
1583+ auto major = static_cast <uint8_t >(
1584+ DECODE_D3D10_SB_TOKENIZED_PROGRAM_MAJOR_VERSION (versionToken));
1585+ auto minor = static_cast <uint8_t >(
1586+ DECODE_D3D10_SB_TOKENIZED_PROGRAM_MINOR_VERSION (versionToken));
1587+
1588+ FAILURE_IF_FAILED (parseToken ()); // VersionToken
1589+ FAILURE_IF_FAILED (parseToken ()); // LengthToken
1590+ return std::optional<ProgramHeader>{{*programType, major, minor}};
15421591 }
15431592
15441593 LogicalResult verifyInstructionLength (size_t beginOffset, uint32_t length) {
@@ -1558,8 +1607,8 @@ class Parser {
15581607};
15591608
15601609namespace mlir ::dxsa {
1561- OwningOpRef<ModuleOp> importDxsaBinaryToModule (llvm::SourceMgr &source,
1562- MLIRContext *context) {
1610+ OwningOpRef<ModuleOp> deserialize (llvm::SourceMgr &source,
1611+ MLIRContext *context) {
15631612
15641613 if (source.getNumBuffers () != 1 ) {
15651614 emitError (UnknownLoc::get (context), " one source file should be provided" );
@@ -1575,7 +1624,7 @@ OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
15751624 context->allowUnregisteredDialects ();
15761625 context->loadAllAvailableDialects ();
15771626
1578- DXBuilder builder (context, name );
1627+ DXBuilder builder (context);
15791628 Parser parser (builder, name, buffer);
15801629 FailureOr<ModuleOp> mod = parser.parseModule ();
15811630 if (failed (mod))
0 commit comments