Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
436 changes: 436 additions & 0 deletions docs/wave/cross-class-spilling.md

Large diffs are not rendered by default.

36 changes: 31 additions & 5 deletions waveasm/include/waveasm/Transforms/RegAlloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@

namespace waveasm {

//===----------------------------------------------------------------------===//
// Spill Record
//===----------------------------------------------------------------------===//

/// Records a cross-class spill decision made during register allocation.
/// The victim value is evicted from its original register class into the
/// alternate class (e.g. VGPR -> AGPR). The pass uses these records to
/// insert spill/reload ops after allocation completes.
struct SpillRecord {
mlir::Value victim; ///< The SSA value being spilled.
int64_t sourcePhysReg; ///< Physical register freed in original class.
int64_t targetPhysReg; ///< Physical register allocated in target class.
RegClass sourceClass; ///< Original register class (e.g. VGPR).
RegClass targetClass; ///< Target register class (e.g. AGPR).
};

//===----------------------------------------------------------------------===//
// Allocation Statistics
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -206,6 +222,10 @@ class RegPool {
}
}

/// Check if the pool has at least one free register.
bool hasFree() const { return free.any(); }

/// Get peak usage.
int64_t getPeakUsage() const { return peak; }

int64_t getCurrentUsage() const { return currentUsage; }
Expand Down Expand Up @@ -294,11 +314,17 @@ class LinearScanRegAlloc {
vgprStrategy = std::move(strategy);
}

/// Run allocation on a kernel program
/// Returns the physical mapping and statistics, or failure if allocation
/// fails
mlir::FailureOr<std::pair<PhysicalMapping, AllocationStats>>
allocate(ProgramOp program);
/// Result bundle returned by allocate().
struct AllocResult {
PhysicalMapping mapping;
AllocationStats stats;
llvm::SmallVector<SpillRecord> spills;
};

/// Run allocation on a kernel program.
/// Returns the physical mapping, statistics, and any cross-class spill
/// records, or failure if allocation fails.
mlir::FailureOr<AllocResult> allocate(ProgramOp program);

private:
/// Process active ranges, expiring those that end before currentPoint
Expand Down
20 changes: 17 additions & 3 deletions waveasm/lib/Transforms/HazardMitigation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
// Hazard Mitigation Pass - Insert s_nop instructions for hardware hazards
//
// This pass handles hardware-specific hazards that require NOP insertion:
// - VALU → v_readfirstlane hazard (gfx940+)
// - Trans → non-Trans VALU forwarding hazard (gfx940+)
// - VALU -> v_readfirstlane hazard (gfx940+)
// - Trans -> non-Trans VALU forwarding hazard (gfx940+)
// - v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+)
//===----------------------------------------------------------------------===//

#include "waveasm/Dialect/WaveASMAttrs.h"
Expand Down Expand Up @@ -74,7 +75,7 @@ bool isVALUOp(Operation *op) {
return true;
}

/// Check if an operation is v_readfirstlane
/// Check if an operation is v_readfirstlane.
bool isReadfirstlaneOp(Operation *op) { return isa<V_READFIRSTLANE_B32>(op); }

/// Check if an operation does NOT emit an assembly instruction.
Expand All @@ -99,6 +100,9 @@ Operation *findPrecedingEmittingOp(ArrayRef<Operation *> ops, size_t start) {
return nullptr;
}

/// Check if an operation is v_accvgpr_read_b32 (AGPR -> VGPR move).
bool isAccVgprReadOp(Operation *op) { return isa<V_ACCVGPR_READ_B32>(op); }

/// Check if an operation is a transcendental instruction (uses the Trans
/// pipeline which has different latency characteristics from the main VALU).
bool isTransOp(Operation *op) {
Expand Down Expand Up @@ -214,6 +218,16 @@ struct HazardMitigationPass
if (pred && isTransOp(pred) && hasVGPRConflict(pred, op))
insertionPoints.push_back(op);
}

// Check for v_accvgpr_read_b32 -> VALU RAW hazard (gfx940+).
// On GFX950 the VGPR destination of v_accvgpr_read_b32 is not
// immediately available; a VALU that consumes it in the next cycle
// reads stale data. Insert s_nop 0 to cover the 1-cycle wait.
if (isVALUOp(op) && i > 0) {
Operation *pred = findPrecedingEmittingOp(ops, i);
if (pred && isAccVgprReadOp(pred) && hasVGPRConflict(pred, op))
insertionPoints.push_back(op);
}
}

// Insert s_nop instructions
Expand Down
115 changes: 109 additions & 6 deletions waveasm/lib/Transforms/LinearScanPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,101 @@ struct LinearScanPass
/// Get the accumulator operand from an MFMA op using the interface.
/// Returns nullptr if the operation is not an MFMA.
Value getMFMAAccumulator(Operation *op) {
if (auto mfmaOp = dyn_cast<MFMAOpInterface>(op)) {
if (auto mfmaOp = dyn_cast<MFMAOpInterface>(op))
return mfmaOp.getAcc();
}
return nullptr;
}

/// Scratch VGPR index used for cross-class spill reloads.
static constexpr int64_t kSpillScratchVGPR = 14;

