Skip to content

Commit cd1a2ef

Browse files
committed
WaveASM dwordx4
Signed-off-by: xintin <gaurav.verma@amd.com>
1 parent be40df5 commit cd1a2ef

17 files changed

Lines changed: 529 additions & 62 deletions

File tree

examples/python/7.1_schedule.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -465,15 +465,21 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_wide_stores(
465465

466466
def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm(
467467
is_debug=False,
468-
shape=(1024, 1024, 8192),
469-
block=(128, 256, 256),
468+
shape=(1024, 3072, 8192),
469+
block=(128, 128, 256),
470470
eliminate_epilogue=False,
471471
):
472-
"""Preshuffle-B MXFP4 GEMM with dynamic M, N, K."""
473-
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
474-
shape, block, wave_shape=(1, 4), reorder_workgroups=False
472+
"""Preshuffle-B MXFP4 GEMM with coalesced dwordx4 stores (WaveASM backend).
473+
474+
Same kernel as the LLVM coalesced-stores test but compiled through the
475+
C++ WaveASM backend. Emits v_permlane16_swap_b32 + buffer_store_dwordx4.
476+
"""
477+
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b_wide_store(
478+
shape,
479+
block,
480+
wave_shape=(1, 4),
481+
reorder_workgroups=True,
475482
)
476-
# Make M, N, K dynamic so the compiler does not specialize on problem size.
477483
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
478484
for sym in dynamic_symbols:
479485
del options.subs[sym]
@@ -483,18 +489,17 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm(
483489
options.use_wave_asm_backend = True
484490
options.wave_runtime = True
485491
options.eliminate_epilogue = eliminate_epilogue
486-
options.dump_intermediates = "build/intermediates/"
492+
options.coalesce_epilogue_stores = True
493+
options._skip_vgpr_compaction = True
487494
schedule = get_mxfp4_asymmetric_schedule(
488495
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
489496
)
490497
options.print_ir_after = "all" if is_debug else []
491498
options = set_default_run_config(options)
492499
gemm = wave_compile(options, gemm, schedule)
493500

494-
_run_mxfp_gemm_preshuffle(gemm, shape, all=True)
495-
print(
496-
"MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!"
497-
)
501+
_run_mxfp_gemm_preshuffle(gemm, shape, all=True, output_dtype=torch.bfloat16)
502+
print("MXFP GEMM preshuffle-B 4-wave dwordx4 (WaveASM backend) test passed!")
498503

499504

500505
if __name__ == "__main__":

wave_lang/kernel/compiler/wave_codegen/read_write.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,9 +1322,19 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13221322

13231323
use_llvm_store = flags != MemoryAccessFlags.NONE
13241324

1325+
is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE
1326+
is_bf16 = isinstance(element_type, BF16Type)
1327+
1328+
if (
1329+
is_bf16
1330+
and not is_shared
1331+
and emitter.options.use_buffer_ops
1332+
and emitter.options.backend == "asm"
1333+
):
1334+
mask = None
1335+
13251336
if getattr(node, "_permlane_pack_global", False):
1326-
is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE
1327-
if not is_shared and isinstance(element_type, BF16Type):
1337+
if not is_shared and is_bf16:
13281338
_write_permlane_pack_to_global(
13291339
emitter,
13301340
insert_vector,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
"""
6+
Graph pass that coalesces epilogue bf16 stores via permlane16_swap.
7+
8+
Marks eligible Write nodes so the codegen combines each thread's 4 bf16
9+
values with its partner lane's (16 lanes apart) via v_permlane16_swap_b32,
10+
producing 8 consecutive bf16 written as a single buffer_store_dwordx4.
11+
No LDS staging or barriers required.
12+
13+
Precondition: the output memory must have M as the innermost (contiguous)
14+
dimension (i.e. transpose_output=True producing [N, M] layout) so that 8
15+
consecutive bf16 elements span 8 adjacent M rows.
16+
"""
17+
18+
from .._support.tracing import CapturedTrace
19+
from ..lang.global_symbols import GLOBAL_ADDRESS_SPACE
20+
from ..ops.wave_ops import Write, get_custom
21+
from .region_canonicalization import RegionFormat, requires_region_format
22+
from .utils.symbol_utils import subs_idxc
23+
24+
25+
@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS)
26+
def coalesce_epilogue_stores(trace: CapturedTrace):
27+
"""Tag epilogue bf16 global writes for permlane16_swap packing.
28+
29+
Walks the root graph and sets ``_permlane_pack_global = True`` on
30+
every Write node that targets global memory with bf16 dtype.
31+
The codegen in ``_write_permlane_pack_to_global`` handles the rest.
32+
"""
33+
import wave_lang.kernel.lang as tkl
34+
35+
root_graph = trace.get_root_graph()
36+
37+
for node in root_graph.nodes:
38+
if node.op != "call_function":
39+
continue
40+
custom = get_custom(node)
41+
if not isinstance(custom, Write):
42+
continue
43+
mem_type = custom.memory_type
44+
if (
45+
subs_idxc(mem_type.address_space) == GLOBAL_ADDRESS_SPACE
46+
and mem_type.dtype == tkl.bf16
47+
):
48+
node._permlane_pack_global = True

wave_lang/kernel/wave/compile.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,11 @@ def build_graph_passes(
534534

535535
graph_passes.append(partial(coalesce_wide_stores, trace))
536536

537+
if options.coalesce_epilogue_stores:
538+
from .coalesce_epilogue_stores import coalesce_epilogue_stores
539+
540+
graph_passes.append(partial(coalesce_epilogue_stores, trace))
541+
537542
graph_passes += [
538543
partial(simplify_indices, trace, launchable.constraints),
539544
partial(
@@ -1370,7 +1375,11 @@ def _generate_asm_code(mb, options):
13701375
"--waveasm-scc-spill-reload",
13711376
"--waveasm-scc-verifier",
13721377
"--waveasm-linear-scan=max-vgprs=256 max-agprs=256",
1373-
"--waveasm-vgpr-compaction",
1378+
*(
1379+
[]
1380+
if getattr(options, "_skip_vgpr_compaction", False)
1381+
else ["--waveasm-vgpr-compaction"]
1382+
),
13741383
waitcnt_flag,
13751384
f"--waveasm-hazard-mitigation=target={options.target}",
13761385
"--emit-assembly",

wave_lang/kernel/wave/compile_options.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class WaveCompileOptions:
105105
specialize: bool = False
106106
eliminate_epilogue: bool = False
107107

108+
coalesce_epilogue_stores: bool = False
109+
108110
# Cluster barrier signal/wait delay in number of loop iterations
109111
# None - no barriers inside the loop
110112
# 0 - signal and wait on same iteration
@@ -118,11 +120,9 @@ class WaveCompileOptions:
118120
# keep read linearization without annotating every buffer.
119121
allow_noncontiguous_runtime_buffers: bool = False
120122

121-
# Dynamic strides are enabled whenever wave_runtime is active,
122-
# supported by both LLVM and waveasm backends.
123123
@property
124124
def dynamic_strides(self) -> bool:
125-
return self.wave_runtime
125+
return self.wave_runtime and self.backend == "llvm"
126126

127127
# === Print options ===
128128
mlir_print_ir_after_all: bool = False

wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,6 @@ def _get_tagged_mxfp4_gemm_preshuffle_b_impl(
404404
K_PACKED = tkl.sym.K_PACKED
405405
K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED
406406

407-
if wide_stores:
408-
m_symbol = tkl.sym.m_symbol
409-
n_symbol = tkl.sym.n_symbol
410-
411407
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
412408
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
413409
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
@@ -426,6 +422,8 @@ def _get_tagged_mxfp4_gemm_preshuffle_b_impl(
426422
constraints += [tkw.Assumption(K > BLOCK_K * 6)]
427423

428424
if wide_stores:
425+
m_symbol = tkl.sym.m_symbol
426+
n_symbol = tkl.sym.n_symbol
429427
constraints += [tkw.IteratorBindings({m_symbol: M, n_symbol: N})]
430428
constraints += [tkw.Assumption(Eq(M % BLOCK_M, 0))]
431429
constraints += [tkw.Assumption(Eq(N % BLOCK_N, 0))]

waveasm/include/waveasm/Dialect/WaveASMOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,21 @@ def WaveASM_V_READFIRSTLANE_B32 : WAVEASMOp<"v_readfirstlane_b32", [Pure]> {
521521
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)";
522522
}
523523

524+
// Lane swap operations (VGPR <-> VGPR across lanes)
525+
def WaveASM_V_PERMLANE16_SWAP_B32 : WAVEASMOp<"v_permlane16_swap_b32", [Pure]> {
526+
let summary = "Swap VGPR values between lanes 16 apart";
527+
let description = [{
528+
Exchanges a 32-bit value between paired lanes that are 16 positions apart.
529+
Lane i swaps with lane i^16 within each 32-lane half-wave.
530+
The hardware writes the swapped value to dst AND clobbers src.
531+
The handler must ensure the original source value is preserved in a
532+
separate register before invoking this instruction.
533+
}];
534+
let arguments = (ins WaveASM_AnyVGPR:$src);
535+
let results = (outs WaveASM_AnyVGPR:$dst);
536+
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)";
537+
}
538+
524539
// Bit operations
525540
def WaveASM_V_NOT_B32 : VALUUnaryOp<"v_not_b32">;
526541
def WaveASM_V_NOT_B64 : VALUUnaryOp<"v_not_b64">;

waveasm/include/waveasm/Transforms/TranslateFromMLIR.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,24 @@ class ValueMapper {
5252
return valueMap.contains(mlirValue);
5353
}
5454

55+
/// Map a sub-element of a struct-typed MLIR value (for llvm.extractvalue).
56+
void setExtraMapping(mlir::Value structVal, int64_t index,
57+
mlir::Value elemVal) {
58+
extraMap[{structVal, index}] = elemVal;
59+
}
60+
61+
/// Get a sub-element of a struct-typed MLIR value.
62+
std::optional<mlir::Value> getExtraMapping(mlir::Value structVal,
63+
int64_t index) const {
64+
auto it = extraMap.find({structVal, index});
65+
if (it != extraMap.end())
66+
return it->second;
67+
return std::nullopt;
68+
}
69+
5570
private:
5671
llvm::DenseMap<mlir::Value, mlir::Value> valueMap;
72+
llvm::DenseMap<std::pair<mlir::Value, int64_t>, mlir::Value> extraMap;
5773
};
5874

5975
//===----------------------------------------------------------------------===//

waveasm/lib/Transforms/AssemblyEmitter.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,71 @@ std::optional<std::string> KernelGenerator::generateOp(Operation *op) {
972972
return formatter.format("v_cvt_pk_bf16_f32", operands);
973973
})
974974

975+
// V_PERMLANE16_SWAP_B32: swap lanes 16 apart.
976+
// The hardware clobbers BOTH dst and src. When the allocator assigns
977+
// dst==src, we must save the original to a scratch register, swap
978+
// through another scratch, then restore the original.
979+
.Case<V_PERMLANE16_SWAP_B32>(
980+
[&](V_PERMLANE16_SWAP_B32 swapOp) -> std::optional<std::string> {
981+
std::string dst = resolveValue(swapOp.getDst());
982+
std::string src = resolveValue(swapOp.getSrc());
983+
if (dst != src) {
984+
llvm::SmallVector<std::string> operands = {dst, src};
985+
return formatter.format("v_permlane16_swap_b32", operands);
986+
}
987+
// dst==src: save original, swap through scratch, restore original
988+
std::string scratch0 = formatVGPRRange(kScratchVGPR, 1);
989+
std::string scratch1 = formatVGPRRange(kScratchVGPR - 1, 1);
990+
peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 1);
991+
invalidateScratchCache();
992+
// 1. Save original src to scratch0
993+
// 2. Copy src to scratch1 for the swap
994+
// 3. Swap: dst gets partner's scratch1, scratch1 clobbered
995+
// 4. Restore original from scratch0 back to src
996+
return " v_mov_b32 " + scratch0 + ", " + src + "\n" +
997+
" v_mov_b32 " + scratch1 + ", " + src + "\n" +
998+
" v_permlane16_swap_b32 " + dst + ", " + scratch1 + "\n" +
999+
" v_mov_b32 " + src + ", " + scratch0;
1000+
})
1001+
1002+
// V_ACCVGPR_READ_B32: unroll multi-register reads into scalar ops
1003+
.Case<V_ACCVGPR_READ_B32>(
1004+
[&](V_ACCVGPR_READ_B32 readOp) -> std::optional<std::string> {
1005+
Value dst = readOp.getDst();
1006+
Value src = readOp.getSrc();
1007+
int64_t dstSize = getRegSize(dst.getType());
1008+
int64_t srcSize = getRegSize(src.getType());
1009+
int64_t size = std::max(dstSize, srcSize);
1010+
if (size <= 1) {
1011+
return emitDefaultFormat(readOp, "v_accvgpr_read_b32");
1012+
}
1013+
int64_t dstBase = -1, srcBase = -1;
1014+
if (auto pv = dyn_cast<PVRegType>(dst.getType()))
1015+
dstBase = pv.getIndex();
1016+
else if (isVirtualRegType(dst.getType()))
1017+
dstBase = mapping.getPhysReg(dst);
1018+
if (auto pa = dyn_cast<PARegType>(src.getType()))
1019+
srcBase = pa.getIndex();
1020+
else if (isVirtualRegType(src.getType()))
1021+
srcBase = mapping.getPhysReg(src);
1022+
if (dstBase < 0 || srcBase < 0) {
1023+
llvm::errs() << "V_ACCVGPR_READ_B32 fallback: dstBase=" << dstBase
1024+
<< " srcBase=" << srcBase << " dstSize=" << dstSize
1025+
<< " srcSize=" << srcSize
1026+
<< " dstType=" << dst.getType()
1027+
<< " srcType=" << src.getType() << "\n";
1028+
return emitDefaultFormat(readOp, "v_accvgpr_read_b32");
1029+
}
1030+
std::string lines;
1031+
for (int64_t i = 0; i < size; ++i) {
1032+
if (i > 0)
1033+
lines += "\n";
1034+
lines += " v_accvgpr_read_b32 v" + std::to_string(dstBase + i) +
1035+
", a" + std::to_string(srcBase + i);
1036+
}
1037+
return lines;
1038+
})
1039+
9751040
// Carry ops: on GFX9, carry-out is implicit VCC.
9761041
// v_add_co_u32: dst, vcc, src0, src1
9771042
// v_addc_co_u32: dst, vcc, src0, src1, vcc (carry-in).

0 commit comments

Comments
 (0)