@@ -117,7 +117,7 @@ def _split_index(
117117 # Compute thread-independent index as `orig_index - thread_dependent_index`
118118 # All thread symbols and dynamic should cancel-out in the result.
119119 diff = src - thread_dependent_index
120- # Avoid sympy.simplify on Piecewise expressions — it recurses into boolean
120+ # Avoid sympy.simplify on Piecewise expressions : it recurses into boolean
121121 # condition simplification and can hang for complex dynamic-shape indices.
122122 # expand() handles basic polynomial cancellation and is O(fast).
123123 if isinstance (diff , sympy .Basic ) and diff .has (sympy .Piecewise ):
@@ -575,7 +575,7 @@ def _cast_buffer_and_encode_stride(
575575 stride_int = _get_constant_value (stride_candidate )
576576 # Emit swizzle stride for both static and dynamic cases.
577577 # Static: only if stride fits in signed i14 (max 8192).
578- # Dynamic: always emit — the SRD swizzle encoding is constant
578+ # Dynamic: always emit : the SRD swizzle encoding is constant
579579 # (0x40400000 + 0x27000) regardless of the actual stride value.
580580 if stride_int is None or stride_int <= 8192 :
581581 swizzle_stride = arith_d .index_cast (uint14 , stride_candidate )
@@ -1326,18 +1326,49 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13261326 if getattr (node , "_permlane_pack_global" , False ):
13271327 is_shared = get_custom (memory ).type .address_space == SHARED_ADDRESS_SPACE
13281328 if not is_shared and isinstance (element_type , BF16Type ):
1329- _write_permlane_pack_to_global (
1330- emitter ,
1331- insert_vector ,
1332- kb_dest ,
1333- output_shape ,
1334- start_indices ,
1335- start_indices_wg ,
1336- start_indices_th ,
1337- get_custom (memory ),
1338- index ,
1329+ role = getattr (node , "_permlane_pack_role" , "unpaired" )
1330+
1331+ if role == "first" :
1332+ node ._stashed_codegen = {
1333+ "insert_vector" : insert_vector ,
1334+ "kb_dest" : kb_dest ,
1335+ "output_shape" : output_shape ,
1336+ "start_indices" : start_indices ,
1337+ "start_indices_wg" : start_indices_wg ,
1338+ "start_indices_th" : start_indices_th ,
1339+ "memory_custom" : get_custom (memory ),
1340+ "index" : index ,
1341+ }
1342+ return
1343+
1344+ if role == "second" :
1345+ partner = node ._permlane_partner
1346+ s = partner ._stashed_codegen
1347+ _write_permlane_pair_to_global (
1348+ emitter ,
1349+ s ["insert_vector" ],
1350+ insert_vector ,
1351+ s ["kb_dest" ],
1352+ kb_dest ,
1353+ s ["output_shape" ],
1354+ output_shape ,
1355+ s ["start_indices" ],
1356+ s ["start_indices_wg" ],
1357+ s ["start_indices_th" ],
1358+ start_indices ,
1359+ start_indices_wg ,
1360+ start_indices_th ,
1361+ s ["memory_custom" ],
1362+ get_custom (memory ),
1363+ s ["index" ],
1364+ index ,
1365+ )
1366+ return
1367+
1368+ assert False , (
1369+ "Unexpected unpaired wide-store write. "
1370+ "coalesce_wide_stores should pair all eligible writes."
13391371 )
1340- return
13411372
13421373 if use_llvm_store :
13431374 _create_llvm_read_write (
@@ -1360,52 +1391,58 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13601391 )
13611392
13621393
1363- def _write_permlane_pack_to_global (
1394+ def _write_permlane_pair_to_global (
13641395 emitter : WaveEmitter ,
1365- insert_vector : Value ,
1366- kb_dest : Value ,
1367- output_shape : tuple ,
1368- start_indices : tuple ,
1369- start_indices_wg : tuple ,
1370- start_indices_th : tuple ,
1371- memory_custom ,
1372- index : dict ,
1396+ vec_a : Value ,
1397+ vec_b : Value ,
1398+ kb_dest_a : Value ,
1399+ kb_dest_b : Value ,
1400+ output_shape_a : tuple ,
1401+ output_shape_b : tuple ,
1402+ start_indices_a : tuple ,
1403+ start_indices_wg_a : tuple ,
1404+ start_indices_th_a : tuple ,
1405+ start_indices_b : tuple ,
1406+ start_indices_wg_b : tuple ,
1407+ start_indices_th_b : tuple ,
1408+ memory_custom_a ,
1409+ memory_custom_b ,
1410+ index_a : dict ,
1411+ index_b : dict ,
13731412):
1374- """Pack two lanes' bf16 values via permlane16_swap for wide global stores.
1413+ """Pair two tile groups via permlane16_swap for duplicate-free wide stores.
13751414
1376- Uses ``v_permlane16_swap_b32`` to exchange each thread's 4 bf16 values
1377- (packed as 2 i32 dwords) with a partner lane 16 positions apart.
1378- The result is 8 consecutive bf16 values per lane, written as a single
1379- ``buffer_store_dwordx4`` (128 bits).
1415+ Pairs tile A and tile B by passing them as separate ``old_dst`` /
1416+ ``src`` operands to ``v_permlane16_swap_b32``. Both outputs of the
1417+ swap carry distinct data, so each lane writes a *different* tile
1418+ group's wide store : no duplicate stores:
13801419
1381- Both lane halves write identical data to the same address (benign
1382- duplicate store), avoiding divergent control flow. The buffer
1383- descriptor's ``valid_bytes`` handles out-of-bounds suppression.
1420+ * Lower lane (lane % 32 < 16) writes tile A:
1421+ ``[own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]``
1422+ * Upper lane (lane % 32 >= 16) writes tile B:
1423+ ``[partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]``
13841424
1385- TODO: Eliminate duplicate stores by using both outputs of
1386- ``permlane16_swap``, letting each lane write the partner's assembled
1387- data to the partner's destination address so every lane performs a
1388- unique store.
1425+ This halves both the ``permlane16_swap`` count and the global store
1426+ count compared to the single-write approach.
13891427
1390- Preconditions:
1391- - The kernel must use swapped MFMA operands (B as LHS, A as RHS)
1392- so the accumulator's 4-contiguous values align with the output
1393- memory's stride-1 dimension.
1394- - The Write node must be tagged with ``_permlane_pack_global=True``
1395- by the ``coalesce_wide_stores`` pass.
1428+ Preconditions (same as ``_write_permlane_pack_to_global``):
1429+ - Swapped MFMA operands, F32_16x16x128_F8F6F4 layout, bf16 output.
1430+ - Both Write nodes tagged by ``coalesce_wide_stores``.
13961431
13971432 .. note::
1398- Currently assumes F32_16x16x128_F8F6F4 MMA layout (4 values
1399- along MMA-M per thread, 16-lane groups). Generalizing to other
1400- MMA types requires parameterizing the lane group size and
1401- elements per thread.
1433+ The store is emitted using tile A's ``output_shape``,
1434+ ``memory_custom``, ``kb_dest``, and ``index``. This is correct
1435+ when both tiles target the same output buffer with identical
1436+ shape and buffer descriptor (the standard MXFP4 GEMM case).
1437+ If the two tiles ever target different buffers, the
1438+ ``_create_vec_read_write`` call would need to be split into
1439+ two lane-divergent stores.
14021440 """
1403- vec_type = insert_vector .type
1404- num_elems = vec_type .shape [0 ] if hasattr (vec_type , "shape" ) else 1
1405- assert num_elems == 4 , (
1406- f"_write_permlane_pack_to_global expects 4 bf16 elements per thread "
1407- f"(F32_16x16x128_F8F6F4 MMA layout), got { num_elems } . "
1408- f"Other MMA types are not yet supported."
1441+ num_elems_a = vec_a .type .shape [0 ] if hasattr (vec_a .type , "shape" ) else 1
1442+ num_elems_b = vec_b .type .shape [0 ] if hasattr (vec_b .type , "shape" ) else 1
1443+ assert num_elems_a == 4 and num_elems_b == 4 , (
1444+ f"_write_permlane_pair_to_global expects 4 bf16 elements per thread "
1445+ f"per tile, got { num_elems_a } and { num_elems_b } ."
14091446 )
14101447
14111448 bf16_type = BF16Type .get ()
@@ -1415,57 +1452,86 @@ def _write_permlane_pack_to_global(
14151452 v4i32_type = VectorType .get ([4 ], i32_type )
14161453 v8bf16_type = VectorType .get ([8 ], bf16_type )
14171454
1418- i32_vec = vector_d .bitcast (v2i32_type , insert_vector )
1419- own_lo = vector_d .extract (i32_vec , static_position = [0 ], dynamic_position = [])
1420- own_hi = vector_d .extract (i32_vec , static_position = [1 ], dynamic_position = [])
1455+ i32_a = vector_d .bitcast (v2i32_type , vec_a )
1456+ a_lo = vector_d .extract (i32_a , static_position = [0 ], dynamic_position = [])
1457+ a_hi = vector_d .extract (i32_a , static_position = [1 ], dynamic_position = [])
1458+
1459+ i32_b = vector_d .bitcast (v2i32_type , vec_b )
1460+ b_lo = vector_d .extract (i32_b , static_position = [0 ], dynamic_position = [])
1461+ b_hi = vector_d .extract (i32_b , static_position = [1 ], dynamic_position = [])
14211462
14221463 swap_type = llvm_d .StructType .get_literal ([i32_type , i32_type ])
1423- partner_lo = llvm_d .extractvalue (
1424- i32_type , rocdl_d .permlane16_swap (swap_type , own_lo , own_lo , False , False ), [0 ]
1425- )
1426- partner_hi = llvm_d .extractvalue (
1427- i32_type , rocdl_d .permlane16_swap (swap_type , own_hi , own_hi , False , False ), [0 ]
1428- )
1464+
1465+ # old_dst = a, src = b → result[0] = partner's b, result[1] = partner's a
1466+ swap_lo = rocdl_d .permlane16_swap (swap_type , a_lo , b_lo , False , False )
1467+ swap_hi = rocdl_d .permlane16_swap (swap_type , a_hi , b_hi , False , False )
1468+
1469+ partner_b_lo = llvm_d .extractvalue (i32_type , swap_lo , [0 ])
1470+ partner_a_lo = llvm_d .extractvalue (i32_type , swap_lo , [1 ])
1471+ partner_b_hi = llvm_d .extractvalue (i32_type , swap_hi , [0 ])
1472+ partner_a_hi = llvm_d .extractvalue (i32_type , swap_hi , [1 ])
14291473
14301474 lane_in_wave = arith_d .remui (emitter .thread_ids [0 ], arith_d .constant (idx_type , 64 ))
14311475 half_pos = arith_d .remui (lane_in_wave , arith_d .constant (idx_type , 32 ))
14321476 is_lower = arith_d .cmpi (
14331477 arith_d .CmpIPredicate .ult , half_pos , arith_d .constant (idx_type , 16 )
14341478 )
14351479
1436- d0 = arith_d .select (is_lower , own_lo , partner_lo )
1437- d1 = arith_d .select (is_lower , own_hi , partner_hi )
1438- d2 = arith_d .select (is_lower , partner_lo , own_lo )
1439- d3 = arith_d .select (is_lower , partner_hi , own_hi )
1480+ # Lower lane: [own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]
1481+ # Upper lane: [partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]
1482+ d0 = arith_d .select (is_lower , a_lo , partner_b_lo )
1483+ d1 = arith_d .select (is_lower , a_hi , partner_b_hi )
1484+ d2 = arith_d .select (is_lower , partner_a_lo , b_lo )
1485+ d3 = arith_d .select (is_lower , partner_a_hi , b_hi )
14401486
14411487 wide_i32 = vector_d .from_elements (v4i32_type , [d0 , d1 , d2 , d3 ])
14421488 wide_vec = vector_d .bitcast (v8bf16_type , wide_i32 )
14431489
1444- elems_per_thread = arith_d .constant (idx_type , num_elems )
1490+ elems_per_thread = arith_d .constant (idx_type , 4 )
1491+
1492+ # Lower lane uses tile A's address; upper lane uses tile B's address.
1493+ # Upper lane subtracts elems_per_thread from the last dim to align
1494+ # to the lower lane's column position (same as the single-write path).
1495+ adj_th = list (start_indices_th_a )
1496+ adj_full = list (start_indices_a )
1497+ for dim_idx in range (len (adj_th )):
1498+ if dim_idx == len (adj_th ) - 1 :
1499+ adj_b_th = arith_d .subi (start_indices_th_b [- 1 ], elems_per_thread )
1500+ adj_b_full = arith_d .subi (start_indices_b [- 1 ], elems_per_thread )
1501+ adj_th [dim_idx ] = arith_d .select (is_lower , adj_th [dim_idx ], adj_b_th )
1502+ adj_full [dim_idx ] = arith_d .select (is_lower , adj_full [dim_idx ], adj_b_full )
1503+ else :
1504+ adj_th [dim_idx ] = arith_d .select (
1505+ is_lower , start_indices_th_a [dim_idx ], start_indices_th_b [dim_idx ]
1506+ )
1507+ adj_full [dim_idx ] = arith_d .select (
1508+ is_lower , start_indices_a [dim_idx ], start_indices_b [dim_idx ]
1509+ )
14451510
1446- adj_th = list (start_indices_th )
1447- adj_th [- 1 ] = arith_d .select (
1448- is_lower , adj_th [- 1 ], arith_d .subi (adj_th [- 1 ], elems_per_thread )
1449- )
1511+ adj_wg = list (start_indices_wg_a )
1512+ for dim_idx in range (len (adj_wg )):
1513+ adj_wg [dim_idx ] = arith_d .select (
1514+ is_lower , start_indices_wg_a [dim_idx ], start_indices_wg_b [dim_idx ]
1515+ )
14501516
1451- adj_full = list ( start_indices )
1452- adj_full [ - 1 ] = arith_d . select (
1453- is_lower , adj_full [ - 1 ], arith_d . subi ( adj_full [ - 1 ], elems_per_thread )
1454- )
1517+ sel_output_shape = output_shape_a
1518+ sel_memory_custom = memory_custom_a
1519+ sel_kb_dest = kb_dest_a
1520+ sel_index = index_a
14551521
14561522 _create_vec_read_write (
14571523 emitter ,
1458- output_shape ,
1459- kb_dest ,
1524+ sel_output_shape ,
1525+ sel_kb_dest ,
14601526 wide_vec ,
14611527 None ,
14621528 tuple (adj_full ),
1463- start_indices_wg ,
1529+ tuple ( adj_wg ) ,
14641530 tuple (adj_th ),
14651531 8 ,
1466- memory_custom ,
1532+ sel_memory_custom ,
14671533 None ,
1468- node_index = index ,
1534+ node_index = sel_index ,
14691535 )
14701536
14711537
0 commit comments