/// Insert spill/reload ops for cross-class evictions.
///
/// For each VGPR->AGPR spill:
/// - After the victim's def: v_accvgpr_write_b32 vSRC, aDST
/// - Before each use: v_accvgpr_read_b32 aSRC -> vSCRATCH
/// and rewrite the use to consume the reload result.
///
/// The ops are created with physical register types so the subsequent
/// type transformation pass does not need to touch them.
LogicalResult insertSpillReloads(ProgramOp program,
ArrayRef<SpillRecord> spills,
PhysicalMapping &mapping) {
// Validate that all spills use the implemented VGPR -> AGPR direction.
// SGPR -> VGPR and AGPR -> VGPR are wired in the allocator but not yet
// supported by the spill insertion logic.
for (const SpillRecord &sr : spills) {
if (sr.sourceClass != RegClass::VGPR ||
sr.targetClass != RegClass::AGPR) {
return program.emitOpError()
<< "unsupported spill direction: only VGPR -> AGPR is "
"currently implemented";
}
}

// Build a lookup from victim Value -> SpillRecord.
llvm::DenseMap<Value, const SpillRecord *> spillMap;
for (const SpillRecord &sr : spills)
spillMap[sr.victim] = &sr;

MLIRContext *ctx = program.getContext();

// For each spill, create a precolored AGPR value to use as the
// source/destination in read/write ops. This gives us an SSA handle
// to thread through the spill ops.
//
// We insert a PrecoloredARegOp at program entry to materialise the
// AGPR "slot". It does not generate any assembly; it just provides
// an SSA value with the right physical type for the dialect ops.
llvm::DenseMap<Value, Value> spillSlots; // victim -> AGPR SSA value
{
OpBuilder entryBuilder(ctx);
Block &entry = program.getBodyBlock();
entryBuilder.setInsertionPointToStart(&entry);
for (const SpillRecord &sr : spills) {
auto physType = PARegType::get(ctx, sr.targetPhysReg, 1);
Value slot = PrecoloredARegOp::create(entryBuilder, program.getLoc(),
physType, sr.targetPhysReg,
/*size=*/1);
spillSlots[sr.victim] = slot;
}
}

// Collect all ops in program order.
llvm::SmallVector<Operation *> ops;
collectOpsRecursive(program.getBodyBlock(), ops);

for (Operation *op : ops) {
// --- Insert spills after defs. ---
for (Value result : op->getResults()) {
auto it = spillMap.find(result);
if (it == spillMap.end())
continue;
const SpillRecord &sr = *it->second;
OpBuilder builder(ctx);
builder.setInsertionPointAfter(op);
Value slot = spillSlots[sr.victim];
V_ACCVGPR_WRITE_B32::create(builder, op->getLoc(), result, slot);
}

// --- Insert reloads before uses. ---
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
Value operand = op->getOperand(i);
auto it = spillMap.find(operand);
if (it == spillMap.end())
continue;
OpBuilder builder(op);
auto scratchType = PVRegType::get(ctx, kSpillScratchVGPR, 1);
Value slot = spillSlots[operand];
Value reloaded = V_ACCVGPR_READ_B32::create(builder, op->getLoc(),
scratchType, slot);
op->setOperand(i, reloaded);
}
}

return success();
}

LogicalResult processProgram(ProgramOp program) {
// Collect precolored values from precolored.vreg and precolored.sreg ops
llvm::DenseMap<Value, int64_t> precoloredValues;
Expand All @@ -291,6 +380,11 @@ struct LinearScanPass
// generates v_mov_b32 v15, <literal> before such instructions.
reservedVGPRs.insert(15);

// Reserve v14 as scratch VGPR for cross-class spill reloads.
// When a VGPR is spilled to an AGPR, reloads use v_accvgpr_read_b32
// into this scratch before the consuming instruction.
reservedVGPRs.insert(14);

// ABI SGPRs (kernarg ptr, preload regs, workgroup IDs, SRDs) are
// reserved via PrecoloredSRegOp ops emitted during translation. The
// collection loop below picks those up and adds their indices to
Expand Down Expand Up @@ -404,13 +498,12 @@ struct LinearScanPass
allocator.setVGPRStrategy(
std::make_unique<BidirectionalStrategy>(bidirectionalThreshold));

// Run allocation
// Run allocation.
auto result = allocator.allocate(program);
if (failed(result)) {
if (failed(result))
return failure();
}

auto [mapping, stats] = *result;
auto &[mapping, stats, spills] = *result;

// Handle waveasm.extract ops: result = source[offset].
// Set the extract result's physical register = source's physReg + offset.
Expand Down Expand Up @@ -497,6 +590,16 @@ struct LinearScanPass
if (packResult.wasInterrupted())
return failure();

// Insert cross-class spill/reload ops for any evicted values.
// For each spill record (e.g. VGPR -> AGPR):
// - After the victim's def: v_accvgpr_write_b32 aX, vY (spill)
// - Before each use: v_accvgpr_read_b32 v14, aX (reload)
// and rewrite the use to consume v14 instead of vY.
if (!spills.empty()) {
if (failed(insertSpillReloads(program, spills, mapping)))
return failure();
}

// Transform the IR: replace virtual register types with physical types
OpBuilder builder(program.getContext());

Expand Down
Loading
Loading