Skip to content

Commit 9f1c6cb

Browse files
committed
remove redundant writes
Signed-off-by: xintin <gaurav.verma@amd.com>
1 parent be40df5 commit 9f1c6cb

2 files changed

Lines changed: 173 additions & 78 deletions

File tree

wave_lang/kernel/compiler/wave_codegen/read_write.py

Lines changed: 143 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _split_index(
117117
# Compute thread-independent index as `orig_index - thread_dependent_index`
118118
# All thread symbols and dynamic should cancel-out in the result.
119119
diff = src - thread_dependent_index
120-
# Avoid sympy.simplify on Piecewise expressions it recurses into boolean
120+
# Avoid sympy.simplify on Piecewise expressions : it recurses into boolean
121121
# condition simplification and can hang for complex dynamic-shape indices.
122122
# expand() handles basic polynomial cancellation and is O(fast).
123123
if isinstance(diff, sympy.Basic) and diff.has(sympy.Piecewise):
@@ -575,7 +575,7 @@ def _cast_buffer_and_encode_stride(
575575
stride_int = _get_constant_value(stride_candidate)
576576
# Emit swizzle stride for both static and dynamic cases.
577577
# Static: only if stride fits in signed i14 (max 8192).
578-
# Dynamic: always emit the SRD swizzle encoding is constant
578+
# Dynamic: always emit : the SRD swizzle encoding is constant
579579
# (0x40400000 + 0x27000) regardless of the actual stride value.
580580
if stride_int is None or stride_int <= 8192:
581581
swizzle_stride = arith_d.index_cast(uint14, stride_candidate)
@@ -1325,18 +1325,49 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13251325
if getattr(node, "_permlane_pack_global", False):
13261326
is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE
13271327
if not is_shared and isinstance(element_type, BF16Type):
1328-
_write_permlane_pack_to_global(
1329-
emitter,
1330-
insert_vector,
1331-
kb_dest,
1332-
output_shape,
1333-
start_indices,
1334-
start_indices_wg,
1335-
start_indices_th,
1336-
get_custom(memory),
1337-
index,
1328+
role = getattr(node, "_permlane_pack_role", "unpaired")
1329+
1330+
if role == "first":
1331+
node._stashed_codegen = {
1332+
"insert_vector": insert_vector,
1333+
"kb_dest": kb_dest,
1334+
"output_shape": output_shape,
1335+
"start_indices": start_indices,
1336+
"start_indices_wg": start_indices_wg,
1337+
"start_indices_th": start_indices_th,
1338+
"memory_custom": get_custom(memory),
1339+
"index": index,
1340+
}
1341+
return
1342+
1343+
if role == "second":
1344+
partner = node._permlane_partner
1345+
s = partner._stashed_codegen
1346+
_write_permlane_pair_to_global(
1347+
emitter,
1348+
s["insert_vector"],
1349+
insert_vector,
1350+
s["kb_dest"],
1351+
kb_dest,
1352+
s["output_shape"],
1353+
output_shape,
1354+
s["start_indices"],
1355+
s["start_indices_wg"],
1356+
s["start_indices_th"],
1357+
start_indices,
1358+
start_indices_wg,
1359+
start_indices_th,
1360+
s["memory_custom"],
1361+
get_custom(memory),
1362+
s["index"],
1363+
index,
1364+
)
1365+
return
1366+
1367+
assert False, (
1368+
"Unexpected unpaired wide-store write. "
1369+
"coalesce_wide_stores should pair all eligible writes."
13381370
)
1339-
return
13401371

13411372
if use_llvm_store:
13421373
_create_llvm_read_write(
@@ -1359,52 +1390,58 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13591390
)
13601391

13611392

1362-
def _write_permlane_pack_to_global(
1393+
def _write_permlane_pair_to_global(
13631394
emitter: WaveEmitter,
1364-
insert_vector: Value,
1365-
kb_dest: Value,
1366-
output_shape: tuple,
1367-
start_indices: tuple,
1368-
start_indices_wg: tuple,
1369-
start_indices_th: tuple,
1370-
memory_custom,
1371-
index: dict,
1395+
vec_a: Value,
1396+
vec_b: Value,
1397+
kb_dest_a: Value,
1398+
kb_dest_b: Value,
1399+
output_shape_a: tuple,
1400+
output_shape_b: tuple,
1401+
start_indices_a: tuple,
1402+
start_indices_wg_a: tuple,
1403+
start_indices_th_a: tuple,
1404+
start_indices_b: tuple,
1405+
start_indices_wg_b: tuple,
1406+
start_indices_th_b: tuple,
1407+
memory_custom_a,
1408+
memory_custom_b,
1409+
index_a: dict,
1410+
index_b: dict,
13721411
):
1373-
"""Pack two lanes' bf16 values via permlane16_swap for wide global stores.
1412+
"""Pair two tile groups via permlane16_swap for duplicate-free wide stores.
13741413
1375-
Uses ``v_permlane16_swap_b32`` to exchange each thread's 4 bf16 values
1376-
(packed as 2 i32 dwords) with a partner lane 16 positions apart.
1377-
The result is 8 consecutive bf16 values per lane, written as a single
1378-
``buffer_store_dwordx4`` (128 bits).
1414+
Pairs tile A and tile B by passing them as separate ``old_dst`` /
1415+
``src`` operands to ``v_permlane16_swap_b32``. Both outputs of the
1416+
swap carry distinct data, so each lane writes a *different* tile
1417+
group's wide store : no duplicate stores:
13791418
1380-
Both lane halves write identical data to the same address (benign
1381-
duplicate store), avoiding divergent control flow. The buffer
1382-
descriptor's ``valid_bytes`` handles out-of-bounds suppression.
1419+
* Lower lane (lane % 32 < 16) writes tile A:
1420+
``[own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]``
1421+
* Upper lane (lane % 32 >= 16) writes tile B:
1422+
``[partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]``
13831423
1384-
TODO: Eliminate duplicate stores by using both outputs of
1385-
``permlane16_swap``, letting each lane write the partner's assembled
1386-
data to the partner's destination address so every lane performs a
1387-
unique store.
1424+
This halves both the ``permlane16_swap`` count and the global store
1425+
count compared to the single-write approach.
13881426
1389-
Preconditions:
1390-
- The kernel must use swapped MFMA operands (B as LHS, A as RHS)
1391-
so the accumulator's 4-contiguous values align with the output
1392-
memory's stride-1 dimension.
1393-
- The Write node must be tagged with ``_permlane_pack_global=True``
1394-
by the ``coalesce_wide_stores`` pass.
1427+
Preconditions (same as ``_write_permlane_pack_to_global``):
1428+
- Swapped MFMA operands, F32_16x16x128_F8F6F4 layout, bf16 output.
1429+
- Both Write nodes tagged by ``coalesce_wide_stores``.
13951430
13961431
.. note::
1397-
Currently assumes F32_16x16x128_F8F6F4 MMA layout (4 values
1398-
along MMA-M per thread, 16-lane groups). Generalizing to other
1399-
MMA types requires parameterizing the lane group size and
1400-
elements per thread.
1432+
The store is emitted using tile A's ``output_shape``,
1433+
``memory_custom``, ``kb_dest``, and ``index``. This is correct
1434+
when both tiles target the same output buffer with identical
1435+
shape and buffer descriptor (the standard MXFP4 GEMM case).
1436+
If the two tiles ever target different buffers, the
1437+
``_create_vec_read_write`` call would need to be split into
1438+
two lane-divergent stores.
14011439
"""
1402-
vec_type = insert_vector.type
1403-
num_elems = vec_type.shape[0] if hasattr(vec_type, "shape") else 1
1404-
assert num_elems == 4, (
1405-
f"_write_permlane_pack_to_global expects 4 bf16 elements per thread "
1406-
f"(F32_16x16x128_F8F6F4 MMA layout), got {num_elems}. "
1407-
f"Other MMA types are not yet supported."
1440+
num_elems_a = vec_a.type.shape[0] if hasattr(vec_a.type, "shape") else 1
1441+
num_elems_b = vec_b.type.shape[0] if hasattr(vec_b.type, "shape") else 1
1442+
assert num_elems_a == 4 and num_elems_b == 4, (
1443+
f"_write_permlane_pair_to_global expects 4 bf16 elements per thread "
1444+
f"per tile, got {num_elems_a} and {num_elems_b}."
14081445
)
14091446

14101447
bf16_type = BF16Type.get()
@@ -1414,57 +1451,86 @@ def _write_permlane_pack_to_global(
14141451
v4i32_type = VectorType.get([4], i32_type)
14151452
v8bf16_type = VectorType.get([8], bf16_type)
14161453

1417-
i32_vec = vector_d.bitcast(v2i32_type, insert_vector)
1418-
own_lo = vector_d.extract(i32_vec, static_position=[0], dynamic_position=[])
1419-
own_hi = vector_d.extract(i32_vec, static_position=[1], dynamic_position=[])
1454+
i32_a = vector_d.bitcast(v2i32_type, vec_a)
1455+
a_lo = vector_d.extract(i32_a, static_position=[0], dynamic_position=[])
1456+
a_hi = vector_d.extract(i32_a, static_position=[1], dynamic_position=[])
1457+
1458+
i32_b = vector_d.bitcast(v2i32_type, vec_b)
1459+
b_lo = vector_d.extract(i32_b, static_position=[0], dynamic_position=[])
1460+
b_hi = vector_d.extract(i32_b, static_position=[1], dynamic_position=[])
14201461

14211462
swap_type = llvm_d.StructType.get_literal([i32_type, i32_type])
1422-
partner_lo = llvm_d.extractvalue(
1423-
i32_type, rocdl_d.permlane16_swap(swap_type, own_lo, own_lo, False, False), [0]
1424-
)
1425-
partner_hi = llvm_d.extractvalue(
1426-
i32_type, rocdl_d.permlane16_swap(swap_type, own_hi, own_hi, False, False), [0]
1427-
)
1463+
1464+
# old_dst = a, src = b → result[0] = partner's b, result[1] = partner's a
1465+
swap_lo = rocdl_d.permlane16_swap(swap_type, a_lo, b_lo, False, False)
1466+
swap_hi = rocdl_d.permlane16_swap(swap_type, a_hi, b_hi, False, False)
1467+
1468+
partner_b_lo = llvm_d.extractvalue(i32_type, swap_lo, [0])
1469+
partner_a_lo = llvm_d.extractvalue(i32_type, swap_lo, [1])
1470+
partner_b_hi = llvm_d.extractvalue(i32_type, swap_hi, [0])
1471+
partner_a_hi = llvm_d.extractvalue(i32_type, swap_hi, [1])
14281472

14291473
lane_in_wave = arith_d.remui(emitter.thread_ids[0], arith_d.constant(idx_type, 64))
14301474
half_pos = arith_d.remui(lane_in_wave, arith_d.constant(idx_type, 32))
14311475
is_lower = arith_d.cmpi(
14321476
arith_d.CmpIPredicate.ult, half_pos, arith_d.constant(idx_type, 16)
14331477
)
14341478

1435-
d0 = arith_d.select(is_lower, own_lo, partner_lo)
1436-
d1 = arith_d.select(is_lower, own_hi, partner_hi)
1437-
d2 = arith_d.select(is_lower, partner_lo, own_lo)
1438-
d3 = arith_d.select(is_lower, partner_hi, own_hi)
1479+
# Lower lane: [own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]
1480+
# Upper lane: [partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]
1481+
d0 = arith_d.select(is_lower, a_lo, partner_b_lo)
1482+
d1 = arith_d.select(is_lower, a_hi, partner_b_hi)
1483+
d2 = arith_d.select(is_lower, partner_a_lo, b_lo)
1484+
d3 = arith_d.select(is_lower, partner_a_hi, b_hi)
14391485

14401486
wide_i32 = vector_d.from_elements(v4i32_type, [d0, d1, d2, d3])
14411487
wide_vec = vector_d.bitcast(v8bf16_type, wide_i32)
14421488

1443-
elems_per_thread = arith_d.constant(idx_type, num_elems)
1489+
elems_per_thread = arith_d.constant(idx_type, 4)
1490+
1491+
# Lower lane uses tile A's address; upper lane uses tile B's address.
1492+
# Upper lane subtracts elems_per_thread from the last dim to align
1493+
# to the lower lane's column position (same as the single-write path).
1494+
adj_th = list(start_indices_th_a)
1495+
adj_full = list(start_indices_a)
1496+
for dim_idx in range(len(adj_th)):
1497+
if dim_idx == len(adj_th) - 1:
1498+
adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread)
1499+
adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread)
1500+
adj_th[dim_idx] = arith_d.select(is_lower, adj_th[dim_idx], adj_b_th)
1501+
adj_full[dim_idx] = arith_d.select(is_lower, adj_full[dim_idx], adj_b_full)
1502+
else:
1503+
adj_th[dim_idx] = arith_d.select(
1504+
is_lower, start_indices_th_a[dim_idx], start_indices_th_b[dim_idx]
1505+
)
1506+
adj_full[dim_idx] = arith_d.select(
1507+
is_lower, start_indices_a[dim_idx], start_indices_b[dim_idx]
1508+
)
14441509

1445-
adj_th = list(start_indices_th)
1446-
adj_th[-1] = arith_d.select(
1447-
is_lower, adj_th[-1], arith_d.subi(adj_th[-1], elems_per_thread)
1448-
)
1510+
adj_wg = list(start_indices_wg_a)
1511+
for dim_idx in range(len(adj_wg)):
1512+
adj_wg[dim_idx] = arith_d.select(
1513+
is_lower, start_indices_wg_a[dim_idx], start_indices_wg_b[dim_idx]
1514+
)
14491515

1450-
adj_full = list(start_indices)
1451-
adj_full[-1] = arith_d.select(
1452-
is_lower, adj_full[-1], arith_d.subi(adj_full[-1], elems_per_thread)
1453-
)
1516+
sel_output_shape = output_shape_a
1517+
sel_memory_custom = memory_custom_a
1518+
sel_kb_dest = kb_dest_a
1519+
sel_index = index_a
14541520

14551521
_create_vec_read_write(
14561522
emitter,
1457-
output_shape,
1458-
kb_dest,
1523+
sel_output_shape,
1524+
sel_kb_dest,
14591525
wide_vec,
14601526
None,
14611527
tuple(adj_full),
1462-
start_indices_wg,
1528+
tuple(adj_wg),
14631529
tuple(adj_th),
14641530
8,
1465-
memory_custom,
1531+
sel_memory_custom,
14661532
None,
1467-
node_index=index,
1533+
node_index=sel_index,
14681534
)
14691535

14701536

wave_lang/kernel/wave/wide_store_coalescing.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
codegen emits v_permlane16_swap_b32 + buffer_store_dwordx4 instead of
1414
scalar buffer_store_short.
1515
16+
Eligible writes are paired so that each lane in a (lane, lane+16)
17+
pair writes a *different* tile group's wide store, eliminating the
18+
duplicate stores that occur when both lanes write the same data.
19+
The number of eligible writes must be even (asserted).
20+
1621
Only tags writes that satisfy ALL conditions:
1722
1. Target memory is global address space
1823
2. Output dtype is bf16
@@ -34,11 +39,16 @@ def coalesce_wide_stores(trace: CapturedTrace):
3439
pattern (swapped MFMA operands, as produced by the wide_store kernel
3540
variant). Writes without source/target are left untouched, making
3641
this pass safe to run unconditionally.
42+
43+
Writes are paired for register-only deduplication: the first node
44+
in each pair stashes its codegen state, and the second triggers a
45+
paired ``permlane16_swap`` that lets each lane emit a unique store.
3746
"""
3847
import wave_lang.kernel.lang as tkl
3948

4049
root_graph = trace.get_root_graph()
4150

51+
eligible_writes = []
4252
for node in root_graph.nodes:
4353
if node.op != "call_function":
4454
continue
@@ -52,4 +62,23 @@ def coalesce_wide_stores(trace: CapturedTrace):
5262
subs_idxc(mem_type.address_space) == GLOBAL_ADDRESS_SPACE
5363
and mem_type.dtype == tkl.bf16
5464
):
55-
node._permlane_pack_global = True
65+
eligible_writes.append(node)
66+
67+
# TODO: Add a fallback path for odd number of writes.
68+
assert len(eligible_writes) % 2 == 0, (
69+
f"Expected even number of eligible wide-store writes, "
70+
f"got {len(eligible_writes)}."
71+
)
72+
73+
# Pair adjacent writes so the codegen can pass both tiles as
74+
# separate old_dst / src operands to a single permlane16_swap.
75+
# The "first" node stashes its codegen state; the "second" node
76+
# retrieves it and emits one unique store per lane (lower lane
77+
# writes tile A, upper lane writes tile B) — no duplicate stores.
78+
for first, second in zip(eligible_writes[0::2], eligible_writes[1::2]):
79+
first._permlane_pack_global = True
80+
first._permlane_pack_role = "first"
81+
first._permlane_partner = second
82+
second._permlane_pack_global = True
83+
second._permlane_pack_role = "second"
84+
second._permlane_partner = first

0 commit comments

Comments
 (0)