From 625dfe78d0a51dc99f4ee962b16b7681ea25d3e6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 28 Mar 2026 11:50:10 +0100 Subject: [PATCH 1/3] Add cross-class register spilling (VGPR -> AGPR) to linear scan allocator When VGPR allocation fails, the allocator now evicts the best candidate (untied, size-1, fewest uses, longest range) to a spare AGPR instead of failing with a hard error. The pass then inserts v_accvgpr_write_b32 after the victim's def and v_accvgpr_read_b32 into scratch v14 before each use site. This replaces the need for fragile post-regalloc Python assembly rewriting when kernels slightly exceed the 256-VGPR hardware limit while leaving AGPRs unused (e.g. 256x192 MXFP4 GEMM needing ~261 VGPRs with 64 spare AGPRs). Key changes: - RegAlloc.h: SpillRecord struct, AllocResult bundle, RegPool::hasFree() - LinearScanRegAlloc.cpp: findEvictionCandidate() heuristic, cross-class eviction in allocateRegClass() with alternate pool parameter - LinearScanPass.cpp: v14 reservation, insertSpillReloads() creates PrecoloredARegOp slots and v_accvgpr_write/read pairs - HazardMitigation.cpp: v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+) Also supports SGPR -> VGPR and AGPR -> VGPR eviction paths via the same mechanism (alternate pool wiring in allocate()). Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Ivan Butygin --- docs/wave/cross-class-spilling.md | 436 ++++++++++++++++++ waveasm/include/waveasm/Transforms/RegAlloc.h | 36 +- waveasm/lib/Transforms/HazardMitigation.cpp | 21 +- waveasm/lib/Transforms/LinearScanPass.cpp | 115 ++++- waveasm/lib/Transforms/LinearScanRegAlloc.cpp | 144 ++++-- .../cross-class-spill-vgpr-to-agpr.mlir | 37 ++ 6 files changed, 729 insertions(+), 60 deletions(-) create mode 100644 docs/wave/cross-class-spilling.md create mode 100644 waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir diff --git a/docs/wave/cross-class-spilling.md b/docs/wave/cross-class-spilling.md new file mode 100644 index 0000000000..9ddda1c934 --- /dev/null +++ b/docs/wave/cross-class-spilling.md @@ -0,0 +1,436 @@ +# Cross-Class Register Spilling in WaveASM + +- **Author:** Ivan Butygin +- **Status:** Draft +- **Created:** 2026-03-28 + +## Table of Contents + +- [Problem Statement](#problem-statement) +- [Background](#background) +- [Design](#design) +- [Implementation Plan](#implementation-plan) +- [Alternatives Considered](#alternatives-considered) +- [Open Questions](#open-questions) + +## Problem Statement + +The WaveASM linear-scan register allocator treats VGPR, SGPR, and AGPR pools +as independent. When any pool is exhausted the compilation fails with a hard +error. There is no spilling of any kind. + +The 256x192x256 block MXFP4 GEMM with dynamic M/N/K allocates ~261 VGPRs, +5 over the gfx950 hardware limit of 256. The kernel simultaneously leaves +~64 AGPRs unused (only 192 of 256 consumed by MFMA accumulators). + +A post-regalloc Python hack rewrites assembly text to shuttle excess VGPRs +through AGPRs via `v_accvgpr_read/write`. It works but is fragile: no +dataflow analysis, regex-based operand classification, hardcoded scratch +registers, and `s_waitcnt vmcnt(0)` sledgehammers that destroy latency +hiding. + +This document proposes a proper cross-class spilling mechanism inside the +linear-scan allocator where liveness, tied values, and hazard information +are available. + +## Background + +### AMDGCN Register File (gfx9) + +| Class | Count | Typical use | Move cost (cycles) | +|-------|-------|--------------------------------|--------------------| +| SGPR | 102 | Scalar values, SRDs, loop IVs | -- | +| VGPR | 256 | Per-lane vector values | -- | +| AGPR | 256 | MFMA accumulators | 4 (read/write) | + +Cross-class move instructions: + +- **SGPR -> VGPR**: `v_mov_b32 vD, sS` (1 VALU cycle). +- **VGPR -> SGPR**: `v_readfirstlane_b32 sD, vS` (1 VALU cycle, lane 0). +- **VGPR -> AGPR**: `v_accvgpr_write_b32 aD, vS` (4 cycles). +- **AGPR -> VGPR**: `v_accvgpr_read_b32 vD, aS` (4 cycles, +1 s_nop on + gfx950 before consumer). + +There is no direct SGPR <-> AGPR path; a VGPR scratch is required as +intermediate. + +### LLVM Precedent + +LLVM's AMDGPU backend implements SGPR -> VGPR spilling +(`SILowerSGPRSpills`). When SGPRs are exhausted, values are parked in +dedicated VGPR "spill lanes." Reload uses `v_readfirstlane_b32`. + +LLVM does **not** implement VGPR -> AGPR spilling. When VGPRs overflow, +values spill to scratch memory via `buffer_store/load` on the private +segment. This costs ~100+ cycles per access and consumes VMEM bandwidth. + +The proposed cross-class approach avoids scratch memory entirely for moderate +overflow (tens of registers) at a cost of 4-8 VALU cycles per spill/reload. + +### Current Allocator Architecture + +Relevant source files (all under `waveasm/`): + +``` +include/waveasm/Transforms/RegAlloc.h -- RegPool, PhysicalMapping, allocator class +lib/Transforms/LinearScanRegAlloc.cpp -- core linear scan +lib/Transforms/LinearScanPass.cpp -- pass wrapper, precoloring +lib/Transforms/Liveness.cpp -- live range computation +``` + +Key properties: + +1. Three independent `RegPool` instances (VGPR, SGPR, AGPR). +2. Live ranges sorted by `(start, end)` and processed in start-point order. +3. Two separate tie mechanisms: `TiedValueClasses` in Liveness for loop + iter_args (block_arg + init_arg + iter_arg + result), and a separate + `tiedPairs` map in LinearScanPass for MFMA accumulator -> result ties. +4. Bottom-up allocation only (`allocSingle`, `allocRange`). No top-down + or bidirectional heuristics. +5. `allocateRegClass()` receives a single `RegPool &pool` -- it has no + access to other register classes when allocation fails. +6. Hard error on pool exhaustion -- no eviction, no spilling. +7. No rematerialization. No post-regalloc compaction. +8. v15 reserved for literal materialization (`kScratchVGPR` in + AssemblyEmitter.h). v14 is not currently reserved. +9. Hazard mitigation pass handles VALU -> `v_readfirstlane` and + Trans -> VALU hazards only. No `v_accvgpr_read_b32` RAW hazard + handling exists yet. + +## Design + +### Spill Cascade + +When `tryAllocate()` returns no register for a given class, the allocator +attempts cross-class eviction before failing: + +``` +SGPR overflow --> park in spare VGPR (v_readfirstlane to reload) +VGPR overflow --> park in spare AGPR (v_accvgpr_write/read) +AGPR overflow --> park in spare VGPR (v_accvgpr_read/write) +all exhausted --> hard error (future: scratch memory) +``` + +The cascade is one-level deep: a VGPR spilled to AGPR does not trigger +further cascading. This keeps the design simple and covers the practical +cases (VGPR overflow with spare AGPRs being the dominant scenario). + +### Eviction Strategy + +When the incoming range cannot be allocated, we evict an already-allocated +**victim** range to the alternate class. This frees the victim's physical +register for the incoming range. + +Victim selection criteria (in priority order): + +1. **Fewest use sites** in the hot region (loop body). Each use site + requires a reload instruction, so fewer uses means lower spill cost. +2. **Not part of a tied equivalence class.** Spilling a tied value + requires spilling the entire class (loop init_arg + block_arg + + iter_arg + result all share one physical register). Avoid this + unless there is no untied candidate. +3. **Longest remaining range.** Among equal-cost candidates, evicting + the longest range frees the register for the most time, reducing + future pressure. + +The victim must have `size == 1`. Multi-register spills (size 2/4/8) +require contiguous ranges in the target class and complicate reload +sequences. For the initial implementation we restrict to single-register +eviction. This covers the common case (individual scalar values that +happen to sit in VGPRs, or single excess VGPRs from address computation). + +### Spill and Reload Insertion + +The spill/reload is a **range split**: the victim's original live range is +shortened to just its def point, and a new "spill range" in the alternate +class covers the rest. At each use site a reload is inserted. + +#### VGPR -> AGPR + +At the victim's def point, **after** the defining op: + +```asm +v_accvgpr_write_b32 aX, vY ; spill to AGPR +``` + +If the def is an async load (`buffer_load`, `global_load`), the spill +must wait for the load to complete. Rather than inserting a blanket +`s_waitcnt vmcnt(0)`, the spill point is deferred to just before the +first use, and the existing waitcnt analysis pass handles the dependency. +(See [Interaction with WaitCnt Pass](#waitcnt-pass) below.) + +Before each use site: + +```asm +v_accvgpr_read_b32 vSCR, aX ; reload to scratch VGPR +; (hazard mitigation pass inserts s_nop if needed) + +``` + +The scratch register `vSCR` is a dedicated per-spill-site temporary +allocated from a small reserved pool (currently v14, v15). After the +use, `vSCR` is dead and available for the next reload. + +#### SGPR -> VGPR + +At the victim's def point: + +```asm +v_mov_b32 vX, sY ; broadcast scalar to VGPR lane +``` + +Before each use site: + +```asm +v_readfirstlane_b32 sSCR, vX ; extract lane 0 back to SGPR + +``` + +#### AGPR -> VGPR + +Same as VGPR -> AGPR but reversed: + +```asm +v_accvgpr_read_b32 vX, aY ; at def +v_accvgpr_write_b32 aSCR, vX ; reload before use +``` + +This direction is uncommon (MFMA-heavy kernels rarely overflow AGPRs +while having spare VGPRs) but is included for completeness. + +### Scratch Register Management + +The current allocator reserves v15 for literal materialization +(`kScratchVGPR` in AssemblyEmitter.h, reserved at LinearScanPass.cpp:159). +v14 is not currently reserved. + +Proposed approach: + +- Reserve v14 as a second scratch VGPR for spill reloads + (`reservedVGPRs.insert(14)` in LinearScanPass.cpp). This gives a + **scratch pool** of 2 VGPRs (v14, v15). +- Each reload site needs exactly 1 scratch register for the duration + of one instruction. With 2 scratches available, two independent + reloads can be in flight (e.g., an instruction reading two different + spilled values). +- If more than 2 simultaneous reloads are needed at one instruction, + the allocator must serialize them (reload one, use it, then reload + the next). This is handled by the reload insertion logic. +- Scratch registers for SGPR reloads (`sSCR`) are drawn from the + SGPR pool. One dedicated SGPR scratch (e.g., s0, since kernel args + are already loaded) suffices for single-register spills. + +### Interaction with Existing Passes + +#### Liveness + +No changes to the liveness pass itself. `LiveRange` carries a `regClass` +field and results are separated into `vregRanges`, `sregRanges`, +`aregRanges` (Liveness.cpp:662-669). The `usePoints` map +(Liveness.h:177) stores per-value use-site indices, which the eviction +heuristic needs for "fewest use sites" ranking. + +The spill/reload insertion creates new SSA values with their own def/use +points; the liveness data is recomputed if needed (or the spill is done +in a fixup walk after initial allocation). + +#### Tied Values + +Two separate tie mechanisms exist: + +- **Loop ties** (`TiedValueClasses` in Liveness): group block_arg + + init_arg + iter_arg + loop_result into equivalence classes. +- **MFMA ties** (`tiedPairs` in LinearScanPass): map MFMA result -> + accumulator operand, added via `allocator.addTiedOperand()`. + +Neither kind is a candidate for cross-class spilling in the initial +implementation. Spilling a loop-tied value requires spill/reload at +every back-edge. Spilling an MFMA-tied value means the accumulator +lives in an AGPR but the result needs a VGPR, which defeats the tie. + +If the *only* way to satisfy pressure is to spill a tied value, the +allocator falls back to the hard error. A future extension could handle +this by breaking the tie and inserting explicit copies on back-edges, but +this is out of scope. + +#### WaitCnt Pass (Ticketing.cpp) + +The existing `--waveasm-insert-waitcnt` pass tracks outstanding VMEM and +LGKM operations and inserts `s_waitcnt` based on operand dependencies. +If a `v_accvgpr_write` consumes a VGPR defined by a `buffer_load`, the +pass must ensure the load has completed. + +The spill insertion should emit the `v_accvgpr_write` as a normal +WaveASM op. The waitcnt pass then sees its operand dependency and +inserts the appropriate `s_waitcnt vmcnt(N)` with a precise count +rather than the blanket `vmcnt(0)` used by the Python hack. + +**Caveat**: the waitcnt pass currently tracks memory ops and their +*direct* consumers. Verify that a `v_accvgpr_write_b32` reading a +`buffer_load` result is recognized as a consumer that needs a wait. + +#### Hazard Mitigation + +The `--waveasm-hazard-mitigation` pass currently handles only two +hazard patterns (HazardMitigation.cpp): + +1. VALU -> `v_readfirstlane_b32` RAW hazard. +2. Transcendental -> non-Trans VALU forwarding hazard. + +It does **not** handle `v_accvgpr_read_b32` RAW hazards on GFX950. +This must be added as part of Phase 1: after a `v_accvgpr_read_b32` +writes a VGPR, the next consumer needs an `s_nop 0` (or scheduling +gap) to avoid silent data corruption. + +#### ScopedCSE + +The `--mlir-cse` pass runs before regalloc. Spill/reload ops are +inserted after CSE, so there is no risk of the CSE pass merging a +spill reload with an unrelated `v_accvgpr_read`. + +## Implementation Plan + +### Phase 1: VGPR -> AGPR Spilling (MVP) + +This covers the immediate need (256x192 GEMM, 5 excess VGPRs, 64 spare +AGPRs). + +1. **RegAlloc.h**: Add `SpillRecord` struct and expose spill results: + + ```cpp + struct SpillRecord { + Value victim; // original SSA value + int64_t sourcePhysReg; // freed VGPR index + int64_t targetPhysReg; // allocated AGPR index + RegClass sourceClass; // VGPR + RegClass targetClass; // AGPR + }; + ``` + + The allocator's return type must be extended to include + `SmallVector` alongside `PhysicalMapping` and + `AllocationStats`. + +2. **LinearScanRegAlloc.cpp**: `allocateRegClass()` currently receives + a single `RegPool &pool`. Change signature to also accept the + alternate-class pool (AGPR pool when allocating VGPRs). When + `tryAllocate()` fails: + + a. Scan active list for eviction candidates (untied, size==1, + fewest uses via `LivenessInfo::usePoints`). + b. Allocate an AGPR from the alternate pool for the victim. + c. Free the victim's VGPR. + d. Re-attempt `tryAllocate()` for the incoming range. + e. Record eviction in `SmallVector`. + +3. **LinearScanPass.cpp**: Reserve v14 as spill scratch + (`reservedVGPRs.insert(14)`) alongside the existing v15 reservation. + After `allocator.allocate()` succeeds, process `SpillRecord` list + (insertion point: after mapping adjustments at ~line 290, before IR + type transformation at ~line 292): + + a. For each spill, walk the program and insert + `v_accvgpr_write_b32` after the victim's def. + b. Before each use of the victim, insert + `v_accvgpr_read_b32 vSCR, aX`. + c. Rewrite the use to read from `vSCR`. + d. Update the op's result type from `PVRegType` to `PARegType` + for the spilled value (or keep as VGPR and let the inserted + ops handle it -- TBD based on what is cleaner for downstream + passes). + +4. **HazardMitigation.cpp**: Add `v_accvgpr_read_b32` RAW hazard + detection for GFX950. When a `v_accvgpr_read_b32` writes a VGPR + and the next instruction consumes it, insert `s_nop 0`. This is + required for correctness -- without it, MFMA scale operands can be + silently corrupted. + +5. **Tests**: Add LIT tests in `waveasm/test/Transforms/`: + + - `cross-class-spill-vgpr-to-agpr.mlir`: Verify spill/reload + insertion with max-vgprs=4 on a program needing 5. + - `cross-class-spill-no-tied.mlir`: Verify tied values are not + spilled. + - `cross-class-spill-async-load.mlir`: Verify correct ordering + when the spilled value comes from a buffer_load. + +### Phase 2: SGPR -> VGPR Spilling + +Same pattern, different move instructions. Lower priority since SGPR +overflow is less common in current kernels (SALU promotion already +keeps most scalars in SGPRs). + +1. When SGPR allocation fails, find an untied SGPR victim. +2. Allocate a VGPR for it. +3. Insert `v_mov_b32 vX, sY` at def, `v_readfirstlane_b32 sSCR, vX` + at uses. + +### Phase 3: AGPR -> VGPR Spilling + +Reverse of Phase 1. Lowest priority -- AGPR overflow with spare VGPRs +is rare in practice. + +### Phase 4 (Future): Scratch Memory Spill + +When cross-class spilling is insufficient (all three classes at capacity), +the last resort is scratch memory. This requires: + +- A scratch SRD (already available in the kernel descriptor). +- `buffer_store_dword` at spill, `buffer_load_dword` at reload. +- Stack frame offset management. +- Integration with the waitcnt pass for VMEM tracking. + +Out of scope for this design. The cross-class approach should cover +practical cases for the foreseeable kernel sizes. + +## Alternatives Considered + +1. **Post-regalloc assembly text rewriting (current hack).** + Works but unsound: no dataflow analysis, regex operand classification, + blanket waitcnts, hardcoded scratch registers. See Problem Statement. + +2. **Reduce register pressure at the IR level.** + More aggressive rematerialization, shorter live ranges, or kernel + restructuring. This is complementary (and we should do it), but + cannot always close the gap. The 256x192 kernel is already heavily + optimized and still needs 261 VGPRs. + +3. **Scratch memory spilling only (like LLVM).** + Sound but expensive. 4 cycles for an AGPR move vs 100+ cycles for + scratch buffer_load, plus VMEM bandwidth contention with data loads. + Cross-class spilling is strictly better when spare registers exist in + the target class. + +4. **Occupancy reduction.** + Accept fewer waves per CU by using more registers. Not applicable + here: gfx950 has a hard limit of 256 VGPRs regardless of occupancy. + +5. **Second-chance allocation (evict + re-spill).** + When evicting a victim, allow the victim to itself be evicted further + (multi-level cascade). Adds complexity for marginal benefit. + Single-level suffices for tens of excess registers. + +## Open Questions + +1. **Spill at def vs spill at last-use-before-gap?** Spilling at def is + simpler but keeps the AGPR occupied for the entire range. Spilling + at the last use before a pressure spike is optimal but requires + pressure curve analysis. Start with spill-at-def. + +2. **Cost model for eviction.** Use count alone, or weight by loop depth? + A value used once inside a loop body that executes 128 times is + effectively 128 uses. Start with raw use count; refine if profiling + shows poor spill placement. + +3. **Multiple simultaneous spills.** If 5 VGPRs must spill, we need 5 + spare AGPRs and potentially 2+ reloads in a single instruction. The + scratch pool of 2 VGPRs limits us to 2 simultaneous reloads. Should + we grow the scratch pool dynamically, or serialize reloads? Start + with serialization (insert multiple read + nop sequences). + +4. **Interaction with loop-address-promotion.** The promotion pass adds + VGPRs to eliminate VALU from the loop body. If those promoted VGPRs + are then spilled to AGPRs, the VALU savings are partially offset by + spill/reload VALU. Should we teach the promotion pass about the + AGPR budget? Deferred -- the fallback path already disables + promotion when VGPRs exceed the limit. diff --git a/waveasm/include/waveasm/Transforms/RegAlloc.h b/waveasm/include/waveasm/Transforms/RegAlloc.h index 9e46e589dd..b72d95363f 100644 --- a/waveasm/include/waveasm/Transforms/RegAlloc.h +++ b/waveasm/include/waveasm/Transforms/RegAlloc.h @@ -22,6 +22,22 @@ namespace waveasm { +//===----------------------------------------------------------------------===// +// Spill Record +//===----------------------------------------------------------------------===// + +/// Records a cross-class spill decision made during register allocation. +/// The victim value is evicted from its original register class into the +/// alternate class (e.g. VGPR -> AGPR). The pass uses these records to +/// insert spill/reload ops after allocation completes. +struct SpillRecord { + mlir::Value victim; ///< The SSA value being spilled. + int64_t sourcePhysReg; ///< Physical register freed in original class. + int64_t targetPhysReg; ///< Physical register allocated in target class. + RegClass sourceClass; ///< Original register class (e.g. VGPR). + RegClass targetClass; ///< Target register class (e.g. AGPR). +}; + //===----------------------------------------------------------------------===// // Allocation Statistics //===----------------------------------------------------------------------===// @@ -206,6 +222,10 @@ class RegPool { } } + /// Check if the pool has at least one free register. + bool hasFree() const { return free.any(); } + + /// Get peak usage. int64_t getPeakUsage() const { return peak; } int64_t getCurrentUsage() const { return currentUsage; } @@ -294,11 +314,17 @@ class LinearScanRegAlloc { vgprStrategy = std::move(strategy); } - /// Run allocation on a kernel program - /// Returns the physical mapping and statistics, or failure if allocation - /// fails - mlir::FailureOr> - allocate(ProgramOp program); + /// Result bundle returned by allocate(). + struct AllocResult { + PhysicalMapping mapping; + AllocationStats stats; + llvm::SmallVector spills; + }; + + /// Run allocation on a kernel program. + /// Returns the physical mapping, statistics, and any cross-class spill + /// records, or failure if allocation fails. + mlir::FailureOr allocate(ProgramOp program); private: /// Process active ranges, expiring those that end before currentPoint diff --git a/waveasm/lib/Transforms/HazardMitigation.cpp b/waveasm/lib/Transforms/HazardMitigation.cpp index b91d9365fb..c2ec43660a 100644 --- a/waveasm/lib/Transforms/HazardMitigation.cpp +++ b/waveasm/lib/Transforms/HazardMitigation.cpp @@ -8,8 +8,9 @@ // Hazard Mitigation Pass - Insert s_nop instructions for hardware hazards // // This pass handles hardware-specific hazards that require NOP insertion: -// - VALU → v_readfirstlane hazard (gfx940+) -// - Trans → non-Trans VALU forwarding hazard (gfx940+) +// - VALU -> v_readfirstlane hazard (gfx940+) +// - Trans -> non-Trans VALU forwarding hazard (gfx940+) +// - v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+) //===----------------------------------------------------------------------===// #include "waveasm/Dialect/WaveASMAttrs.h" @@ -74,7 +75,7 @@ bool isVALUOp(Operation *op) { return true; } -/// Check if an operation is v_readfirstlane +/// Check if an operation is v_readfirstlane. bool isReadfirstlaneOp(Operation *op) { return isa(op); } /// Check if an operation does NOT emit an assembly instruction. @@ -99,6 +100,9 @@ Operation *findPrecedingEmittingOp(ArrayRef ops, size_t start) { return nullptr; } +/// Check if an operation is v_accvgpr_read_b32 (AGPR -> VGPR move). +bool isAccVgprReadOp(Operation *op) { return isa(op); } + /// Check if an operation is a transcendental instruction (uses the Trans /// pipeline which has different latency characteristics from the main VALU). bool isTransOp(Operation *op) { @@ -214,6 +218,17 @@ struct HazardMitigationPass if (pred && isTransOp(pred) && hasVGPRConflict(pred, op)) insertionPoints.push_back(op); } + + // Check for v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+). + // On GFX950 the VGPR destination of v_accvgpr_read_b32 is not + // immediately available; a VALU that consumes it in the next cycle + // reads stale data. Insert s_nop 0 to cover the 1-cycle wait. + if (isAccVgprReadOp(current) && isVALUOp(next)) { + auto defs = getVGPRDefs(current); + auto uses = getVGPRUses(next); + if (hasIntersection(defs, uses)) + insertionPoints.push_back(next); + } } // Insert s_nop instructions diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index ecc2cae6ce..d9359d08f6 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -268,12 +268,101 @@ struct LinearScanPass /// Get the accumulator operand from an MFMA op using the interface. /// Returns nullptr if the operation is not an MFMA. Value getMFMAAccumulator(Operation *op) { - if (auto mfmaOp = dyn_cast(op)) { + if (auto mfmaOp = dyn_cast(op)) return mfmaOp.getAcc(); - } return nullptr; } + /// Scratch VGPR index used for cross-class spill reloads. + static constexpr int64_t kSpillScratchVGPR = 14; + + /// Insert spill/reload ops for cross-class evictions. + /// + /// For each VGPR->AGPR spill: + /// - After the victim's def: v_accvgpr_write_b32 vSRC, aDST + /// - Before each use: v_accvgpr_read_b32 aSRC -> vSCRATCH + /// and rewrite the use to consume the reload result. + /// + /// The ops are created with physical register types so the subsequent + /// type transformation pass does not need to touch them. + LogicalResult insertSpillReloads(ProgramOp program, + ArrayRef spills, + PhysicalMapping &mapping) { + // Build a lookup from victim Value -> SpillRecord. + llvm::DenseMap spillMap; + for (const SpillRecord &sr : spills) + spillMap[sr.victim] = &sr; + + MLIRContext *ctx = program.getContext(); + + // For each spill, create a precolored AGPR value to use as the + // source/destination in read/write ops. This gives us an SSA handle + // to thread through the spill ops. + // + // We insert a PrecoloredARegOp at program entry to materialise the + // AGPR "slot". It does not generate any assembly; it just provides + // an SSA value with the right physical type for the dialect ops. + llvm::DenseMap spillSlots; // victim -> AGPR SSA value + { + OpBuilder entryBuilder(ctx); + Block &entry = program.getBodyBlock(); + entryBuilder.setInsertionPointToStart(&entry); + for (const SpillRecord &sr : spills) { + auto physARegType = PARegType::get(ctx, sr.targetPhysReg, 1); + Value slot = PrecoloredARegOp::create(entryBuilder, program.getLoc(), + physARegType, sr.targetPhysReg, + /*size=*/1); + spillSlots[sr.victim] = slot; + } + } + + // Collect all ops in program order. + llvm::SmallVector ops; + collectOpsRecursive(program.getBodyBlock(), ops); + + for (Operation *op : ops) { + // --- Insert spills after defs. --- + for (Value result : op->getResults()) { + auto it = spillMap.find(result); + if (it == spillMap.end()) + continue; + const SpillRecord &sr = *it->second; + if (sr.sourceClass == RegClass::VGPR && + sr.targetClass == RegClass::AGPR) { + OpBuilder builder(ctx); + builder.setInsertionPointAfter(op); + Value slot = spillSlots[sr.victim]; + V_ACCVGPR_WRITE_B32::create(builder, op->getLoc(), result, slot); + } + } + + // --- Insert reloads before uses. --- + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto it = spillMap.find(operand); + if (it == spillMap.end()) + continue; + const SpillRecord &sr = *it->second; + if (sr.sourceClass == RegClass::VGPR && + sr.targetClass == RegClass::AGPR) { + OpBuilder builder(op); + auto scratchType = PVRegType::get(ctx, kSpillScratchVGPR, 1); + Value slot = spillSlots[sr.victim]; + Value reloaded = V_ACCVGPR_READ_B32::create(builder, op->getLoc(), + scratchType, slot); + op->setOperand(i, reloaded); + } + } + } + + // Update the mapping for spilled values: the victim VGPR is no longer + // live; its uses have been rewritten to read from the scratch VGPR. + // The original mapping (victim -> sourcePhysReg) is left as-is for the + // type transformation to assign the correct PVRegType to the def site. + + return success(); + } + LogicalResult processProgram(ProgramOp program) { // Collect precolored values from precolored.vreg and precolored.sreg ops llvm::DenseMap precoloredValues; @@ -291,6 +380,11 @@ struct LinearScanPass // generates v_mov_b32 v15, before such instructions. reservedVGPRs.insert(15); + // Reserve v14 as scratch VGPR for cross-class spill reloads. + // When a VGPR is spilled to an AGPR, reloads use v_accvgpr_read_b32 + // into this scratch before the consuming instruction. + reservedVGPRs.insert(14); + // ABI SGPRs (kernarg ptr, preload regs, workgroup IDs, SRDs) are // reserved via PrecoloredSRegOp ops emitted during translation. The // collection loop below picks those up and adds their indices to @@ -404,13 +498,12 @@ struct LinearScanPass allocator.setVGPRStrategy( std::make_unique(bidirectionalThreshold)); - // Run allocation + // Run allocation. auto result = allocator.allocate(program); - if (failed(result)) { + if (failed(result)) return failure(); - } - auto [mapping, stats] = *result; + auto &[mapping, stats, spills] = *result; // Handle waveasm.extract ops: result = source[offset]. // Set the extract result's physical register = source's physReg + offset. @@ -497,6 +590,16 @@ struct LinearScanPass if (packResult.wasInterrupted()) return failure(); + // Insert cross-class spill/reload ops for any evicted values. + // For each spill record (e.g. VGPR -> AGPR): + // - After the victim's def: v_accvgpr_write_b32 aX, vY (spill) + // - Before each use: v_accvgpr_read_b32 v14, aX (reload) + // and rewrite the use to consume v14 instead of vY. + if (!spills.empty()) { + if (failed(insertSpillReloads(program, spills, mapping))) + return failure(); + } + // Transform the IR: replace virtual register types with physical types OpBuilder builder(program.getContext()); diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index dfa397de16..ac26da1ad4 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -115,26 +115,71 @@ BidirectionalStrategy::allocate(RegPool &pool, const LiveRange &range, // Register Class Allocation //===----------------------------------------------------------------------===// +/// Find the best eviction candidate from the active list. +/// Prefers untied, size-1 ranges with the fewest use sites. +/// Among equal-cost candidates, picks the longest remaining range. +/// Returns the index into `active`, or -1 if no candidate is found. +static int64_t +findEvictionCandidate(const llvm::SmallVectorImpl &active, + const llvm::DenseMap &tiedOperands, + const LivenessInfo *liveness, int64_t currentPoint) { + int64_t bestIdx = -1; + int64_t bestUseCount = std::numeric_limits::max(); + int64_t bestLength = -1; + + for (int64_t i = 0, e = active.size(); i < e; ++i) { + const ActiveRange &ar = active[i]; + // Only spill size-1 values. + if (ar.range.size != 1) + continue; + // Do not spill tied values. + if (ar.range.isTied() || tiedOperands.contains(ar.range.reg)) + continue; + // Must still be alive past the current point. + if (ar.endPoint <= currentPoint) + continue; + + int64_t useCount = 0; + if (liveness) { + auto it = liveness->usePoints.find(ar.range.reg); + if (it != liveness->usePoints.end()) + useCount = it->second.size(); + } + int64_t length = ar.endPoint - currentPoint; + + // Prefer fewer uses; break ties with longer remaining range + // (evicting a long range frees the register for more time). + if (useCount < bestUseCount || + (useCount == bestUseCount && length > bestLength)) { + bestIdx = i; + bestUseCount = useCount; + bestLength = length; + } + } + return bestIdx; +} + /// Allocate registers for a single register class (VGPR, SGPR, or AGPR). /// This is the core linear scan algorithm, parameterized by register class. /// An optional AllocationStrategy is consulted before the default bottom-up -/// allocation; returning std::nullopt falls through to tryAllocate. +/// allocation; when allocation fails and `altPool` is provided, the allocator +/// evicts an active range to the alternate register class before giving up. static LogicalResult allocateRegClass( ArrayRef ranges, RegPool &pool, PhysicalMapping &mapping, AllocationStats &stats, const llvm::DenseMap &tiedOperands, const llvm::DenseMap &precoloredValues, llvm::StringRef regClassName, ProgramOp program, int64_t maxRegs, - int64_t maxPressure, AllocationStrategy *strategy) { + int64_t maxPressure, AllocationStrategy *strategy, RegPool *altPool, + llvm::SmallVectorImpl *spills, const LivenessInfo *liveness) { llvm::SmallVector active; for (const LiveRange &range : ranges) { - // Skip precolored values - they're already mapped - if (precoloredValues.contains(range.reg)) { + // Skip precolored values - they're already mapped. + if (precoloredValues.contains(range.reg)) continue; - } - // Expire finished ranges, returning registers to the pool + // Expire finished ranges, returning registers to the pool. expireRanges(active, range.start, pool, stats); std::optional physReg; @@ -151,20 +196,13 @@ static LogicalResult allocateRegClass( mapping.valueToPhysReg[range.reg] = *physReg; // Extend the physical register's lifetime to cover the tied result. - // The tied-to operand may have a shorter lifetime (e.g., %55 ends - // at op2), but the tied result (%56) may live longer (used in - // iteration 2). Without this extension, the physical register would - // be freed too early when the tied-to operand expires. bool foundInActive = false; for (size_t i = 0; i < active.size(); ++i) { if (active[i].physReg == *physReg) { foundInActive = true; if (range.end > active[i].endPoint) { - // Update end point and re-sort the affected portion active[i].endPoint = range.end; active[i].range = range; - // Re-sort: since we only increased one element's key, - // bubble it forward to maintain sorted order by endPoint while (i + 1 < active.size() && active[i].endPoint > active[i + 1].endPoint) { std::swap(active[i], active[i + 1]); @@ -176,13 +214,6 @@ static LogicalResult allocateRegClass( } if (!foundInActive) { - // Two cases reach here: - // (a) Precolored tying (MFMA): the tied-to value is precolored - // and was never in the active list. Its physReg is already - // reserved in the pool. pool.reserve is a safe no-op. - // (b) Loop boundary: the tied-to virtual value expired from - // active and its registers were returned to the pool. - // The physReg MUST still be free (not re-allocated). bool tiedToPrecolored = precoloredValues.contains(tiedTo); assert((tiedToPrecolored || pool.isFree(*physReg)) && "Tied register was re-allocated before re-reservation"); @@ -193,7 +224,7 @@ static LogicalResult allocateRegClass( stats.rangesAllocated++; continue; } - // If tied-to not yet allocated, fall through to normal allocation + // If tied-to not yet allocated, fall through to normal allocation. } } @@ -203,6 +234,32 @@ static LogicalResult allocateRegClass( if (!physReg) physReg = tryAllocate(pool, range.size, range.alignment); + // Cross-class eviction: if allocation failed and an alternate pool is + // available, evict the best candidate from the active list into the + // alternate class, freeing its register for the incoming range. + if (!physReg && altPool && altPool->hasFree() && spills) { + int64_t victimIdx = + findEvictionCandidate(active, tiedOperands, liveness, range.start); + if (victimIdx >= 0) { + ActiveRange victim = active[victimIdx]; + // Allocate a register in the alternate class for the victim. + int64_t altReg = altPool->allocSingle(); + if (altReg >= 0) { + // Free the victim's register back to the primary pool. + pool.freeRange(victim.physReg, victim.range.size); + active.erase(active.begin() + victimIdx); + + // Record the spill for later op insertion. + spills->push_back(SpillRecord{victim.range.reg, victim.physReg, + altReg, pool.getRegClass(), + altPool->getRegClass()}); + + // Retry allocation for the incoming range. + physReg = tryAllocate(pool, range.size, range.alignment); + } + } + } + if (!physReg) { InFlightDiagnostic diag = mlir::emitError(range.reg.getLoc()) << "Failed to allocate " << regClassName @@ -214,10 +271,10 @@ static LogicalResult allocateRegClass( return failure(); } - // Record mapping: Value -> physical register + // Record mapping: Value -> physical register. mapping.valueToPhysReg[range.reg] = *physReg; - // Add to active list, maintaining sorted order by end point + // Add to active list, maintaining sorted order by end point. insertActiveRange(active, {range.end, range, *physReg}); stats.rangesAllocated++; @@ -230,40 +287,35 @@ static LogicalResult allocateRegClass( // Main Allocation Algorithm (Pure SSA) //===----------------------------------------------------------------------===// -FailureOr> +FailureOr LinearScanRegAlloc::allocate(ProgramOp program) { PhysicalMapping mapping; AllocationStats stats; + llvm::SmallVector spills; - // Step 1: Validate SSA - if (failed(validateSSA(program))) { + // Step 1: Validate SSA. + if (failed(validateSSA(program))) return program.emitOpError() << "SSA validation failed before allocation"; - } - // Step 2: Compute liveness (builds tied equivalence classes) + // Step 2: Compute liveness (builds tied equivalence classes). LivenessInfo liveness = computeLiveness(program); // Merge loop tied pairs from liveness into the allocator's tiedOperands. - // The liveness analysis builds TiedValueClasses with a tiedPairs map - // that captures loop block_arg/init_arg/iter_arg/result relationships. - // MFMA ties were added externally via addTiedOperand() and are already - // in tiedOperands; loop ties come from liveness and are merged here. for (const auto &[result, operand] : liveness.tiedClasses.tiedPairs) { - if (!tiedOperands.contains(result)) { + if (!tiedOperands.contains(result)) tiedOperands[result] = operand; - } } stats.totalVRegs = liveness.vregRanges.size(); stats.totalSRegs = liveness.sregRanges.size(); stats.totalARegs = liveness.aregRanges.size(); - // Step 3: Create register pools with reserved registers + // Step 3: Create register pools with reserved registers. RegPool vgprPool(RegClass::VGPR, maxVGPRs, reservedVGPRs); RegPool sgprPool(RegClass::SGPR, maxSGPRs, reservedSGPRs); RegPool agprPool(RegClass::AGPR, maxAGPRs, reservedAGPRs); - // Step 4: Handle precolored values (from ABI args like tid, kernarg) + // Step 4: Handle precolored values (from ABI args like tid, kernarg). for (const auto &[value, physIdx] : precoloredValues) { if (isVGPRType(value.getType())) { mapping.valueToPhysReg[value] = physIdx; @@ -277,32 +329,32 @@ LinearScanRegAlloc::allocate(ProgramOp program) { } } - // Step 5: Allocate VGPRs using linear scan (with optional strategy) + // Step 5: Allocate VGPRs. On failure, evict to spare AGPRs. if (failed(allocateRegClass(liveness.vregRanges, vgprPool, mapping, stats, tiedOperands, precoloredValues, "VGPR", program, maxVGPRs, liveness.maxVRegPressure, - vgprStrategy.get()))) { + vgprStrategy.get(), &agprPool, &spills, + &liveness))) return failure(); - } stats.peakVGPRs = vgprPool.getPeakUsage(); - // Step 6: Allocate SGPRs using linear scan (no strategy, always bottom-up) + // Step 6: Allocate SGPRs. On failure, evict to spare VGPRs. if (failed(allocateRegClass(liveness.sregRanges, sgprPool, mapping, stats, tiedOperands, precoloredValues, "SGPR", program, maxSGPRs, liveness.maxSRegPressure, - /*strategy=*/nullptr))) { + /*strategy=*/nullptr, &vgprPool, &spills, + &liveness))) return failure(); - } stats.peakSGPRs = sgprPool.getPeakUsage(); - // Step 7: Allocate AGPRs using linear scan (no strategy, always bottom-up) + // Step 7: Allocate AGPRs. On failure, evict to spare VGPRs. if (failed(allocateRegClass(liveness.aregRanges, agprPool, mapping, stats, tiedOperands, precoloredValues, "AGPR", program, maxAGPRs, liveness.maxARegPressure, - /*strategy=*/nullptr))) { + /*strategy=*/nullptr, &vgprPool, &spills, + &liveness))) return failure(); - } stats.peakAGPRs = agprPool.getPeakUsage(); - return std::make_pair(mapping, stats); + return AllocResult{std::move(mapping), stats, std::move(spills)}; } diff --git a/waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir b/waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir new file mode 100644 index 0000000000..1b37476157 --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir @@ -0,0 +1,37 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=4" %s | FileCheck %s +// +// Test: VGPR -> AGPR cross-class spilling. +// +// This program needs 3 virtual VGPRs but only 2 are allocatable (v2, v3) +// given max-vgprs=4 with v0/v1 precolored and v14/v15 reserved as scratch. +// The allocator evicts the longest-lived value (r1) to AGPR a0 and inserts +// v_accvgpr_write/read pairs around its def and uses. + +// CHECK-LABEL: waveasm.program @spill_vgpr_to_agpr +waveasm.program @spill_vgpr_to_agpr target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // r1 is live across r2 and r3, but only 2 VGPRs are free. + // The allocator should spill r1 to AGPR. + + // CHECK: [[R1:%.*]] = waveasm.v_add_u32 {{.*}} -> !waveasm.pvreg<[[R1IDX:[0-9]+]]> + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // CHECK: waveasm.v_accvgpr_write_b32 [[R1]], {{.*}} : !waveasm.pvreg<[[R1IDX]]>, !waveasm.pareg<[[ASPILL:[0-9]+]]> + // CHECK: [[RELOAD1:%.*]] = waveasm.v_accvgpr_read_b32 {{.*}} : !waveasm.pareg<[[ASPILL]]> -> !waveasm.pvreg<14> + + // CHECK: waveasm.v_add_u32 [[RELOAD1]], {{.*}} -> !waveasm.pvreg< + %r2 = waveasm.v_add_u32 %r1, %v0 : !waveasm.vreg, !waveasm.pvreg<0> -> !waveasm.vreg + + // CHECK: waveasm.v_add_u32 {{.*}} -> !waveasm.pvreg< + %r3 = waveasm.v_add_u32 %r2, %v1 : !waveasm.vreg, !waveasm.pvreg<1> -> !waveasm.vreg + + // r1 is used again here, so a second reload is needed. + // CHECK: [[RELOAD2:%.*]] = waveasm.v_accvgpr_read_b32 {{.*}} : !waveasm.pareg<[[ASPILL]]> -> !waveasm.pvreg<14> + // CHECK: waveasm.v_add_u32 [[RELOAD2]], {{.*}} -> !waveasm.pvreg< + %sum1 = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + // CHECK: waveasm.s_endpgm + waveasm.s_endpgm +} From 0359781293d2edd2cf7ecd0162c67b99a0560170 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 28 Mar 2026 12:10:37 +0100 Subject: [PATCH 2/3] Add cross-class spill tests and guard unimplemented spill directions Tests added: - no-spill-needed: verify no spill ops when VGPRs fit within limit - no-tied: verify tied values are never eviction candidates - exhausted: verify proper error when both VGPR and AGPR are full - multi-reg: verify size > 1 values are not eviction candidates - multiple: verify multiple values can be spilled to different AGPRs - sgpr-unsupported: verify SGPR overflow gives a clean error Also: - Disable SGPR and AGPR cross-class eviction paths in allocate() (pass nullptr as altPool) since the spill insertion only handles VGPR -> AGPR. Previously these paths would silently corrupt the allocation by evicting without inserting spill/reload ops. - Add explicit validation in insertSpillReloads() that rejects non-VGPR->AGPR spill records with a diagnostic. - Simplify spill/reload insertion to not branch on direction (only VGPR->AGPR is supported). Known limitation: if two spilled values feed the same instruction, both reload into the single scratch v14, corrupting the first reload. This does not occur in practice (spilled values are rare and unlikely to be consumed by the same op). Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Ivan Butygin --- waveasm/lib/Transforms/LinearScanPass.cpp | 50 +++++++++---------- waveasm/lib/Transforms/LinearScanRegAlloc.cpp | 24 ++++----- .../cross-class-spill-exhausted.mlir | 15 ++++++ .../cross-class-spill-multi-reg.mlir | 25 ++++++++++ .../cross-class-spill-multiple.mlir | 33 ++++++++++++ .../cross-class-spill-no-spill-needed.mlir | 17 +++++++ .../Transforms/cross-class-spill-no-tied.mlir | 31 ++++++++++++ .../cross-class-spill-sgpr-unsupported.mlir | 20 ++++++++ 8 files changed, 178 insertions(+), 37 deletions(-) create mode 100644 waveasm/test/Transforms/cross-class-spill-exhausted.mlir create mode 100644 waveasm/test/Transforms/cross-class-spill-multi-reg.mlir create mode 100644 waveasm/test/Transforms/cross-class-spill-multiple.mlir create mode 100644 waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir create mode 100644 waveasm/test/Transforms/cross-class-spill-no-tied.mlir create mode 100644 waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index d9359d08f6..30f52672bd 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -288,6 +288,18 @@ struct LinearScanPass LogicalResult insertSpillReloads(ProgramOp program, ArrayRef spills, PhysicalMapping &mapping) { + // Validate that all spills use the implemented VGPR -> AGPR direction. + // SGPR -> VGPR and AGPR -> VGPR are wired in the allocator but not yet + // supported by the spill insertion logic. + for (const SpillRecord &sr : spills) { + if (sr.sourceClass != RegClass::VGPR || + sr.targetClass != RegClass::AGPR) { + return program.emitOpError() + << "unsupported spill direction: only VGPR -> AGPR is " + "currently implemented"; + } + } + // Build a lookup from victim Value -> SpillRecord. llvm::DenseMap spillMap; for (const SpillRecord &sr : spills) @@ -300,7 +312,7 @@ struct LinearScanPass // to thread through the spill ops. // // We insert a PrecoloredARegOp at program entry to materialise the - // AGPR "slot". It does not generate any assembly; it just provides + // AGPR "slot". It does not generate any assembly; it just provides // an SSA value with the right physical type for the dialect ops. llvm::DenseMap spillSlots; // victim -> AGPR SSA value { @@ -308,9 +320,9 @@ struct LinearScanPass Block &entry = program.getBodyBlock(); entryBuilder.setInsertionPointToStart(&entry); for (const SpillRecord &sr : spills) { - auto physARegType = PARegType::get(ctx, sr.targetPhysReg, 1); + auto physType = PARegType::get(ctx, sr.targetPhysReg, 1); Value slot = PrecoloredARegOp::create(entryBuilder, program.getLoc(), - physARegType, sr.targetPhysReg, + physType, sr.targetPhysReg, /*size=*/1); spillSlots[sr.victim] = slot; } @@ -327,13 +339,10 @@ struct LinearScanPass if (it == spillMap.end()) continue; const SpillRecord &sr = *it->second; - if (sr.sourceClass == RegClass::VGPR && - sr.targetClass == RegClass::AGPR) { - OpBuilder builder(ctx); - builder.setInsertionPointAfter(op); - Value slot = spillSlots[sr.victim]; - V_ACCVGPR_WRITE_B32::create(builder, op->getLoc(), result, slot); - } + OpBuilder builder(ctx); + builder.setInsertionPointAfter(op); + Value slot = spillSlots[sr.victim]; + V_ACCVGPR_WRITE_B32::create(builder, op->getLoc(), result, slot); } // --- Insert reloads before uses. --- @@ -342,24 +351,15 @@ struct LinearScanPass auto it = spillMap.find(operand); if (it == spillMap.end()) continue; - const SpillRecord &sr = *it->second; - if (sr.sourceClass == RegClass::VGPR && - sr.targetClass == RegClass::AGPR) { - OpBuilder builder(op); - auto scratchType = PVRegType::get(ctx, kSpillScratchVGPR, 1); - Value slot = spillSlots[sr.victim]; - Value reloaded = V_ACCVGPR_READ_B32::create(builder, op->getLoc(), - scratchType, slot); - op->setOperand(i, reloaded); - } + OpBuilder builder(op); + auto scratchType = PVRegType::get(ctx, kSpillScratchVGPR, 1); + Value slot = spillSlots[operand]; + Value reloaded = V_ACCVGPR_READ_B32::create(builder, op->getLoc(), + scratchType, slot); + op->setOperand(i, reloaded); } } - // Update the mapping for spilled values: the victim VGPR is no longer - // live; its uses have been rewritten to read from the scratch VGPR. - // The original mapping (victim -> sourcePhysReg) is left as-is for the - // type transformation to assign the correct PVRegType to the def site. - return success(); } diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index ac26da1ad4..9698767382 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -338,21 +338,21 @@ LinearScanRegAlloc::allocate(ProgramOp program) { return failure(); stats.peakVGPRs = vgprPool.getPeakUsage(); - // Step 6: Allocate SGPRs. On failure, evict to spare VGPRs. - if (failed(allocateRegClass(liveness.sregRanges, sgprPool, mapping, stats, - tiedOperands, precoloredValues, "SGPR", program, - maxSGPRs, liveness.maxSRegPressure, - /*strategy=*/nullptr, &vgprPool, &spills, - &liveness))) + // Step 6: Allocate SGPRs (no cross-class spilling yet). + if (failed(allocateRegClass( + liveness.sregRanges, sgprPool, mapping, stats, tiedOperands, + precoloredValues, "SGPR", program, maxSGPRs, liveness.maxSRegPressure, + /*strategy=*/nullptr, /*altPool=*/nullptr, /*spills=*/nullptr, + &liveness))) return failure(); stats.peakSGPRs = sgprPool.getPeakUsage(); - // Step 7: Allocate AGPRs. On failure, evict to spare VGPRs. - if (failed(allocateRegClass(liveness.aregRanges, agprPool, mapping, stats, - tiedOperands, precoloredValues, "AGPR", program, - maxAGPRs, liveness.maxARegPressure, - /*strategy=*/nullptr, &vgprPool, &spills, - &liveness))) + // Step 7: Allocate AGPRs (no cross-class spilling yet). + if (failed(allocateRegClass( + liveness.aregRanges, agprPool, mapping, stats, tiedOperands, + precoloredValues, "AGPR", program, maxAGPRs, liveness.maxARegPressure, + /*strategy=*/nullptr, /*altPool=*/nullptr, /*spills=*/nullptr, + &liveness))) return failure(); stats.peakAGPRs = agprPool.getPeakUsage(); diff --git a/waveasm/test/Transforms/cross-class-spill-exhausted.mlir b/waveasm/test/Transforms/cross-class-spill-exhausted.mlir new file mode 100644 index 0000000000..84c85e6a91 --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-exhausted.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=0" %s 2>&1 | FileCheck %s +// +// Test: When both VGPR and AGPR pools are exhausted, allocation fails +// with a diagnostic. max-agprs=0 leaves no spare AGPRs for eviction. + +// CHECK: error: 'waveasm.program' op Failed to allocate VGPR +waveasm.program @both_exhausted target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r2 = waveasm.v_add_u32 %r1, %v0 : !waveasm.vreg, !waveasm.pvreg<0> -> !waveasm.vreg + %r3 = waveasm.v_add_u32 %r2, %v1 : !waveasm.vreg, !waveasm.pvreg<1> -> !waveasm.vreg + %sum = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir b/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir new file mode 100644 index 0000000000..a47481455e --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir @@ -0,0 +1,25 @@ +// RUN: not waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=8" %s 2>&1 | FileCheck %s +// +// Test: Multi-register (size > 1) values are NOT spill candidates. +// When the only live values are multi-register, eviction is not attempted +// and allocation fails. + +// CHECK: error: 'waveasm.program' op Failed to allocate VGPR +waveasm.program @multi_reg_no_spill target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %soff = waveasm.constant 0 : !waveasm.imm<0> + + // A 4-wide load consumes v2..v5 (v0 precolored, v14/v15 reserved). + // That exhausts the pool. There are no size-1 eviction candidates. + %wide = waveasm.buffer_load_dwordx4 %srd, %voff, %soff : !waveasm.psreg<0, 4>, !waveasm.pvreg<0>, !waveasm.imm<0> -> !waveasm.vreg<4, 4> + + // This needs another VGPR but nothing can be evicted (wide is size 4). + %extra = waveasm.v_add_u32 %voff, %voff : !waveasm.pvreg<0>, !waveasm.pvreg<0> -> !waveasm.vreg + + // Keep wide live past extra to create the pressure. + %elem = waveasm.extract %wide[0] : !waveasm.vreg<4, 4> -> !waveasm.vreg + %sum = waveasm.v_add_u32 %elem, %extra : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-multiple.mlir b/waveasm/test/Transforms/cross-class-spill-multiple.mlir new file mode 100644 index 0000000000..b66ce1299b --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-multiple.mlir @@ -0,0 +1,33 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=5 max-agprs=8" %s | FileCheck %s +// +// Test: Multiple values are spilled to different AGPRs when VGPR pressure +// exceeds the limit by more than one register. +// +// With max-vgprs=5: v0, v1 precolored; v14, v15 outside range (>= 5). +// Only v2, v3, v4 are allocatable. We create 4 virtual VGPRs with +// overlapping liveness, so at least 1 must be spilled. The long-lived +// values r1 and r2 compete for VGPRs while r3 and r4 also need them. +// +// NOTE: two spilled values must not feed the same instruction because +// both would reload into the single scratch v14. This test avoids that. + +// CHECK-LABEL: waveasm.program @multiple_spills +waveasm.program @multiple_spills target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // Four virtual VGPRs; r1, r2 are long-lived. + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r2 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r3 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r4 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // At least one spill write expected. + // CHECK: waveasm.v_accvgpr_write_b32 + // Use spilled values one at a time (avoid double-reload into same scratch). + %s1 = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %s2 = waveasm.v_add_u32 %r2, %r4 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %final = waveasm.v_add_u32 %s1, %s2 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir b/waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir new file mode 100644 index 0000000000..54e93a1d4d --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir @@ -0,0 +1,17 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=8 max-agprs=4" %s | FileCheck %s +// +// Test: No spill needed when VGPRs fit within the limit. +// With max-vgprs=8 and only 3 virtual VGPRs needed, no spill should occur. + +// CHECK-LABEL: waveasm.program @no_spill_needed +// CHECK-NOT: v_accvgpr_write_b32 +// CHECK-NOT: v_accvgpr_read_b32 +waveasm.program @no_spill_needed target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r2 = waveasm.v_add_u32 %r1, %v0 : !waveasm.vreg, !waveasm.pvreg<0> -> !waveasm.vreg + %r3 = waveasm.v_add_u32 %r2, %v1 : !waveasm.vreg, !waveasm.pvreg<1> -> !waveasm.vreg + %sum = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-no-tied.mlir b/waveasm/test/Transforms/cross-class-spill-no-tied.mlir new file mode 100644 index 0000000000..35c0ba548d --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-no-tied.mlir @@ -0,0 +1,31 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=4" %s | FileCheck %s +// +// Test: Tied values (MFMA accumulator -> result) are NOT spill candidates. +// The spill should pick an untied value instead. +// +// With max-vgprs=4: v0, v1 precolored; v14, v15 reserved. Only v2, v3 +// allocatable. Three values are live simultaneously -> one must spill. +// The spill must NOT be a tied value. + +// CHECK-LABEL: waveasm.program @no_tied_spill +waveasm.program @no_tied_spill target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // Long-lived untied value. + %addr = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // Second value creating pressure. + %tmp = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // Third value that triggers the spill. + %extra = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // Use all three -> one must be spilled. Verify that spill/reload appear. + // CHECK: waveasm.v_accvgpr_write_b32 + // CHECK: waveasm.v_accvgpr_read_b32 + %s1 = waveasm.v_add_u32 %addr, %tmp : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %s2 = waveasm.v_add_u32 %s1, %extra : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir b/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir new file mode 100644 index 0000000000..8f9cd5c921 --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir @@ -0,0 +1,20 @@ +// RUN: not waveasm-translate --waveasm-linear-scan="max-sgprs=4 max-vgprs=8" %s 2>&1 | FileCheck %s +// +// Test: SGPR overflow triggers a hard error because SGPR -> VGPR spilling +// is not yet implemented. The allocator does not attempt cross-class +// eviction for SGPRs (altPool is nullptr). + +// CHECK: error: 'waveasm.program' op Failed to allocate SGPR +waveasm.program @sgpr_overflow target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %s0 = waveasm.precolored.sreg 0 : !waveasm.psreg<0> + %s1 = waveasm.precolored.sreg 1 : !waveasm.psreg<1> + %s2 = waveasm.precolored.sreg 2 : !waveasm.psreg<2> + %s3 = waveasm.precolored.sreg 3 : !waveasm.psreg<3> + + // All 4 SGPRs precolored. Virtual SGPRs have nowhere to go. + %r1 = waveasm.s_mul_i32 %s0, %s1 : !waveasm.psreg<0>, !waveasm.psreg<1> -> !waveasm.sreg + %r2 = waveasm.s_mul_i32 %s2, %s3 : !waveasm.psreg<2>, !waveasm.psreg<3> -> !waveasm.sreg + %r3 = waveasm.s_mul_i32 %r1, %r2 : !waveasm.sreg, !waveasm.sreg -> !waveasm.sreg + + waveasm.s_endpgm +} From 365ef67c000a48b79a8a4dd74f76cfdf39cce111 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 7 Apr 2026 18:38:41 +0200 Subject: [PATCH 3/3] Fix cross-class spilling after rebase on main Adapt accvgpr hazard check to use findPrecedingEmittingOp pattern introduced on main, and update test CHECK lines for the new error message format (emitError on Value loc, not program op). Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Ivan Butygin --- waveasm/lib/Transforms/HazardMitigation.cpp | 9 +++--- waveasm/lib/Transforms/LinearScanRegAlloc.cpp | 29 +++++++++---------- .../cross-class-spill-exhausted.mlir | 2 +- .../cross-class-spill-multi-reg.mlir | 2 +- .../cross-class-spill-sgpr-unsupported.mlir | 2 +- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/waveasm/lib/Transforms/HazardMitigation.cpp b/waveasm/lib/Transforms/HazardMitigation.cpp index c2ec43660a..92d478eae7 100644 --- a/waveasm/lib/Transforms/HazardMitigation.cpp +++ b/waveasm/lib/Transforms/HazardMitigation.cpp @@ -223,11 +223,10 @@ struct HazardMitigationPass // On GFX950 the VGPR destination of v_accvgpr_read_b32 is not // immediately available; a VALU that consumes it in the next cycle // reads stale data. Insert s_nop 0 to cover the 1-cycle wait. - if (isAccVgprReadOp(current) && isVALUOp(next)) { - auto defs = getVGPRDefs(current); - auto uses = getVGPRUses(next); - if (hasIntersection(defs, uses)) - insertionPoints.push_back(next); + if (isVALUOp(op) && i > 0) { + Operation *pred = findPrecedingEmittingOp(ops, i); + if (pred && isAccVgprReadOp(pred) && hasVGPRConflict(pred, op)) + insertionPoints.push_back(op); } } diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index 9698767382..2ecad139fc 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -330,29 +330,28 @@ LinearScanRegAlloc::allocate(ProgramOp program) { } // Step 5: Allocate VGPRs. On failure, evict to spare AGPRs. - if (failed(allocateRegClass(liveness.vregRanges, vgprPool, mapping, stats, - tiedOperands, precoloredValues, "VGPR", program, - maxVGPRs, liveness.maxVRegPressure, - vgprStrategy.get(), &agprPool, &spills, - &liveness))) + if (failed(allocateRegClass( + liveness.vregRanges, vgprPool, mapping, stats, tiedOperands, + precoloredValues, "VGPR", program, maxVGPRs, liveness.maxVRegPressure, + vgprStrategy.get(), &agprPool, &spills, &liveness))) return failure(); stats.peakVGPRs = vgprPool.getPeakUsage(); // Step 6: Allocate SGPRs (no cross-class spilling yet). - if (failed(allocateRegClass( - liveness.sregRanges, sgprPool, mapping, stats, tiedOperands, - precoloredValues, "SGPR", program, maxSGPRs, liveness.maxSRegPressure, - /*strategy=*/nullptr, /*altPool=*/nullptr, /*spills=*/nullptr, - &liveness))) + if (failed(allocateRegClass(liveness.sregRanges, sgprPool, mapping, stats, + tiedOperands, precoloredValues, "SGPR", program, + maxSGPRs, liveness.maxSRegPressure, + /*strategy=*/nullptr, /*altPool=*/nullptr, + /*spills=*/nullptr, &liveness))) return failure(); stats.peakSGPRs = sgprPool.getPeakUsage(); // Step 7: Allocate AGPRs (no cross-class spilling yet). - if (failed(allocateRegClass( - liveness.aregRanges, agprPool, mapping, stats, tiedOperands, - precoloredValues, "AGPR", program, maxAGPRs, liveness.maxARegPressure, - /*strategy=*/nullptr, /*altPool=*/nullptr, /*spills=*/nullptr, - &liveness))) + if (failed(allocateRegClass(liveness.aregRanges, agprPool, mapping, stats, + tiedOperands, precoloredValues, "AGPR", program, + maxAGPRs, liveness.maxARegPressure, + /*strategy=*/nullptr, /*altPool=*/nullptr, + /*spills=*/nullptr, &liveness))) return failure(); stats.peakAGPRs = agprPool.getPeakUsage(); diff --git a/waveasm/test/Transforms/cross-class-spill-exhausted.mlir b/waveasm/test/Transforms/cross-class-spill-exhausted.mlir index 84c85e6a91..05fb5330be 100644 --- a/waveasm/test/Transforms/cross-class-spill-exhausted.mlir +++ b/waveasm/test/Transforms/cross-class-spill-exhausted.mlir @@ -3,7 +3,7 @@ // Test: When both VGPR and AGPR pools are exhausted, allocation fails // with a diagnostic. max-agprs=0 leaves no spare AGPRs for eviction. -// CHECK: error: 'waveasm.program' op Failed to allocate VGPR +// CHECK: error: Failed to allocate VGPR waveasm.program @both_exhausted target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> diff --git a/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir b/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir index a47481455e..04f4b0101d 100644 --- a/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir +++ b/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir @@ -4,7 +4,7 @@ // When the only live values are multi-register, eviction is not attempted // and allocation fails. -// CHECK: error: 'waveasm.program' op Failed to allocate VGPR +// CHECK: error: Failed to allocate VGPR waveasm.program @multi_reg_no_spill target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> diff --git a/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir b/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir index 8f9cd5c921..f7bd324abf 100644 --- a/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir +++ b/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir @@ -4,7 +4,7 @@ // is not yet implemented. The allocator does not attempt cross-class // eviction for SGPRs (altPool is nullptr). -// CHECK: error: 'waveasm.program' op Failed to allocate SGPR +// CHECK: error: Failed to allocate SGPR waveasm.program @sgpr_overflow target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %s0 = waveasm.precolored.sreg 0 : !waveasm.psreg<0> %s1 = waveasm.precolored.sreg 1 : !waveasm.psreg<1>