Skip to content

Commit f01106c

Browse files
committed
remove redundant writes
Signed-off-by: xintin <gaurav.verma@amd.com>
1 parent 2e578c6 commit f01106c

2 files changed

Lines changed: 173 additions & 78 deletions

File tree

wave_lang/kernel/compiler/wave_codegen/read_write.py

Lines changed: 143 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _split_index(
117117
# Compute thread-independent index as `orig_index - thread_dependent_index`
118118
# All thread symbols and dynamic should cancel-out in the result.
119119
diff = src - thread_dependent_index
120-
# Avoid sympy.simplify on Piecewise expressions it recurses into boolean
120+
# Avoid sympy.simplify on Piecewise expressions : it recurses into boolean
121121
# condition simplification and can hang for complex dynamic-shape indices.
122122
# expand() handles basic polynomial cancellation and is O(fast).
123123
if isinstance(diff, sympy.Basic) and diff.has(sympy.Piecewise):
@@ -575,7 +575,7 @@ def _cast_buffer_and_encode_stride(
575575
stride_int = _get_constant_value(stride_candidate)
576576
# Emit swizzle stride for both static and dynamic cases.
577577
# Static: only if stride fits in signed i14 (max 8192).
578-
# Dynamic: always emit the SRD swizzle encoding is constant
578+
# Dynamic: always emit : the SRD swizzle encoding is constant
579579
# (0x40400000 + 0x27000) regardless of the actual stride value.
580580
if stride_int is None or stride_int <= 8192:
581581
swizzle_stride = arith_d.index_cast(uint14, stride_candidate)
@@ -1326,18 +1326,49 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13261326
if getattr(node, "_permlane_pack_global", False):
13271327
is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE
13281328
if not is_shared and isinstance(element_type, BF16Type):
1329-
_write_permlane_pack_to_global(
1330-
emitter,
1331-
insert_vector,
1332-
kb_dest,
1333-
output_shape,
1334-
start_indices,
1335-
start_indices_wg,
1336-
start_indices_th,
1337-
get_custom(memory),
1338-
index,
1329+
role = getattr(node, "_permlane_pack_role", "unpaired")
1330+
1331+
if role == "first":
1332+
node._stashed_codegen = {
1333+
"insert_vector": insert_vector,
1334+
"kb_dest": kb_dest,
1335+
"output_shape": output_shape,
1336+
"start_indices": start_indices,
1337+
"start_indices_wg": start_indices_wg,
1338+
"start_indices_th": start_indices_th,
1339+
"memory_custom": get_custom(memory),
1340+
"index": index,
1341+
}
1342+
return
1343+
1344+
if role == "second":
1345+
partner = node._permlane_partner
1346+
s = partner._stashed_codegen
1347+
_write_permlane_pair_to_global(
1348+
emitter,
1349+
s["insert_vector"],
1350+
insert_vector,
1351+
s["kb_dest"],
1352+
kb_dest,
1353+
s["output_shape"],
1354+
output_shape,
1355+
s["start_indices"],
1356+
s["start_indices_wg"],
1357+
s["start_indices_th"],
1358+
start_indices,
1359+
start_indices_wg,
1360+
start_indices_th,
1361+
s["memory_custom"],
1362+
get_custom(memory),
1363+
s["index"],
1364+
index,
1365+
)
1366+
return
1367+
1368+
assert False, (
1369+
"Unexpected unpaired wide-store write. "
1370+
"coalesce_wide_stores should pair all eligible writes."
13391371
)
1340-
return
13411372

13421373
if use_llvm_store:
13431374
_create_llvm_read_write(
@@ -1360,52 +1391,58 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13601391
)
13611392

13621393

1363-
def _write_permlane_pack_to_global(
1394+
def _write_permlane_pair_to_global(
13641395
emitter: WaveEmitter,
1365-
insert_vector: Value,
1366-
kb_dest: Value,
1367-
output_shape: tuple,
1368-
start_indices: tuple,
1369-
start_indices_wg: tuple,
1370-
start_indices_th: tuple,
1371-
memory_custom,
1372-
index: dict,
1396+
vec_a: Value,
1397+
vec_b: Value,
1398+
kb_dest_a: Value,
1399+
kb_dest_b: Value,
1400+
output_shape_a: tuple,
1401+
output_shape_b: tuple,
1402+
start_indices_a: tuple,
1403+
start_indices_wg_a: tuple,
1404+
start_indices_th_a: tuple,
1405+
start_indices_b: tuple,
1406+
start_indices_wg_b: tuple,
1407+
start_indices_th_b: tuple,
1408+
memory_custom_a,
1409+
memory_custom_b,
1410+
index_a: dict,
1411+
index_b: dict,
13731412
):
1374-
"""Pack two lanes' bf16 values via permlane16_swap for wide global stores.
1413+
"""Pair two tile groups via permlane16_swap for duplicate-free wide stores.
13751414
1376-
Uses ``v_permlane16_swap_b32`` to exchange each thread's 4 bf16 values
1377-
(packed as 2 i32 dwords) with a partner lane 16 positions apart.
1378-
The result is 8 consecutive bf16 values per lane, written as a single
1379-
``buffer_store_dwordx4`` (128 bits).
1415+
Pairs tile A and tile B by passing them as separate ``old_dst`` /
1416+
``src`` operands to ``v_permlane16_swap_b32``. Both outputs of the
1417+
swap carry distinct data, so each lane writes a *different* tile
1418+
group's wide store : no duplicate stores:
13801419
1381-
Both lane halves write identical data to the same address (benign
1382-
duplicate store), avoiding divergent control flow. The buffer
1383-
descriptor's ``valid_bytes`` handles out-of-bounds suppression.
1420+
* Lower lane (lane % 32 < 16) writes tile A:
1421+
``[own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]``
1422+
* Upper lane (lane % 32 >= 16) writes tile B:
1423+
``[partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]``
13841424
1385-
TODO: Eliminate duplicate stores by using both outputs of
1386-
``permlane16_swap``, letting each lane write the partner's assembled
1387-
data to the partner's destination address so every lane performs a
1388-
unique store.
1425+
This halves both the ``permlane16_swap`` count and the global store
1426+
count compared to the single-write approach.
13891427
1390-
Preconditions:
1391-
- The kernel must use swapped MFMA operands (B as LHS, A as RHS)
1392-
so the accumulator's 4-contiguous values align with the output
1393-
memory's stride-1 dimension.
1394-
- The Write node must be tagged with ``_permlane_pack_global=True``
1395-
by the ``coalesce_wide_stores`` pass.
1428+
Preconditions (same as ``_write_permlane_pack_to_global``):
1429+
- Swapped MFMA operands, F32_16x16x128_F8F6F4 layout, bf16 output.
1430+
- Both Write nodes tagged by ``coalesce_wide_stores``.
13961431
13971432
.. note::
1398-
Currently assumes F32_16x16x128_F8F6F4 MMA layout (4 values
1399-
along MMA-M per thread, 16-lane groups). Generalizing to other
1400-
MMA types requires parameterizing the lane group size and
1401-
elements per thread.
1433+
The store is emitted using tile A's ``output_shape``,
1434+
``memory_custom``, ``kb_dest``, and ``index``. This is correct
1435+
when both tiles target the same output buffer with identical
1436+
shape and buffer descriptor (the standard MXFP4 GEMM case).
1437+
If the two tiles ever target different buffers, the
1438+
``_create_vec_read_write`` call would need to be split into
1439+
two lane-divergent stores.
14021440
"""
1403-
vec_type = insert_vector.type
1404-
num_elems = vec_type.shape[0] if hasattr(vec_type, "shape") else 1
1405-
assert num_elems == 4, (
1406-
f"_write_permlane_pack_to_global expects 4 bf16 elements per thread "
1407-
f"(F32_16x16x128_F8F6F4 MMA layout), got {num_elems}. "
1408-
f"Other MMA types are not yet supported."
1441+
num_elems_a = vec_a.type.shape[0] if hasattr(vec_a.type, "shape") else 1
1442+
num_elems_b = vec_b.type.shape[0] if hasattr(vec_b.type, "shape") else 1
1443+
assert num_elems_a == 4 and num_elems_b == 4, (
1444+
f"_write_permlane_pair_to_global expects 4 bf16 elements per thread "
1445+
f"per tile, got {num_elems_a} and {num_elems_b}."
14091446
)
14101447

14111448
bf16_type = BF16Type.get()
@@ -1415,57 +1452,86 @@ def _write_permlane_pack_to_global(
14151452
v4i32_type = VectorType.get([4], i32_type)
14161453
v8bf16_type = VectorType.get([8], bf16_type)
14171454

1418-
i32_vec = vector_d.bitcast(v2i32_type, insert_vector)
1419-
own_lo = vector_d.extract(i32_vec, static_position=[0], dynamic_position=[])
1420-
own_hi = vector_d.extract(i32_vec, static_position=[1], dynamic_position=[])
1455+
i32_a = vector_d.bitcast(v2i32_type, vec_a)
1456+
a_lo = vector_d.extract(i32_a, static_position=[0], dynamic_position=[])
1457+
a_hi = vector_d.extract(i32_a, static_position=[1], dynamic_position=[])
1458+
1459+
i32_b = vector_d.bitcast(v2i32_type, vec_b)
1460+
b_lo = vector_d.extract(i32_b, static_position=[0], dynamic_position=[])
1461+
b_hi = vector_d.extract(i32_b, static_position=[1], dynamic_position=[])
14211462

14221463
swap_type = llvm_d.StructType.get_literal([i32_type, i32_type])
1423-
partner_lo = llvm_d.extractvalue(
1424-
i32_type, rocdl_d.permlane16_swap(swap_type, own_lo, own_lo, False, False), [0]
1425-
)
1426-
partner_hi = llvm_d.extractvalue(
1427-
i32_type, rocdl_d.permlane16_swap(swap_type, own_hi, own_hi, False, False), [0]
1428-
)
1464+
1465+
# old_dst = a, src = b → result[0] = partner's b, result[1] = partner's a
1466+
swap_lo = rocdl_d.permlane16_swap(swap_type, a_lo, b_lo, False, False)
1467+
swap_hi = rocdl_d.permlane16_swap(swap_type, a_hi, b_hi, False, False)
1468+
1469+
partner_b_lo = llvm_d.extractvalue(i32_type, swap_lo, [0])
1470+
partner_a_lo = llvm_d.extractvalue(i32_type, swap_lo, [1])
1471+
partner_b_hi = llvm_d.extractvalue(i32_type, swap_hi, [0])
1472+
partner_a_hi = llvm_d.extractvalue(i32_type, swap_hi, [1])
14291473

14301474
lane_in_wave = arith_d.remui(emitter.thread_ids[0], arith_d.constant(idx_type, 64))
14311475
half_pos = arith_d.remui(lane_in_wave, arith_d.constant(idx_type, 32))
14321476
is_lower = arith_d.cmpi(
14331477
arith_d.CmpIPredicate.ult, half_pos, arith_d.constant(idx_type, 16)
14341478
)
14351479

1436-
d0 = arith_d.select(is_lower, own_lo, partner_lo)
1437-
d1 = arith_d.select(is_lower, own_hi, partner_hi)
1438-
d2 = arith_d.select(is_lower, partner_lo, own_lo)
1439-
d3 = arith_d.select(is_lower, partner_hi, own_hi)
1480+
# Lower lane: [own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]
1481+
# Upper lane: [partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]
1482+
d0 = arith_d.select(is_lower, a_lo, partner_b_lo)
1483+
d1 = arith_d.select(is_lower, a_hi, partner_b_hi)
1484+
d2 = arith_d.select(is_lower, partner_a_lo, b_lo)
1485+
d3 = arith_d.select(is_lower, partner_a_hi, b_hi)
14401486

14411487
wide_i32 = vector_d.from_elements(v4i32_type, [d0, d1, d2, d3])
14421488
wide_vec = vector_d.bitcast(v8bf16_type, wide_i32)
14431489

1444-
elems_per_thread = arith_d.constant(idx_type, num_elems)
1490+
elems_per_thread = arith_d.constant(idx_type, 4)
1491+
1492+
# Lower lane uses tile A's address; upper lane uses tile B's address.
1493+
# Upper lane subtracts elems_per_thread from the last dim to align
1494+
# to the lower lane's column position (same as the single-write path).
1495+
adj_th = list(start_indices_th_a)
1496+
adj_full = list(start_indices_a)
1497+
for dim_idx in range(len(adj_th)):
1498+
if dim_idx == len(adj_th) - 1:
1499+
adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread)
1500+
adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread)
1501+
adj_th[dim_idx] = arith_d.select(is_lower, adj_th[dim_idx], adj_b_th)
1502+
adj_full[dim_idx] = arith_d.select(is_lower, adj_full[dim_idx], adj_b_full)
1503+
else:
1504+
adj_th[dim_idx] = arith_d.select(
1505+
is_lower, start_indices_th_a[dim_idx], start_indices_th_b[dim_idx]
1506+
)
1507+
adj_full[dim_idx] = arith_d.select(
1508+
is_lower, start_indices_a[dim_idx], start_indices_b[dim_idx]
1509+
)
14451510

1446-
adj_th = list(start_indices_th)
1447-
adj_th[-1] = arith_d.select(
1448-
is_lower, adj_th[-1], arith_d.subi(adj_th[-1], elems_per_thread)
1449-
)
1511+
adj_wg = list(start_indices_wg_a)
1512+
for dim_idx in range(len(adj_wg)):
1513+
adj_wg[dim_idx] = arith_d.select(
1514+
is_lower, start_indices_wg_a[dim_idx], start_indices_wg_b[dim_idx]
1515+
)
14501516

1451-
adj_full = list(start_indices)
1452-
adj_full[-1] = arith_d.select(
1453-
is_lower, adj_full[-1], arith_d.subi(adj_full[-1], elems_per_thread)
1454-
)
1517+
sel_output_shape = output_shape_a
1518+
sel_memory_custom = memory_custom_a
1519+
sel_kb_dest = kb_dest_a
1520+
sel_index = index_a
14551521

14561522
_create_vec_read_write(
14571523
emitter,
1458-
output_shape,
1459-
kb_dest,
1524+
sel_output_shape,
1525+
sel_kb_dest,
14601526
wide_vec,
14611527
None,
14621528
tuple(adj_full),
1463-
start_indices_wg,
1529+
tuple(adj_wg),
14641530
tuple(adj_th),
14651531
8,
1466-
memory_custom,
1532+
sel_memory_custom,
14671533
None,
1468-
node_index=index,
1534+
node_index=sel_index,
14691535
)
14701536

14711537

wave_lang/kernel/wave/wide_store_coalescing.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
codegen emits v_permlane16_swap_b32 + buffer_store_dwordx4 instead of
1414
scalar buffer_store_short.
1515
16+
Eligible writes are paired so that each lane in a (lane, lane+16)
17+
pair writes a *different* tile group's wide store, eliminating the
18+
duplicate stores that occur when both lanes write the same data.
19+
The number of eligible writes must be even (asserted).
20+
1621
Only tags writes that satisfy ALL conditions:
1722
1. Target memory is global address space
1823
2. Output dtype is bf16
@@ -34,11 +39,16 @@ def coalesce_wide_stores(trace: CapturedTrace):
3439
pattern (swapped MFMA operands, as produced by the wide_store kernel
3540
variant). Writes without source/target are left untouched, making
3641
this pass safe to run unconditionally.
42+
43+
Writes are paired for register-only deduplication: the first node
44+
in each pair stashes its codegen state, and the second triggers a
45+
paired ``permlane16_swap`` that lets each lane emit a unique store.
3746
"""
3847
import wave_lang.kernel.lang as tkl
3948

4049
root_graph = trace.get_root_graph()
4150

51+
eligible_writes = []
4252
for node in root_graph.nodes:
4353
if node.op != "call_function":
4454
continue
@@ -52,4 +62,23 @@ def coalesce_wide_stores(trace: CapturedTrace):
5262
subs_idxc(mem_type.address_space) == GLOBAL_ADDRESS_SPACE
5363
and mem_type.dtype == tkl.bf16
5464
):
55-
node._permlane_pack_global = True
65+
eligible_writes.append(node)
66+
67+
# TODO: Add a fallback path for odd number of writes.
68+
assert len(eligible_writes) % 2 == 0, (
69+
f"Expected even number of eligible wide-store writes, "
70+
f"got {len(eligible_writes)}."
71+
)
72+
73+
# Pair adjacent writes so the codegen can pass both tiles as
74+
# separate old_dst / src operands to a single permlane16_swap.
75+
# The "first" node stashes its codegen state; the "second" node
76+
# retrieves it and emits one unique store per lane (lower lane
77+
# writes tile A, upper lane writes tile B) — no duplicate stores.
78+
for first, second in zip(eligible_writes[0::2], eligible_writes[1::2]):
79+
first._permlane_pack_global = True
80+
first._permlane_pack_role = "first"
81+
first._permlane_partner = second
82+
second._permlane_pack_global = True
83+
second._permlane_pack_role = "second"
84+
second._permlane_partner = first

0 commit comments

Comments
 (0)