diff --git a/shared/stinkytofu/include/stinkytofu/ir/asm/StinkyModifiers.hpp b/shared/stinkytofu/include/stinkytofu/ir/asm/StinkyModifiers.hpp index d6164a25987f..340819d8fb83 100644 --- a/shared/stinkytofu/include/stinkytofu/ir/asm/StinkyModifiers.hpp +++ b/shared/stinkytofu/include/stinkytofu/ir/asm/StinkyModifiers.hpp @@ -272,11 +272,17 @@ struct DPPModifiers : public TypedModifier { int row_bcast; int bound_ctrl; - DPPModifiers(int row_shr = -1, int row_bcast = -1, int bound_ctrl = -1) + // Raw dpp_ctrl string for modes not modelled as typed fields + // (e.g. "row_xmask:8", "row_shl:4"). Emitted verbatim. + std::string dppCtrl; + + DPPModifiers(int row_shr = -1, int row_bcast = -1, int bound_ctrl = -1, + const std::string& dppCtrl = "") : TypedModifier(), row_shr(row_shr), row_bcast(row_bcast), - bound_ctrl(bound_ctrl) {} + bound_ctrl(bound_ctrl), + dppCtrl(dppCtrl) {} }; struct VOP3Modifiers : public TypedModifier { diff --git a/shared/stinkytofu/src/serialization/asm/ModifierSerializer.cpp b/shared/stinkytofu/src/serialization/asm/ModifierSerializer.cpp index 39388f9fdda1..7e10c5f193dd 100644 --- a/shared/stinkytofu/src/serialization/asm/ModifierSerializer.cpp +++ b/shared/stinkytofu/src/serialization/asm/ModifierSerializer.cpp @@ -410,19 +410,32 @@ void deserializeVisit(StinkyInstruction* inst, const std::string& attrKey, } else if (attrKey == "mod.swaitstorecnt") { inst->addModifier(SWaitStoreCntData(static_cast(getInt(fields, "storecnt", -1)))); } else if (attrKey == "mod.mfma") { + // Build inputPermute from individual matrix format fields when not + // provided as a pre-built string (RawAsmParser provides individual + // fields, STIR/pipeline path provides the pre-built string). + std::string inputPermute = getStr(fields, "inputPermute"); + if (inputPermute.empty()) { + for (const auto& key : + {"matrix_a_fmt", "matrix_b_fmt", "matrix_a_scale_fmt", "matrix_b_scale_fmt"}) { + auto it = fields.find(key); + if (it != fields.end()) { + if (!inputPermute.empty()) inputPermute += ' '; + inputPermute += std::string(key) + ":" + it->second; + } + } + } bool isMX = getBool(fields, "isMXMFMA", false); if (isMX) { - inst->addModifier(MFMAModifiers( - getStr(fields, "inputPermute"), getStr(fields, "scaleStr"), - getStr(fields, "negStr"), getBool(fields, "reuseA", false), - getBool(fields, "reuseB", false), getInt(fields, "mxInstType", 0), - getInt(fields, "mxScaleAType", 0), getInt(fields, "mxScaleBType", 0))); + inst->addModifier( + MFMAModifiers(inputPermute, getStr(fields, "scaleStr"), getStr(fields, "negStr"), + getBool(fields, "reuseA", false), getBool(fields, "reuseB", false), + getInt(fields, "mxInstType", 0), getInt(fields, "mxScaleAType", 0), + getInt(fields, "mxScaleBType", 0))); } else { inst->addModifier( - MFMAModifiers(getStr(fields, "inputPermute"), getStr(fields, "scaleStr"), - getStr(fields, "negStr"), getBool(fields, "reuseA", false), - getBool(fields, "reuseB", false), getBool(fields, "neg_lo", false), - getBool(fields, "neg_hi", false))); + MFMAModifiers(inputPermute, getStr(fields, "scaleStr"), getStr(fields, "negStr"), + getBool(fields, "reuseA", false), getBool(fields, "reuseB", false), + getBool(fields, "neg_lo", false), getBool(fields, "neg_hi", false))); } } else if (attrKey == "mod.delayalu") { auto toInstType = [](const std::string& s) { @@ -455,8 +468,12 @@ void deserializeVisit(StinkyInstruction* inst, const std::string& attrKey, if (fields.count("tokens")) { inst->addModifier(MemTokenData(getIntVector(fields, "tokens"))); } + } else if (attrKey == "mod.dpp") { + inst->addModifier( + DPPModifiers(getInt(fields, "row_shr", -1), getInt(fields, "row_bcast", -1), + getInt(fields, "bound_ctrl", -1), getStr(fields, "dppCtrl"))); } - // mod.sdwa, mod.dpp, mod.vop3p, mod.true16: no deserialize support yet + // mod.sdwa, mod.vop3p, mod.true16: no deserialize support yet } } // namespace diff --git a/shared/stinkytofu/src/serialization/asm/RawAsmParser.cpp b/shared/stinkytofu/src/serialization/asm/RawAsmParser.cpp index b5a0381292ed..fd0d260d2872 100644 --- a/shared/stinkytofu/src/serialization/asm/RawAsmParser.cpp +++ b/shared/stinkytofu/src/serialization/asm/RawAsmParser.cpp @@ -400,6 +400,169 @@ std::optional parseOneRegister(IRLexer& lexer, const SymbolTable return StinkyRegister(regTypeStr + "[" + bracketContent + "]"); } +//---------------------------------------------------------------------- +// FieldType helpers +//---------------------------------------------------------------------- + +/// Returns true if the given FieldType uses custom textual syntax that +/// parseOneRegister cannot handle. These operands are dispatched to +/// dedicated parsers in the per-field operand loop. +bool hasCustomOperandSyntax(FieldType ft) { + switch (ft) { + case FieldType::delay: + case FieldType::wait_alu: + case FieldType::hwreg: + return true; + default: + return false; + } +} + +//---------------------------------------------------------------------- +// Custom operand parsers +//---------------------------------------------------------------------- + +/// Parse s_delay_alu instid0/instskip/instid1 syntax into mod.delayalu fields. +bool parseDelayAluSyntax(IRLexer& lexer, ParsedInstruction& inst) { + auto parseInstId = [](const std::string& s) -> std::pair { + if (s.find("VALU_DEP_") == 0) return {"VALU", std::stoi(s.substr(9))}; + if (s.find("TRANS32_DEP_") == 0) return {"TRANS", std::stoi(s.substr(12))}; + if (s.find("SALU_CYCLE_") == 0) return {"SALU", std::stoi(s.substr(11))}; + return {"NO_DEP", 0}; + }; + auto parseSkip = [](const std::string& s) -> int { + if (s == "SAME") return 0; + if (s == "NEXT") return 1; + if (s.find("SKIP_") == 0) return std::stoi(s.substr(5)) + 1; + return 0; + }; + auto& fields = inst.modifiers["mod.delayalu"]; + while (!lexer.isAtEnd() && lexer.peek().kind != TokenKind::Eof && + lexer.peek().kind != TokenKind::Newline) { + if (lexer.peek().kind != TokenKind::Identifier) { + lexer.consume(); + continue; + } + std::string key(lexer.consume().text); + std::string val; + if (!lexer.isAtEnd() && lexer.peek().kind == TokenKind::LeftParen) { + lexer.consume(); + if (!lexer.isAtEnd()) val = std::string(lexer.consume().text); + if (!lexer.isAtEnd() && lexer.peek().kind == TokenKind::RightParen) lexer.consume(); + } + if (key == "instid0") { + auto [type, dist] = parseInstId(val); + fields["instid0Type"] = type; + fields["instid0Distance"] = std::to_string(dist); + } else if (key == "instskip") { + fields["instSkip"] = std::to_string(parseSkip(val)); + } else if (key == "instid1") { + auto [type, dist] = parseInstId(val); + fields["instid1Type"] = type; + fields["instid1Distance"] = std::to_string(dist); + } + } + if (fields.find("instid0Type") == fields.end()) { + fields["instid0Type"] = "NO_DEP"; + fields["instid0Distance"] = "0"; + } + return true; +} + +/// Parse s_wait_alu depctr_*() syntax into mod.waitalu fields. +bool parseWaitAluSyntax(IRLexer& lexer, ParsedInstruction& inst) { + auto& fields = inst.modifiers["mod.waitalu"]; + while (!lexer.isAtEnd() && lexer.peek().kind != TokenKind::Eof && + lexer.peek().kind != TokenKind::Newline) { + if (lexer.peek().kind != TokenKind::Identifier) { + lexer.consume(); + continue; + } + std::string key(lexer.consume().text); + std::string val; + if (!lexer.isAtEnd() && lexer.peek().kind == TokenKind::LeftParen) { + lexer.consume(); + if (!lexer.isAtEnd()) val = std::string(lexer.consume().text); + if (!lexer.isAtEnd() && lexer.peek().kind == TokenKind::RightParen) lexer.consume(); + } + if (key.find("depctr_") == 0) key = key.substr(7); + if (!key.empty() && !val.empty()) fields[key] = val; + } + return true; +} + +/// Map of string key → string value used for parsed modifier fields. +using FieldMap = std::unordered_map; + +/// Returns true if the field map contains any DPP control modifier key. +/// DPP is an add-on encoding variant layered on top of VOP formats — +/// the assembler picks a wider encoding when DPP modifiers are present. +bool hasDPPFields(const FieldMap& fields) { + // dpp_ctrl modes (13 symbolic names) + return fields.count("quad_perm") || fields.count("row_shl") || fields.count("row_shr") || + fields.count("row_ror") || fields.count("wave_shl") || fields.count("wave_shr") || + fields.count("wave_rol") || fields.count("wave_ror") || fields.count("row_mirror") || + fields.count("row_half_mirror") || fields.count("row_bcast") || + fields.count("row_share") || fields.count("row_xmask") || + // other DPP encoding fields + fields.count("bound_ctrl") || fields.count("row_mask") || fields.count("bank_mask") || + fields.count("fi"); +} + +/// Parse hwreg(id, offset, width) syntax and store as a LiteralString +/// register on the instruction for verbatim round-trip emission. +void parseHwregOperand(IRLexer& lexer, ParsedInstruction& inst, + const HwInstDesc::OperandFieldDesc& field) { + if (lexer.isAtEnd() || lexer.peek().kind != TokenKind::Identifier) return; + std::string text(lexer.peek().text); + if (text != "hwreg") return; + lexer.consume(); + + if (lexer.isAtEnd() || lexer.peek().kind != TokenKind::LeftParen) { + StinkyRegister reg(text); + if (field.isDest) + inst.destRegs.push_back(reg); + else + inst.srcRegs.push_back(reg); + return; + } + text += '('; + lexer.consume(); + + int depth = 1; + while (!lexer.isAtEnd() && depth > 0) { + const auto& tok = lexer.consume(); + if (tok.kind == TokenKind::LeftParen) depth++; + if (tok.kind == TokenKind::RightParen) depth--; + if (depth > 0) text += std::string(tok.text); + } + text += ')'; + + StinkyRegister reg(text); + if (field.isDest) + inst.destRegs.push_back(reg); + else + inst.srcRegs.push_back(reg); +} + +/// Dispatch a custom-syntax operand to its dedicated parser based on FieldType. +void parseCustomOperand(IRLexer& lexer, ParsedInstruction& inst, + const HwInstDesc::OperandFieldDesc& field) { + switch (field.fieldType) { + case FieldType::delay: + parseDelayAluSyntax(lexer, inst); + break; + case FieldType::wait_alu: + parseWaitAluSyntax(lexer, inst); + break; + case FieldType::hwreg: + parseHwregOperand(lexer, inst, field); + break; + default: + break; + } +} + //---------------------------------------------------------------------- // Modifier parsing //---------------------------------------------------------------------- @@ -421,7 +584,6 @@ bool parseModifiers(IRLexer& lexer, ParsedInstruction& inst, const HwInstDesc* h const std::string& mnemonic = inst.opcodeStr; bool isWaitcnt = (mnemonic == "s_waitcnt"); - bool isDelayAlu = (mnemonic == "s_delay_alu"); // Determine modifier namespace from microcode format std::string modKey; @@ -444,6 +606,10 @@ bool parseModifiers(IRLexer& lexer, ParsedInstruction& inst, const HwInstDesc* h case MicrocodeFormat::MC_SMEM: modKey = "mod.smem"; break; + case MicrocodeFormat::MC_VOP3PX2: + case MicrocodeFormat::MC_VOP3PX3: + modKey = "mod.mfma"; + break; default: break; } @@ -459,15 +625,6 @@ bool parseModifiers(IRLexer& lexer, ParsedInstruction& inst, const HwInstDesc* h } } - // s_delay_alu: store the entire remainder as a TEXTBLOCK (complex syntax) - // The instruction will still be created with correct timing; modifiers won't - // be decoded but that's acceptable for a round-trip parser. - if (isDelayAlu) { - // Leave modifiers empty; the passes that need delay_alu data (e.g. wait-cnt - // insertion) will re-insert the correct modifier after scheduling. - return true; - } - // Collect key→value / key(value) / boolean-flag tokens FieldMap fields; // Set when we see syntax we can parse but cannot represent: e.g. @@ -517,14 +674,11 @@ bool parseModifiers(IRLexer& lexer, ParsedInstruction& inst, const HwInstDesc* h fields[tok] = std::move(folded); sawAnyModifier = true; } else if (vk == TokenKind::Identifier) { - // `th:TH_LOAD_NT`, `scope:SCOPE_DEV`, `matrix_a_fmt:MATRIX_FMT_FP8`, - // etc. — gfx12+ memory-hint / matrix-format syntax. Not - // modelled by any of the existing modifier structs, so - // consume the rhs to keep the lexer in sync but signal - // that the line cannot be losslessly round-tripped via - // inst.modifiers; the caller will route to TEXTBLOCK. - lexer.consume(); - sawUnrepresentable = true; + // key:Identifier modifiers (e.g. matrix_a_fmt:MATRIX_FMT_FP8, + // th:TH_LOAD_NT, scope:SCOPE_DEV). Store in fields for the + // modKey dispatch below. Formats without a modifier namespace + // are caught by the modKey.empty() check after the loop. + fields[tok] = std::string(lexer.consume().text); sawAnyModifier = true; } } else if (lexer.peek().kind == TokenKind::LeftParen) { @@ -562,7 +716,16 @@ bool parseModifiers(IRLexer& lexer, ParsedInstruction& inst, const HwInstDesc* h // would be silently discarded below at the `modKey.empty()` early-return, // so flag the line as unrepresentable instead and let the caller fall // back to TEXTBLOCK pass-through. - if (modKey.empty() && sawAnyModifier) sawUnrepresentable = true; + // When no modKey was assigned from the microcode format, check if the + // collected fields match a known add-on modifier pattern. DPP is an + // add-on encoding variant whose base microcode format doesn't have a + // dedicated modKey — detect by checking for DPP-specific field names. + if (modKey.empty() && sawAnyModifier) { + if (hasDPPFields(fields)) + modKey = "mod.dpp"; + else + sawUnrepresentable = true; + } if (modKey.empty() || fields.empty()) return !sawUnrepresentable; @@ -602,6 +765,32 @@ bool parseModifiers(IRLexer& lexer, ParsedInstruction& inst, const HwInstDesc* h if (fields.count("offset")) modFields["offset"] = fields["offset"]; if (fields.count("glc")) modFields["glc"] = "true"; if (fields.count("nv")) modFields["nv"] = "true"; + + } else if (modKey == "mod.dpp") { + if (fields.count("row_shr")) modFields["row_shr"] = fields["row_shr"]; + if (fields.count("row_bcast")) modFields["row_bcast"] = fields["row_bcast"]; + if (fields.count("bound_ctrl")) modFields["bound_ctrl"] = fields["bound_ctrl"]; + // Collect remaining (unmodelled) DPP fields into dppCtrl string for + // verbatim round-trip. Skip the typed fields handled above. + std::string dppCtrl; + for (const auto& [key, val] : fields) { + if (key == "row_shr" || key == "row_bcast" || key == "bound_ctrl") continue; + // Boolean flags (e.g. row_mirror) are stored as "true" — + // emit just the key, not "key:true". + dppCtrl = (val == "true") ? key : key + ":" + val; + break; + } + if (!dppCtrl.empty()) modFields["dppCtrl"] = dppCtrl; + + } else if (modKey == "mod.mfma") { + if (fields.count("matrix_a_fmt")) modFields["matrix_a_fmt"] = fields["matrix_a_fmt"]; + if (fields.count("matrix_b_fmt")) modFields["matrix_b_fmt"] = fields["matrix_b_fmt"]; + if (fields.count("matrix_a_scale_fmt")) + modFields["matrix_a_scale_fmt"] = fields["matrix_a_scale_fmt"]; + if (fields.count("matrix_b_scale_fmt")) + modFields["matrix_b_scale_fmt"] = fields["matrix_b_scale_fmt"]; + if (fields.count("matrix_a_reuse")) modFields["reuseA"] = "true"; + if (fields.count("matrix_b_reuse")) modFields["reuseB"] = "true"; } return !sawUnrepresentable; @@ -641,38 +830,36 @@ std::unique_ptr parseInstLine(const std::string& line, GfxArc inst->issueCycles = hwInstDesc->issue; inst->latencyCycles = hwInstDesc->latency; - // Count dest/src split from operandFields. - // When operandFields is empty the instruction has no explicit register operands - // (e.g. s_endpgm, s_waitcnt, s_delay_alu) so skip register parsing entirely. - int numDest = 0; - bool hasOperandFields = !hwInstDesc->operandFields.empty(); - if (hasOperandFields) { - for (const auto& f : hwInstDesc->operandFields) - if (f.isDest) numDest++; - } - - if (hasOperandFields) { - // Parse comma-separated register operands. - // Stop when: (a) no comma follows, or (b) next token is not a register. - int opIdx = 0; + // Per-field operand dispatch: iterate operandFields and dispatch each + // to the appropriate parser based on FieldType. Custom-syntax fields + // (delay, wait_alu) are parsed by dedicated parsers here. + // Register/immediate fields are parsed by parseOneRegister. + if (!hwInstDesc->operandFields.empty()) { bool firstOp = true; + int regOpIdx = 0; + + for (size_t fi = 0; fi < hwInstDesc->operandFields.size(); fi++) { + if (lexer.isAtEnd() || lexer.peek().kind == TokenKind::Eof || + lexer.peek().kind == TokenKind::Newline) + break; + + const auto& field = hwInstDesc->operandFields[fi]; - while (!lexer.isAtEnd() && lexer.peek().kind != TokenKind::Eof && - lexer.peek().kind != TokenKind::Newline) { if (!firstOp) { if (lexer.peek().kind != TokenKind::Comma) break; lexer.consume(); // eat ',' } firstOp = false; + // Custom-syntax operand: dispatch to dedicated parser. + if (hasCustomOperandSyntax(field.fieldType)) { + parseCustomOperand(lexer, *inst, field); + continue; + } + // Snapshot the lookahead so we can recover an unrecognised - // identifier as a LiteralString if parseOneRegister fails on a - // non-first operand (see recovery block below). For first - // operands we still want to fall through to TEXTBLOCK pass- - // through so instructions with custom textual syntax (e.g. - // `s_delay_alu instid0(VALU_DEP_2)`, `s_wait_alu depctr_va_vdst(0)`) - // round-trip verbatim instead of being half-parsed and then - // tripping the SDelayAluData/SWaitAluData asserts in the emitter. + // identifier as a LiteralString if parseOneRegister fails on + // a non-first operand (see recovery block below). TokenKind preKind = lexer.isAtEnd() ? TokenKind::Eof : lexer.peek().kind; std::string preText = preKind == TokenKind::Identifier ? std::string(lexer.peek().text) : std::string(); @@ -680,12 +867,10 @@ std::unique_ptr parseInstLine(const std::string& line, GfxArc auto reg = parseOneRegister(lexer, syms, preserveSymbolicNames); if (!reg) { - if (opIdx == 0) { - // First operand failed (e.g. symbolic expression like - // v[sym-768:sym-765] or a custom-syntax mnemonic such as - // s_delay_alu / s_wait_alu). parseOneRegister may have - // consumed tokens; safest to preserve the entire line - // verbatim via TEXTBLOCK pass-through. + if (regOpIdx == 0) { + // First register operand failed (e.g. unresolvable symbolic + // expression like v[sym-768:sym-765]). parseOneRegister may + // have consumed tokens; preserve the line as TEXTBLOCK. return nullptr; } // Non-first operand recovery: a `.set` symbol (or any other @@ -707,22 +892,22 @@ std::unique_ptr parseInstLine(const std::string& line, GfxArc // as the IntegerLiteral path in parseOneRegister. std::string text = gatherArithExprSuffix(lexer, preText); StinkyRegister litReg(text); - if (opIdx < numDest) + if (field.isDest) inst->destRegs.push_back(litReg); else inst->srcRegs.push_back(litReg); - opIdx++; + regOpIdx++; continue; } // Later operand failed → stop operand parsing, proceed to modifiers. break; } - if (opIdx < numDest) + if (field.isDest) inst->destRegs.push_back(*reg); else inst->srcRegs.push_back(*reg); - opIdx++; + regOpIdx++; } } diff --git a/shared/stinkytofu/src/serialization/asm/StinkyAsmEmitter.cpp b/shared/stinkytofu/src/serialization/asm/StinkyAsmEmitter.cpp index fa6b676dd6a3..1b68df64065b 100644 --- a/shared/stinkytofu/src/serialization/asm/StinkyAsmEmitter.cpp +++ b/shared/stinkytofu/src/serialization/asm/StinkyAsmEmitter.cpp @@ -272,6 +272,7 @@ inline std::ostream& operator<<(std::ostream& os, const SDWAModifiers& sdwaMod) } inline std::ostream& operator<<(std::ostream& os, const DPPModifiers& dppMod) { + if (!dppMod.dppCtrl.empty()) os << " " << dppMod.dppCtrl; if (dppMod.row_shr != -1) os << " row_shr:" << dppMod.row_shr; if (dppMod.row_bcast != -1) os << " row_bcast:" << dppMod.row_bcast; if (dppMod.bound_ctrl != -1) os << " bound_ctrl:" << dppMod.bound_ctrl;