diff --git a/docs/wave/cross-class-spilling.md b/docs/wave/cross-class-spilling.md new file mode 100644 index 0000000000..9ddda1c934 --- /dev/null +++ b/docs/wave/cross-class-spilling.md @@ -0,0 +1,436 @@ +# Cross-Class Register Spilling in WaveASM + +- **Author:** Ivan Butygin +- **Status:** Draft +- **Created:** 2026-03-28 + +## Table of Contents + +- [Problem Statement](#problem-statement) +- [Background](#background) +- [Design](#design) +- [Implementation Plan](#implementation-plan) +- [Alternatives Considered](#alternatives-considered) +- [Open Questions](#open-questions) + +## Problem Statement + +The WaveASM linear-scan register allocator treats VGPR, SGPR, and AGPR pools +as independent. When any pool is exhausted the compilation fails with a hard +error. There is no spilling of any kind. + +The 256x192x256 block MXFP4 GEMM with dynamic M/N/K allocates ~261 VGPRs, +5 over the gfx950 hardware limit of 256. The kernel simultaneously leaves +~64 AGPRs unused (only 192 of 256 consumed by MFMA accumulators). + +A post-regalloc Python hack rewrites assembly text to shuttle excess VGPRs +through AGPRs via `v_accvgpr_read/write`. It works but is fragile: no +dataflow analysis, regex-based operand classification, hardcoded scratch +registers, and `s_waitcnt vmcnt(0)` sledgehammers that destroy latency +hiding. + +This document proposes a proper cross-class spilling mechanism inside the +linear-scan allocator where liveness, tied values, and hazard information +are available. + +## Background + +### AMDGCN Register File (gfx9) + +| Class | Count | Typical use | Move cost (cycles) | +|-------|-------|--------------------------------|--------------------| +| SGPR | 102 | Scalar values, SRDs, loop IVs | -- | +| VGPR | 256 | Per-lane vector values | -- | +| AGPR | 256 | MFMA accumulators | 4 (read/write) | + +Cross-class move instructions: + +- **SGPR -> VGPR**: `v_mov_b32 vD, sS` (1 VALU cycle). +- **VGPR -> SGPR**: `v_readfirstlane_b32 sD, vS` (1 VALU cycle, lane 0). +- **VGPR -> AGPR**: `v_accvgpr_write_b32 aD, vS` (4 cycles). +- **AGPR -> VGPR**: `v_accvgpr_read_b32 vD, aS` (4 cycles, +1 s_nop on + gfx950 before consumer). + +There is no direct SGPR <-> AGPR path; a VGPR scratch is required as +intermediate. + +### LLVM Precedent + +LLVM's AMDGPU backend implements SGPR -> VGPR spilling +(`SILowerSGPRSpills`). When SGPRs are exhausted, values are parked in +dedicated VGPR "spill lanes." Reload uses `v_readfirstlane_b32`. + +LLVM does **not** implement VGPR -> AGPR spilling. When VGPRs overflow, +values spill to scratch memory via `buffer_store/load` on the private +segment. This costs ~100+ cycles per access and consumes VMEM bandwidth. + +The proposed cross-class approach avoids scratch memory entirely for moderate +overflow (tens of registers) at a cost of 4-8 VALU cycles per spill/reload. + +### Current Allocator Architecture + +Relevant source files (all under `waveasm/`): + +``` +include/waveasm/Transforms/RegAlloc.h -- RegPool, PhysicalMapping, allocator class +lib/Transforms/LinearScanRegAlloc.cpp -- core linear scan +lib/Transforms/LinearScanPass.cpp -- pass wrapper, precoloring +lib/Transforms/Liveness.cpp -- live range computation +``` + +Key properties: + +1. Three independent `RegPool` instances (VGPR, SGPR, AGPR). +2. Live ranges sorted by `(start, end)` and processed in start-point order. +3. Two separate tie mechanisms: `TiedValueClasses` in Liveness for loop + iter_args (block_arg + init_arg + iter_arg + result), and a separate + `tiedPairs` map in LinearScanPass for MFMA accumulator -> result ties. +4. Bottom-up allocation only (`allocSingle`, `allocRange`). No top-down + or bidirectional heuristics. +5. `allocateRegClass()` receives a single `RegPool &pool` -- it has no + access to other register classes when allocation fails. +6. Hard error on pool exhaustion -- no eviction, no spilling. +7. No rematerialization. No post-regalloc compaction. +8. v15 reserved for literal materialization (`kScratchVGPR` in + AssemblyEmitter.h). v14 is not currently reserved. +9. Hazard mitigation pass handles VALU -> `v_readfirstlane` and + Trans -> VALU hazards only. No `v_accvgpr_read_b32` RAW hazard + handling exists yet. + +## Design + +### Spill Cascade + +When `tryAllocate()` returns no register for a given class, the allocator +attempts cross-class eviction before failing: + +``` +SGPR overflow --> park in spare VGPR (v_readfirstlane to reload) +VGPR overflow --> park in spare AGPR (v_accvgpr_write/read) +AGPR overflow --> park in spare VGPR (v_accvgpr_read/write) +all exhausted --> hard error (future: scratch memory) +``` + +The cascade is one-level deep: a VGPR spilled to AGPR does not trigger +further cascading. This keeps the design simple and covers the practical +cases (VGPR overflow with spare AGPRs being the dominant scenario). + +### Eviction Strategy + +When the incoming range cannot be allocated, we evict an already-allocated +**victim** range to the alternate class. This frees the victim's physical +register for the incoming range. + +Victim selection criteria (in priority order): + +1. **Fewest use sites** in the hot region (loop body). Each use site + requires a reload instruction, so fewer uses means lower spill cost. +2. **Not part of a tied equivalence class.** Spilling a tied value + requires spilling the entire class (loop init_arg + block_arg + + iter_arg + result all share one physical register). Avoid this + unless there is no untied candidate. +3. **Longest remaining range.** Among equal-cost candidates, evicting + the longest range frees the register for the most time, reducing + future pressure. + +The victim must have `size == 1`. Multi-register spills (size 2/4/8) +require contiguous ranges in the target class and complicate reload +sequences. For the initial implementation we restrict to single-register +eviction. This covers the common case (individual scalar values that +happen to sit in VGPRs, or single excess VGPRs from address computation). + +### Spill and Reload Insertion + +The spill/reload is a **range split**: the victim's original live range is +shortened to just its def point, and a new "spill range" in the alternate +class covers the rest. At each use site a reload is inserted. + +#### VGPR -> AGPR + +At the victim's def point, **after** the defining op: + +```asm +v_accvgpr_write_b32 aX, vY ; spill to AGPR +``` + +If the def is an async load (`buffer_load`, `global_load`), the spill +must wait for the load to complete. Rather than inserting a blanket +`s_waitcnt vmcnt(0)`, the spill point is deferred to just before the +first use, and the existing waitcnt analysis pass handles the dependency. +(See [Interaction with WaitCnt Pass](#waitcnt-pass) below.) + +Before each use site: + +```asm +v_accvgpr_read_b32 vSCR, aX ; reload to scratch VGPR +; (hazard mitigation pass inserts s_nop if needed) + +``` + +The scratch register `vSCR` is a dedicated per-spill-site temporary +allocated from a small reserved pool (currently v14, v15). After the +use, `vSCR` is dead and available for the next reload. + +#### SGPR -> VGPR + +At the victim's def point: + +```asm +v_mov_b32 vX, sY ; broadcast scalar to VGPR lane +``` + +Before each use site: + +```asm +v_readfirstlane_b32 sSCR, vX ; extract lane 0 back to SGPR + +``` + +#### AGPR -> VGPR + +Same as VGPR -> AGPR but reversed: + +```asm +v_accvgpr_read_b32 vX, aY ; at def +v_accvgpr_write_b32 aSCR, vX ; reload before use +``` + +This direction is uncommon (MFMA-heavy kernels rarely overflow AGPRs +while having spare VGPRs) but is included for completeness. + +### Scratch Register Management + +The current allocator reserves v15 for literal materialization +(`kScratchVGPR` in AssemblyEmitter.h, reserved at LinearScanPass.cpp:159). +v14 is not currently reserved. + +Proposed approach: + +- Reserve v14 as a second scratch VGPR for spill reloads + (`reservedVGPRs.insert(14)` in LinearScanPass.cpp). This gives a + **scratch pool** of 2 VGPRs (v14, v15). +- Each reload site needs exactly 1 scratch register for the duration + of one instruction. With 2 scratches available, two independent + reloads can be in flight (e.g., an instruction reading two different + spilled values). +- If more than 2 simultaneous reloads are needed at one instruction, + the allocator must serialize them (reload one, use it, then reload + the next). This is handled by the reload insertion logic. +- Scratch registers for SGPR reloads (`sSCR`) are drawn from the + SGPR pool. One dedicated SGPR scratch (e.g., s0, since kernel args + are already loaded) suffices for single-register spills. + +### Interaction with Existing Passes + +#### Liveness + +No changes to the liveness pass itself. `LiveRange` carries a `regClass` +field and results are separated into `vregRanges`, `sregRanges`, +`aregRanges` (Liveness.cpp:662-669). The `usePoints` map +(Liveness.h:177) stores per-value use-site indices, which the eviction +heuristic needs for "fewest use sites" ranking. + +The spill/reload insertion creates new SSA values with their own def/use +points; the liveness data is recomputed if needed (or the spill is done +in a fixup walk after initial allocation). + +#### Tied Values + +Two separate tie mechanisms exist: + +- **Loop ties** (`TiedValueClasses` in Liveness): group block_arg + + init_arg + iter_arg + loop_result into equivalence classes. +- **MFMA ties** (`tiedPairs` in LinearScanPass): map MFMA result -> + accumulator operand, added via `allocator.addTiedOperand()`. + +Neither kind is a candidate for cross-class spilling in the initial +implementation. Spilling a loop-tied value requires spill/reload at +every back-edge. Spilling an MFMA-tied value means the accumulator +lives in an AGPR but the result needs a VGPR, which defeats the tie. + +If the *only* way to satisfy pressure is to spill a tied value, the +allocator falls back to the hard error. A future extension could handle +this by breaking the tie and inserting explicit copies on back-edges, but +this is out of scope. + +#### WaitCnt Pass (Ticketing.cpp) + +The existing `--waveasm-insert-waitcnt` pass tracks outstanding VMEM and +LGKM operations and inserts `s_waitcnt` based on operand dependencies. +If a `v_accvgpr_write` consumes a VGPR defined by a `buffer_load`, the +pass must ensure the load has completed. + +The spill insertion should emit the `v_accvgpr_write` as a normal +WaveASM op. The waitcnt pass then sees its operand dependency and +inserts the appropriate `s_waitcnt vmcnt(N)` with a precise count +rather than the blanket `vmcnt(0)` used by the Python hack. + +**Caveat**: the waitcnt pass currently tracks memory ops and their +*direct* consumers. Verify that a `v_accvgpr_write_b32` reading a +`buffer_load` result is recognized as a consumer that needs a wait. + +#### Hazard Mitigation + +The `--waveasm-hazard-mitigation` pass currently handles only two +hazard patterns (HazardMitigation.cpp): + +1. VALU -> `v_readfirstlane_b32` RAW hazard. +2. Transcendental -> non-Trans VALU forwarding hazard. + +It does **not** handle `v_accvgpr_read_b32` RAW hazards on GFX950. +This must be added as part of Phase 1: after a `v_accvgpr_read_b32` +writes a VGPR, the next consumer needs an `s_nop 0` (or scheduling +gap) to avoid silent data corruption. + +#### ScopedCSE + +The `--mlir-cse` pass runs before regalloc. Spill/reload ops are +inserted after CSE, so there is no risk of the CSE pass merging a +spill reload with an unrelated `v_accvgpr_read`. + +## Implementation Plan + +### Phase 1: VGPR -> AGPR Spilling (MVP) + +This covers the immediate need (256x192 GEMM, 5 excess VGPRs, 64 spare +AGPRs). + +1. **RegAlloc.h**: Add `SpillRecord` struct and expose spill results: + + ```cpp + struct SpillRecord { + Value victim; // original SSA value + int64_t sourcePhysReg; // freed VGPR index + int64_t targetPhysReg; // allocated AGPR index + RegClass sourceClass; // VGPR + RegClass targetClass; // AGPR + }; + ``` + + The allocator's return type must be extended to include + `SmallVector` alongside `PhysicalMapping` and + `AllocationStats`. + +2. **LinearScanRegAlloc.cpp**: `allocateRegClass()` currently receives + a single `RegPool &pool`. Change signature to also accept the + alternate-class pool (AGPR pool when allocating VGPRs). When + `tryAllocate()` fails: + + a. Scan active list for eviction candidates (untied, size==1, + fewest uses via `LivenessInfo::usePoints`). + b. Allocate an AGPR from the alternate pool for the victim. + c. Free the victim's VGPR. + d. Re-attempt `tryAllocate()` for the incoming range. + e. Record eviction in `SmallVector`. + +3. **LinearScanPass.cpp**: Reserve v14 as spill scratch + (`reservedVGPRs.insert(14)`) alongside the existing v15 reservation. + After `allocator.allocate()` succeeds, process `SpillRecord` list + (insertion point: after mapping adjustments at ~line 290, before IR + type transformation at ~line 292): + + a. For each spill, walk the program and insert + `v_accvgpr_write_b32` after the victim's def. + b. Before each use of the victim, insert + `v_accvgpr_read_b32 vSCR, aX`. + c. Rewrite the use to read from `vSCR`. + d. Update the op's result type from `PVRegType` to `PARegType` + for the spilled value (or keep as VGPR and let the inserted + ops handle it -- TBD based on what is cleaner for downstream + passes). + +4. **HazardMitigation.cpp**: Add `v_accvgpr_read_b32` RAW hazard + detection for GFX950. When a `v_accvgpr_read_b32` writes a VGPR + and the next instruction consumes it, insert `s_nop 0`. This is + required for correctness -- without it, MFMA scale operands can be + silently corrupted. + +5. **Tests**: Add LIT tests in `waveasm/test/Transforms/`: + + - `cross-class-spill-vgpr-to-agpr.mlir`: Verify spill/reload + insertion with max-vgprs=4 on a program needing 5. + - `cross-class-spill-no-tied.mlir`: Verify tied values are not + spilled. + - `cross-class-spill-async-load.mlir`: Verify correct ordering + when the spilled value comes from a buffer_load. + +### Phase 2: SGPR -> VGPR Spilling + +Same pattern, different move instructions. Lower priority since SGPR +overflow is less common in current kernels (SALU promotion already +keeps most scalars in SGPRs). + +1. When SGPR allocation fails, find an untied SGPR victim. +2. Allocate a VGPR for it. +3. Insert `v_mov_b32 vX, sY` at def, `v_readfirstlane_b32 sSCR, vX` + at uses. + +### Phase 3: AGPR -> VGPR Spilling + +Reverse of Phase 1. Lowest priority -- AGPR overflow with spare VGPRs +is rare in practice. + +### Phase 4 (Future): Scratch Memory Spill + +When cross-class spilling is insufficient (all three classes at capacity), +the last resort is scratch memory. This requires: + +- A scratch SRD (already available in the kernel descriptor). +- `buffer_store_dword` at spill, `buffer_load_dword` at reload. +- Stack frame offset management. +- Integration with the waitcnt pass for VMEM tracking. + +Out of scope for this design. The cross-class approach should cover +practical cases for the foreseeable kernel sizes. + +## Alternatives Considered + +1. **Post-regalloc assembly text rewriting (current hack).** + Works but unsound: no dataflow analysis, regex operand classification, + blanket waitcnts, hardcoded scratch registers. See Problem Statement. + +2. **Reduce register pressure at the IR level.** + More aggressive rematerialization, shorter live ranges, or kernel + restructuring. This is complementary (and we should do it), but + cannot always close the gap. The 256x192 kernel is already heavily + optimized and still needs 261 VGPRs. + +3. **Scratch memory spilling only (like LLVM).** + Sound but expensive. 4 cycles for an AGPR move vs 100+ cycles for + scratch buffer_load, plus VMEM bandwidth contention with data loads. + Cross-class spilling is strictly better when spare registers exist in + the target class. + +4. **Occupancy reduction.** + Accept fewer waves per CU by using more registers. Not applicable + here: gfx950 has a hard limit of 256 VGPRs regardless of occupancy. + +5. **Second-chance allocation (evict + re-spill).** + When evicting a victim, allow the victim to itself be evicted further + (multi-level cascade). Adds complexity for marginal benefit. + Single-level suffices for tens of excess registers. + +## Open Questions + +1. **Spill at def vs spill at last-use-before-gap?** Spilling at def is + simpler but keeps the AGPR occupied for the entire range. Spilling + at the last use before a pressure spike is optimal but requires + pressure curve analysis. Start with spill-at-def. + +2. **Cost model for eviction.** Use count alone, or weight by loop depth? + A value used once inside a loop body that executes 128 times is + effectively 128 uses. Start with raw use count; refine if profiling + shows poor spill placement. + +3. **Multiple simultaneous spills.** If 5 VGPRs must spill, we need 5 + spare AGPRs and potentially 2+ reloads in a single instruction. The + scratch pool of 2 VGPRs limits us to 2 simultaneous reloads. Should + we grow the scratch pool dynamically, or serialize reloads? Start + with serialization (insert multiple read + nop sequences). + +4. **Interaction with loop-address-promotion.** The promotion pass adds + VGPRs to eliminate VALU from the loop body. If those promoted VGPRs + are then spilled to AGPRs, the VALU savings are partially offset by + spill/reload VALU. Should we teach the promotion pass about the + AGPR budget? Deferred -- the fallback path already disables + promotion when VGPRs exceed the limit. diff --git a/waveasm/include/waveasm/Transforms/RegAlloc.h b/waveasm/include/waveasm/Transforms/RegAlloc.h index 9e46e589dd..b72d95363f 100644 --- a/waveasm/include/waveasm/Transforms/RegAlloc.h +++ b/waveasm/include/waveasm/Transforms/RegAlloc.h @@ -22,6 +22,22 @@ namespace waveasm { +//===----------------------------------------------------------------------===// +// Spill Record +//===----------------------------------------------------------------------===// + +/// Records a cross-class spill decision made during register allocation. +/// The victim value is evicted from its original register class into the +/// alternate class (e.g. VGPR -> AGPR). The pass uses these records to +/// insert spill/reload ops after allocation completes. +struct SpillRecord { + mlir::Value victim; ///< The SSA value being spilled. + int64_t sourcePhysReg; ///< Physical register freed in original class. + int64_t targetPhysReg; ///< Physical register allocated in target class. + RegClass sourceClass; ///< Original register class (e.g. VGPR). + RegClass targetClass; ///< Target register class (e.g. AGPR). +}; + //===----------------------------------------------------------------------===// // Allocation Statistics //===----------------------------------------------------------------------===// @@ -206,6 +222,10 @@ class RegPool { } } + /// Check if the pool has at least one free register. + bool hasFree() const { return free.any(); } + + /// Get peak usage. int64_t getPeakUsage() const { return peak; } int64_t getCurrentUsage() const { return currentUsage; } @@ -294,11 +314,17 @@ class LinearScanRegAlloc { vgprStrategy = std::move(strategy); } - /// Run allocation on a kernel program - /// Returns the physical mapping and statistics, or failure if allocation - /// fails - mlir::FailureOr> - allocate(ProgramOp program); + /// Result bundle returned by allocate(). + struct AllocResult { + PhysicalMapping mapping; + AllocationStats stats; + llvm::SmallVector spills; + }; + + /// Run allocation on a kernel program. + /// Returns the physical mapping, statistics, and any cross-class spill + /// records, or failure if allocation fails. + mlir::FailureOr allocate(ProgramOp program); private: /// Process active ranges, expiring those that end before currentPoint diff --git a/waveasm/lib/Transforms/HazardMitigation.cpp b/waveasm/lib/Transforms/HazardMitigation.cpp index b91d9365fb..92d478eae7 100644 --- a/waveasm/lib/Transforms/HazardMitigation.cpp +++ b/waveasm/lib/Transforms/HazardMitigation.cpp @@ -8,8 +8,9 @@ // Hazard Mitigation Pass - Insert s_nop instructions for hardware hazards // // This pass handles hardware-specific hazards that require NOP insertion: -// - VALU → v_readfirstlane hazard (gfx940+) -// - Trans → non-Trans VALU forwarding hazard (gfx940+) +// - VALU -> v_readfirstlane hazard (gfx940+) +// - Trans -> non-Trans VALU forwarding hazard (gfx940+) +// - v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+) //===----------------------------------------------------------------------===// #include "waveasm/Dialect/WaveASMAttrs.h" @@ -74,7 +75,7 @@ bool isVALUOp(Operation *op) { return true; } -/// Check if an operation is v_readfirstlane +/// Check if an operation is v_readfirstlane. bool isReadfirstlaneOp(Operation *op) { return isa(op); } /// Check if an operation does NOT emit an assembly instruction. @@ -99,6 +100,9 @@ Operation *findPrecedingEmittingOp(ArrayRef ops, size_t start) { return nullptr; } +/// Check if an operation is v_accvgpr_read_b32 (AGPR -> VGPR move). +bool isAccVgprReadOp(Operation *op) { return isa(op); } + /// Check if an operation is a transcendental instruction (uses the Trans /// pipeline which has different latency characteristics from the main VALU). bool isTransOp(Operation *op) { @@ -214,6 +218,16 @@ struct HazardMitigationPass if (pred && isTransOp(pred) && hasVGPRConflict(pred, op)) insertionPoints.push_back(op); } + + // Check for v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+). + // On GFX950 the VGPR destination of v_accvgpr_read_b32 is not + // immediately available; a VALU that consumes it in the next cycle + // reads stale data. Insert s_nop 0 to cover the 1-cycle wait. + if (isVALUOp(op) && i > 0) { + Operation *pred = findPrecedingEmittingOp(ops, i); + if (pred && isAccVgprReadOp(pred) && hasVGPRConflict(pred, op)) + insertionPoints.push_back(op); + } } // Insert s_nop instructions diff --git a/waveasm/lib/Transforms/LinearScanPass.cpp b/waveasm/lib/Transforms/LinearScanPass.cpp index ecc2cae6ce..30f52672bd 100644 --- a/waveasm/lib/Transforms/LinearScanPass.cpp +++ b/waveasm/lib/Transforms/LinearScanPass.cpp @@ -268,12 +268,101 @@ struct LinearScanPass /// Get the accumulator operand from an MFMA op using the interface. /// Returns nullptr if the operation is not an MFMA. Value getMFMAAccumulator(Operation *op) { - if (auto mfmaOp = dyn_cast(op)) { + if (auto mfmaOp = dyn_cast(op)) return mfmaOp.getAcc(); - } return nullptr; } + /// Scratch VGPR index used for cross-class spill reloads. + static constexpr int64_t kSpillScratchVGPR = 14; + + /// Insert spill/reload ops for cross-class evictions. + /// + /// For each VGPR->AGPR spill: + /// - After the victim's def: v_accvgpr_write_b32 vSRC, aDST + /// - Before each use: v_accvgpr_read_b32 aSRC -> vSCRATCH + /// and rewrite the use to consume the reload result. + /// + /// The ops are created with physical register types so the subsequent + /// type transformation pass does not need to touch them. + LogicalResult insertSpillReloads(ProgramOp program, + ArrayRef spills, + PhysicalMapping &mapping) { + // Validate that all spills use the implemented VGPR -> AGPR direction. + // SGPR -> VGPR and AGPR -> VGPR are wired in the allocator but not yet + // supported by the spill insertion logic. + for (const SpillRecord &sr : spills) { + if (sr.sourceClass != RegClass::VGPR || + sr.targetClass != RegClass::AGPR) { + return program.emitOpError() + << "unsupported spill direction: only VGPR -> AGPR is " + "currently implemented"; + } + } + + // Build a lookup from victim Value -> SpillRecord. + llvm::DenseMap spillMap; + for (const SpillRecord &sr : spills) + spillMap[sr.victim] = &sr; + + MLIRContext *ctx = program.getContext(); + + // For each spill, create a precolored AGPR value to use as the + // source/destination in read/write ops. This gives us an SSA handle + // to thread through the spill ops. + // + // We insert a PrecoloredARegOp at program entry to materialise the + // AGPR "slot". It does not generate any assembly; it just provides + // an SSA value with the right physical type for the dialect ops. + llvm::DenseMap spillSlots; // victim -> AGPR SSA value + { + OpBuilder entryBuilder(ctx); + Block &entry = program.getBodyBlock(); + entryBuilder.setInsertionPointToStart(&entry); + for (const SpillRecord &sr : spills) { + auto physType = PARegType::get(ctx, sr.targetPhysReg, 1); + Value slot = PrecoloredARegOp::create(entryBuilder, program.getLoc(), + physType, sr.targetPhysReg, + /*size=*/1); + spillSlots[sr.victim] = slot; + } + } + + // Collect all ops in program order. + llvm::SmallVector ops; + collectOpsRecursive(program.getBodyBlock(), ops); + + for (Operation *op : ops) { + // --- Insert spills after defs. --- + for (Value result : op->getResults()) { + auto it = spillMap.find(result); + if (it == spillMap.end()) + continue; + const SpillRecord &sr = *it->second; + OpBuilder builder(ctx); + builder.setInsertionPointAfter(op); + Value slot = spillSlots[sr.victim]; + V_ACCVGPR_WRITE_B32::create(builder, op->getLoc(), result, slot); + } + + // --- Insert reloads before uses. --- + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto it = spillMap.find(operand); + if (it == spillMap.end()) + continue; + OpBuilder builder(op); + auto scratchType = PVRegType::get(ctx, kSpillScratchVGPR, 1); + Value slot = spillSlots[operand]; + Value reloaded = V_ACCVGPR_READ_B32::create(builder, op->getLoc(), + scratchType, slot); + op->setOperand(i, reloaded); + } + } + + return success(); + } + LogicalResult processProgram(ProgramOp program) { // Collect precolored values from precolored.vreg and precolored.sreg ops llvm::DenseMap precoloredValues; @@ -291,6 +380,11 @@ struct LinearScanPass // generates v_mov_b32 v15, before such instructions. reservedVGPRs.insert(15); + // Reserve v14 as scratch VGPR for cross-class spill reloads. + // When a VGPR is spilled to an AGPR, reloads use v_accvgpr_read_b32 + // into this scratch before the consuming instruction. + reservedVGPRs.insert(14); + // ABI SGPRs (kernarg ptr, preload regs, workgroup IDs, SRDs) are // reserved via PrecoloredSRegOp ops emitted during translation. The // collection loop below picks those up and adds their indices to @@ -404,13 +498,12 @@ struct LinearScanPass allocator.setVGPRStrategy( std::make_unique(bidirectionalThreshold)); - // Run allocation + // Run allocation. auto result = allocator.allocate(program); - if (failed(result)) { + if (failed(result)) return failure(); - } - auto [mapping, stats] = *result; + auto &[mapping, stats, spills] = *result; // Handle waveasm.extract ops: result = source[offset]. // Set the extract result's physical register = source's physReg + offset. @@ -497,6 +590,16 @@ struct LinearScanPass if (packResult.wasInterrupted()) return failure(); + // Insert cross-class spill/reload ops for any evicted values. + // For each spill record (e.g. VGPR -> AGPR): + // - After the victim's def: v_accvgpr_write_b32 aX, vY (spill) + // - Before each use: v_accvgpr_read_b32 v14, aX (reload) + // and rewrite the use to consume v14 instead of vY. + if (!spills.empty()) { + if (failed(insertSpillReloads(program, spills, mapping))) + return failure(); + } + // Transform the IR: replace virtual register types with physical types OpBuilder builder(program.getContext()); diff --git a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp index dfa397de16..2ecad139fc 100644 --- a/waveasm/lib/Transforms/LinearScanRegAlloc.cpp +++ b/waveasm/lib/Transforms/LinearScanRegAlloc.cpp @@ -115,26 +115,71 @@ BidirectionalStrategy::allocate(RegPool &pool, const LiveRange &range, // Register Class Allocation //===----------------------------------------------------------------------===// +/// Find the best eviction candidate from the active list. +/// Prefers untied, size-1 ranges with the fewest use sites. +/// Among equal-cost candidates, picks the longest remaining range. +/// Returns the index into `active`, or -1 if no candidate is found. +static int64_t +findEvictionCandidate(const llvm::SmallVectorImpl &active, + const llvm::DenseMap &tiedOperands, + const LivenessInfo *liveness, int64_t currentPoint) { + int64_t bestIdx = -1; + int64_t bestUseCount = std::numeric_limits::max(); + int64_t bestLength = -1; + + for (int64_t i = 0, e = active.size(); i < e; ++i) { + const ActiveRange &ar = active[i]; + // Only spill size-1 values. + if (ar.range.size != 1) + continue; + // Do not spill tied values. + if (ar.range.isTied() || tiedOperands.contains(ar.range.reg)) + continue; + // Must still be alive past the current point. + if (ar.endPoint <= currentPoint) + continue; + + int64_t useCount = 0; + if (liveness) { + auto it = liveness->usePoints.find(ar.range.reg); + if (it != liveness->usePoints.end()) + useCount = it->second.size(); + } + int64_t length = ar.endPoint - currentPoint; + + // Prefer fewer uses; break ties with longer remaining range + // (evicting a long range frees the register for more time). + if (useCount < bestUseCount || + (useCount == bestUseCount && length > bestLength)) { + bestIdx = i; + bestUseCount = useCount; + bestLength = length; + } + } + return bestIdx; +} + /// Allocate registers for a single register class (VGPR, SGPR, or AGPR). /// This is the core linear scan algorithm, parameterized by register class. /// An optional AllocationStrategy is consulted before the default bottom-up -/// allocation; returning std::nullopt falls through to tryAllocate. +/// allocation; when allocation fails and `altPool` is provided, the allocator +/// evicts an active range to the alternate register class before giving up. static LogicalResult allocateRegClass( ArrayRef ranges, RegPool &pool, PhysicalMapping &mapping, AllocationStats &stats, const llvm::DenseMap &tiedOperands, const llvm::DenseMap &precoloredValues, llvm::StringRef regClassName, ProgramOp program, int64_t maxRegs, - int64_t maxPressure, AllocationStrategy *strategy) { + int64_t maxPressure, AllocationStrategy *strategy, RegPool *altPool, + llvm::SmallVectorImpl *spills, const LivenessInfo *liveness) { llvm::SmallVector active; for (const LiveRange &range : ranges) { - // Skip precolored values - they're already mapped - if (precoloredValues.contains(range.reg)) { + // Skip precolored values - they're already mapped. + if (precoloredValues.contains(range.reg)) continue; - } - // Expire finished ranges, returning registers to the pool + // Expire finished ranges, returning registers to the pool. expireRanges(active, range.start, pool, stats); std::optional physReg; @@ -151,20 +196,13 @@ static LogicalResult allocateRegClass( mapping.valueToPhysReg[range.reg] = *physReg; // Extend the physical register's lifetime to cover the tied result. - // The tied-to operand may have a shorter lifetime (e.g., %55 ends - // at op2), but the tied result (%56) may live longer (used in - // iteration 2). Without this extension, the physical register would - // be freed too early when the tied-to operand expires. bool foundInActive = false; for (size_t i = 0; i < active.size(); ++i) { if (active[i].physReg == *physReg) { foundInActive = true; if (range.end > active[i].endPoint) { - // Update end point and re-sort the affected portion active[i].endPoint = range.end; active[i].range = range; - // Re-sort: since we only increased one element's key, - // bubble it forward to maintain sorted order by endPoint while (i + 1 < active.size() && active[i].endPoint > active[i + 1].endPoint) { std::swap(active[i], active[i + 1]); @@ -176,13 +214,6 @@ static LogicalResult allocateRegClass( } if (!foundInActive) { - // Two cases reach here: - // (a) Precolored tying (MFMA): the tied-to value is precolored - // and was never in the active list. Its physReg is already - // reserved in the pool. pool.reserve is a safe no-op. - // (b) Loop boundary: the tied-to virtual value expired from - // active and its registers were returned to the pool. - // The physReg MUST still be free (not re-allocated). bool tiedToPrecolored = precoloredValues.contains(tiedTo); assert((tiedToPrecolored || pool.isFree(*physReg)) && "Tied register was re-allocated before re-reservation"); @@ -193,7 +224,7 @@ static LogicalResult allocateRegClass( stats.rangesAllocated++; continue; } - // If tied-to not yet allocated, fall through to normal allocation + // If tied-to not yet allocated, fall through to normal allocation. } } @@ -203,6 +234,32 @@ static LogicalResult allocateRegClass( if (!physReg) physReg = tryAllocate(pool, range.size, range.alignment); + // Cross-class eviction: if allocation failed and an alternate pool is + // available, evict the best candidate from the active list into the + // alternate class, freeing its register for the incoming range. + if (!physReg && altPool && altPool->hasFree() && spills) { + int64_t victimIdx = + findEvictionCandidate(active, tiedOperands, liveness, range.start); + if (victimIdx >= 0) { + ActiveRange victim = active[victimIdx]; + // Allocate a register in the alternate class for the victim. + int64_t altReg = altPool->allocSingle(); + if (altReg >= 0) { + // Free the victim's register back to the primary pool. + pool.freeRange(victim.physReg, victim.range.size); + active.erase(active.begin() + victimIdx); + + // Record the spill for later op insertion. + spills->push_back(SpillRecord{victim.range.reg, victim.physReg, + altReg, pool.getRegClass(), + altPool->getRegClass()}); + + // Retry allocation for the incoming range. + physReg = tryAllocate(pool, range.size, range.alignment); + } + } + } + if (!physReg) { InFlightDiagnostic diag = mlir::emitError(range.reg.getLoc()) << "Failed to allocate " << regClassName @@ -214,10 +271,10 @@ static LogicalResult allocateRegClass( return failure(); } - // Record mapping: Value -> physical register + // Record mapping: Value -> physical register. mapping.valueToPhysReg[range.reg] = *physReg; - // Add to active list, maintaining sorted order by end point + // Add to active list, maintaining sorted order by end point. insertActiveRange(active, {range.end, range, *physReg}); stats.rangesAllocated++; @@ -230,40 +287,35 @@ static LogicalResult allocateRegClass( // Main Allocation Algorithm (Pure SSA) //===----------------------------------------------------------------------===// -FailureOr> +FailureOr LinearScanRegAlloc::allocate(ProgramOp program) { PhysicalMapping mapping; AllocationStats stats; + llvm::SmallVector spills; - // Step 1: Validate SSA - if (failed(validateSSA(program))) { + // Step 1: Validate SSA. + if (failed(validateSSA(program))) return program.emitOpError() << "SSA validation failed before allocation"; - } - // Step 2: Compute liveness (builds tied equivalence classes) + // Step 2: Compute liveness (builds tied equivalence classes). LivenessInfo liveness = computeLiveness(program); // Merge loop tied pairs from liveness into the allocator's tiedOperands. - // The liveness analysis builds TiedValueClasses with a tiedPairs map - // that captures loop block_arg/init_arg/iter_arg/result relationships. - // MFMA ties were added externally via addTiedOperand() and are already - // in tiedOperands; loop ties come from liveness and are merged here. for (const auto &[result, operand] : liveness.tiedClasses.tiedPairs) { - if (!tiedOperands.contains(result)) { + if (!tiedOperands.contains(result)) tiedOperands[result] = operand; - } } stats.totalVRegs = liveness.vregRanges.size(); stats.totalSRegs = liveness.sregRanges.size(); stats.totalARegs = liveness.aregRanges.size(); - // Step 3: Create register pools with reserved registers + // Step 3: Create register pools with reserved registers. RegPool vgprPool(RegClass::VGPR, maxVGPRs, reservedVGPRs); RegPool sgprPool(RegClass::SGPR, maxSGPRs, reservedSGPRs); RegPool agprPool(RegClass::AGPR, maxAGPRs, reservedAGPRs); - // Step 4: Handle precolored values (from ABI args like tid, kernarg) + // Step 4: Handle precolored values (from ABI args like tid, kernarg). for (const auto &[value, physIdx] : precoloredValues) { if (isVGPRType(value.getType())) { mapping.valueToPhysReg[value] = physIdx; @@ -277,32 +329,31 @@ LinearScanRegAlloc::allocate(ProgramOp program) { } } - // Step 5: Allocate VGPRs using linear scan (with optional strategy) - if (failed(allocateRegClass(liveness.vregRanges, vgprPool, mapping, stats, - tiedOperands, precoloredValues, "VGPR", program, - maxVGPRs, liveness.maxVRegPressure, - vgprStrategy.get()))) { + // Step 5: Allocate VGPRs. On failure, evict to spare AGPRs. + if (failed(allocateRegClass( + liveness.vregRanges, vgprPool, mapping, stats, tiedOperands, + precoloredValues, "VGPR", program, maxVGPRs, liveness.maxVRegPressure, + vgprStrategy.get(), &agprPool, &spills, &liveness))) return failure(); - } stats.peakVGPRs = vgprPool.getPeakUsage(); - // Step 6: Allocate SGPRs using linear scan (no strategy, always bottom-up) + // Step 6: Allocate SGPRs (no cross-class spilling yet). if (failed(allocateRegClass(liveness.sregRanges, sgprPool, mapping, stats, tiedOperands, precoloredValues, "SGPR", program, maxSGPRs, liveness.maxSRegPressure, - /*strategy=*/nullptr))) { + /*strategy=*/nullptr, /*altPool=*/nullptr, + /*spills=*/nullptr, &liveness))) return failure(); - } stats.peakSGPRs = sgprPool.getPeakUsage(); - // Step 7: Allocate AGPRs using linear scan (no strategy, always bottom-up) + // Step 7: Allocate AGPRs (no cross-class spilling yet). if (failed(allocateRegClass(liveness.aregRanges, agprPool, mapping, stats, tiedOperands, precoloredValues, "AGPR", program, maxAGPRs, liveness.maxARegPressure, - /*strategy=*/nullptr))) { + /*strategy=*/nullptr, /*altPool=*/nullptr, + /*spills=*/nullptr, &liveness))) return failure(); - } stats.peakAGPRs = agprPool.getPeakUsage(); - return std::make_pair(mapping, stats); + return AllocResult{std::move(mapping), stats, std::move(spills)}; } diff --git a/waveasm/test/Transforms/cross-class-spill-exhausted.mlir b/waveasm/test/Transforms/cross-class-spill-exhausted.mlir new file mode 100644 index 0000000000..05fb5330be --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-exhausted.mlir @@ -0,0 +1,15 @@ +// RUN: not waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=0" %s 2>&1 | FileCheck %s +// +// Test: When both VGPR and AGPR pools are exhausted, allocation fails +// with a diagnostic. max-agprs=0 leaves no spare AGPRs for eviction. + +// CHECK: error: Failed to allocate VGPR +waveasm.program @both_exhausted target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r2 = waveasm.v_add_u32 %r1, %v0 : !waveasm.vreg, !waveasm.pvreg<0> -> !waveasm.vreg + %r3 = waveasm.v_add_u32 %r2, %v1 : !waveasm.vreg, !waveasm.pvreg<1> -> !waveasm.vreg + %sum = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir b/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir new file mode 100644 index 0000000000..04f4b0101d --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-multi-reg.mlir @@ -0,0 +1,25 @@ +// RUN: not waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=8" %s 2>&1 | FileCheck %s +// +// Test: Multi-register (size > 1) values are NOT spill candidates. +// When the only live values are multi-register, eviction is not attempted +// and allocation fails. + +// CHECK: error: Failed to allocate VGPR +waveasm.program @multi_reg_no_spill target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %soff = waveasm.constant 0 : !waveasm.imm<0> + + // A 4-wide load consumes v2..v5 (v0 precolored, v14/v15 reserved). + // That exhausts the pool. There are no size-1 eviction candidates. + %wide = waveasm.buffer_load_dwordx4 %srd, %voff, %soff : !waveasm.psreg<0, 4>, !waveasm.pvreg<0>, !waveasm.imm<0> -> !waveasm.vreg<4, 4> + + // This needs another VGPR but nothing can be evicted (wide is size 4). + %extra = waveasm.v_add_u32 %voff, %voff : !waveasm.pvreg<0>, !waveasm.pvreg<0> -> !waveasm.vreg + + // Keep wide live past extra to create the pressure. + %elem = waveasm.extract %wide[0] : !waveasm.vreg<4, 4> -> !waveasm.vreg + %sum = waveasm.v_add_u32 %elem, %extra : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-multiple.mlir b/waveasm/test/Transforms/cross-class-spill-multiple.mlir new file mode 100644 index 0000000000..b66ce1299b --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-multiple.mlir @@ -0,0 +1,33 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=5 max-agprs=8" %s | FileCheck %s +// +// Test: Multiple values are spilled to different AGPRs when VGPR pressure +// exceeds the limit by more than one register. +// +// With max-vgprs=5: v0, v1 precolored; v14, v15 outside range (>= 5). +// Only v2, v3, v4 are allocatable. We create 4 virtual VGPRs with +// overlapping liveness, so at least 1 must be spilled. The long-lived +// values r1 and r2 compete for VGPRs while r3 and r4 also need them. +// +// NOTE: two spilled values must not feed the same instruction because +// both would reload into the single scratch v14. This test avoids that. + +// CHECK-LABEL: waveasm.program @multiple_spills +waveasm.program @multiple_spills target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // Four virtual VGPRs; r1, r2 are long-lived. + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r2 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r3 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r4 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // At least one spill write expected. + // CHECK: waveasm.v_accvgpr_write_b32 + // Use spilled values one at a time (avoid double-reload into same scratch). + %s1 = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %s2 = waveasm.v_add_u32 %r2, %r4 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %final = waveasm.v_add_u32 %s1, %s2 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir b/waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir new file mode 100644 index 0000000000..54e93a1d4d --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-no-spill-needed.mlir @@ -0,0 +1,17 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=8 max-agprs=4" %s | FileCheck %s +// +// Test: No spill needed when VGPRs fit within the limit. +// With max-vgprs=8 and only 3 virtual VGPRs needed, no spill should occur. + +// CHECK-LABEL: waveasm.program @no_spill_needed +// CHECK-NOT: v_accvgpr_write_b32 +// CHECK-NOT: v_accvgpr_read_b32 +waveasm.program @no_spill_needed target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + %r2 = waveasm.v_add_u32 %r1, %v0 : !waveasm.vreg, !waveasm.pvreg<0> -> !waveasm.vreg + %r3 = waveasm.v_add_u32 %r2, %v1 : !waveasm.vreg, !waveasm.pvreg<1> -> !waveasm.vreg + %sum = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-no-tied.mlir b/waveasm/test/Transforms/cross-class-spill-no-tied.mlir new file mode 100644 index 0000000000..35c0ba548d --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-no-tied.mlir @@ -0,0 +1,31 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=4" %s | FileCheck %s +// +// Test: Tied values (MFMA accumulator -> result) are NOT spill candidates. +// The spill should pick an untied value instead. +// +// With max-vgprs=4: v0, v1 precolored; v14, v15 reserved. Only v2, v3 +// allocatable. Three values are live simultaneously -> one must spill. +// The spill must NOT be a tied value. + +// CHECK-LABEL: waveasm.program @no_tied_spill +waveasm.program @no_tied_spill target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // Long-lived untied value. + %addr = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // Second value creating pressure. + %tmp = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // Third value that triggers the spill. + %extra = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // Use all three -> one must be spilled. Verify that spill/reload appear. + // CHECK: waveasm.v_accvgpr_write_b32 + // CHECK: waveasm.v_accvgpr_read_b32 + %s1 = waveasm.v_add_u32 %addr, %tmp : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %s2 = waveasm.v_add_u32 %s1, %extra : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir b/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir new file mode 100644 index 0000000000..f7bd324abf --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-sgpr-unsupported.mlir @@ -0,0 +1,20 @@ +// RUN: not waveasm-translate --waveasm-linear-scan="max-sgprs=4 max-vgprs=8" %s 2>&1 | FileCheck %s +// +// Test: SGPR overflow triggers a hard error because SGPR -> VGPR spilling +// is not yet implemented. The allocator does not attempt cross-class +// eviction for SGPRs (altPool is nullptr). + +// CHECK: error: Failed to allocate SGPR +waveasm.program @sgpr_overflow target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %s0 = waveasm.precolored.sreg 0 : !waveasm.psreg<0> + %s1 = waveasm.precolored.sreg 1 : !waveasm.psreg<1> + %s2 = waveasm.precolored.sreg 2 : !waveasm.psreg<2> + %s3 = waveasm.precolored.sreg 3 : !waveasm.psreg<3> + + // All 4 SGPRs precolored. Virtual SGPRs have nowhere to go. + %r1 = waveasm.s_mul_i32 %s0, %s1 : !waveasm.psreg<0>, !waveasm.psreg<1> -> !waveasm.sreg + %r2 = waveasm.s_mul_i32 %s2, %s3 : !waveasm.psreg<2>, !waveasm.psreg<3> -> !waveasm.sreg + %r3 = waveasm.s_mul_i32 %r1, %r2 : !waveasm.sreg, !waveasm.sreg -> !waveasm.sreg + + waveasm.s_endpgm +} diff --git a/waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir b/waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir new file mode 100644 index 0000000000..1b37476157 --- /dev/null +++ b/waveasm/test/Transforms/cross-class-spill-vgpr-to-agpr.mlir @@ -0,0 +1,37 @@ +// RUN: waveasm-translate --waveasm-linear-scan="max-vgprs=4 max-agprs=4" %s | FileCheck %s +// +// Test: VGPR -> AGPR cross-class spilling. +// +// This program needs 3 virtual VGPRs but only 2 are allocatable (v2, v3) +// given max-vgprs=4 with v0/v1 precolored and v14/v15 reserved as scratch. +// The allocator evicts the longest-lived value (r1) to AGPR a0 and inserts +// v_accvgpr_write/read pairs around its def and uses. + +// CHECK-LABEL: waveasm.program @spill_vgpr_to_agpr +waveasm.program @spill_vgpr_to_agpr target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // r1 is live across r2 and r3, but only 2 VGPRs are free. + // The allocator should spill r1 to AGPR. + + // CHECK: [[R1:%.*]] = waveasm.v_add_u32 {{.*}} -> !waveasm.pvreg<[[R1IDX:[0-9]+]]> + %r1 = waveasm.v_add_u32 %v0, %v1 : !waveasm.pvreg<0>, !waveasm.pvreg<1> -> !waveasm.vreg + + // CHECK: waveasm.v_accvgpr_write_b32 [[R1]], {{.*}} : !waveasm.pvreg<[[R1IDX]]>, !waveasm.pareg<[[ASPILL:[0-9]+]]> + // CHECK: [[RELOAD1:%.*]] = waveasm.v_accvgpr_read_b32 {{.*}} : !waveasm.pareg<[[ASPILL]]> -> !waveasm.pvreg<14> + + // CHECK: waveasm.v_add_u32 [[RELOAD1]], {{.*}} -> !waveasm.pvreg< + %r2 = waveasm.v_add_u32 %r1, %v0 : !waveasm.vreg, !waveasm.pvreg<0> -> !waveasm.vreg + + // CHECK: waveasm.v_add_u32 {{.*}} -> !waveasm.pvreg< + %r3 = waveasm.v_add_u32 %r2, %v1 : !waveasm.vreg, !waveasm.pvreg<1> -> !waveasm.vreg + + // r1 is used again here, so a second reload is needed. + // CHECK: [[RELOAD2:%.*]] = waveasm.v_accvgpr_read_b32 {{.*}} : !waveasm.pareg<[[ASPILL]]> -> !waveasm.pvreg<14> + // CHECK: waveasm.v_add_u32 [[RELOAD2]], {{.*}} -> !waveasm.pvreg< + %sum1 = waveasm.v_add_u32 %r1, %r3 : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + + // CHECK: waveasm.s_endpgm + waveasm.s_endpgm +}