@@ -117,7 +117,7 @@ def _split_index(
117117 # Compute thread-independent index as `orig_index - thread_dependent_index`
118118 # All thread symbols and dynamic should cancel-out in the result.
119119 diff = src - thread_dependent_index
120- # Avoid sympy.simplify on Piecewise expressions — it recurses into boolean
120+ # Avoid sympy.simplify on Piecewise expressions : it recurses into boolean
121121 # condition simplification and can hang for complex dynamic-shape indices.
122122 # expand() handles basic polynomial cancellation and is O(fast).
123123 if isinstance (diff , sympy .Basic ) and diff .has (sympy .Piecewise ):
@@ -575,7 +575,7 @@ def _cast_buffer_and_encode_stride(
575575 stride_int = _get_constant_value (stride_candidate )
576576 # Emit swizzle stride for both static and dynamic cases.
577577 # Static: only if stride fits in signed i14 (max 8192).
578- # Dynamic: always emit — the SRD swizzle encoding is constant
578+ # Dynamic: always emit : the SRD swizzle encoding is constant
579579 # (0x40400000 + 0x27000) regardless of the actual stride value.
580580 if stride_int is None or stride_int <= 8192 :
581581 swizzle_stride = arith_d .index_cast (uint14 , stride_candidate )
@@ -1325,18 +1325,49 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13251325 if getattr (node , "_permlane_pack_global" , False ):
13261326 is_shared = get_custom (memory ).type .address_space == SHARED_ADDRESS_SPACE
13271327 if not is_shared and isinstance (element_type , BF16Type ):
1328- _write_permlane_pack_to_global (
1329- emitter ,
1330- insert_vector ,
1331- kb_dest ,
1332- output_shape ,
1333- start_indices ,
1334- start_indices_wg ,
1335- start_indices_th ,
1336- get_custom (memory ),
1337- index ,
1328+ role = getattr (node , "_permlane_pack_role" , "unpaired" )
1329+
1330+ if role == "first" :
1331+ node ._stashed_codegen = {
1332+ "insert_vector" : insert_vector ,
1333+ "kb_dest" : kb_dest ,
1334+ "output_shape" : output_shape ,
1335+ "start_indices" : start_indices ,
1336+ "start_indices_wg" : start_indices_wg ,
1337+ "start_indices_th" : start_indices_th ,
1338+ "memory_custom" : get_custom (memory ),
1339+ "index" : index ,
1340+ }
1341+ return
1342+
1343+ if role == "second" :
1344+ partner = node ._permlane_partner
1345+ s = partner ._stashed_codegen
1346+ _write_permlane_pair_to_global (
1347+ emitter ,
1348+ s ["insert_vector" ],
1349+ insert_vector ,
1350+ s ["kb_dest" ],
1351+ kb_dest ,
1352+ s ["output_shape" ],
1353+ output_shape ,
1354+ s ["start_indices" ],
1355+ s ["start_indices_wg" ],
1356+ s ["start_indices_th" ],
1357+ start_indices ,
1358+ start_indices_wg ,
1359+ start_indices_th ,
1360+ s ["memory_custom" ],
1361+ get_custom (memory ),
1362+ s ["index" ],
1363+ index ,
1364+ )
1365+ return
1366+
1367+ assert False , (
1368+ "Unexpected unpaired wide-store write. "
1369+ "coalesce_wide_stores should pair all eligible writes."
13381370 )
1339- return
13401371
13411372 if use_llvm_store :
13421373 _create_llvm_read_write (
@@ -1359,52 +1390,58 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
13591390 )
13601391
13611392
1362- def _write_permlane_pack_to_global (
1393+ def _write_permlane_pair_to_global (
13631394 emitter : WaveEmitter ,
1364- insert_vector : Value ,
1365- kb_dest : Value ,
1366- output_shape : tuple ,
1367- start_indices : tuple ,
1368- start_indices_wg : tuple ,
1369- start_indices_th : tuple ,
1370- memory_custom ,
1371- index : dict ,
1395+ vec_a : Value ,
1396+ vec_b : Value ,
1397+ kb_dest_a : Value ,
1398+ kb_dest_b : Value ,
1399+ output_shape_a : tuple ,
1400+ output_shape_b : tuple ,
1401+ start_indices_a : tuple ,
1402+ start_indices_wg_a : tuple ,
1403+ start_indices_th_a : tuple ,
1404+ start_indices_b : tuple ,
1405+ start_indices_wg_b : tuple ,
1406+ start_indices_th_b : tuple ,
1407+ memory_custom_a ,
1408+ memory_custom_b ,
1409+ index_a : dict ,
1410+ index_b : dict ,
13721411):
1373- """Pack two lanes' bf16 values via permlane16_swap for wide global stores.
1412+ """Pair two tile groups via permlane16_swap for duplicate-free wide stores.
13741413
1375- Uses ``v_permlane16_swap_b32`` to exchange each thread's 4 bf16 values
1376- (packed as 2 i32 dwords) with a partner lane 16 positions apart.
1377- The result is 8 consecutive bf16 values per lane, written as a single
1378- ``buffer_store_dwordx4`` (128 bits).
1414+ Pairs tile A and tile B by passing them as separate ``old_dst`` /
1415+ ``src`` operands to ``v_permlane16_swap_b32``. Both outputs of the
1416+ swap carry distinct data, so each lane writes a *different* tile
1417+ group's wide store : no duplicate stores:
13791418
1380- Both lane halves write identical data to the same address (benign
1381- duplicate store), avoiding divergent control flow. The buffer
1382- descriptor's ``valid_bytes`` handles out-of-bounds suppression.
1419+ * Lower lane (lane % 32 < 16) writes tile A:
1420+ ``[own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]``
1421+ * Upper lane (lane % 32 >= 16) writes tile B:
1422+ ``[partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]``
13831423
1384- TODO: Eliminate duplicate stores by using both outputs of
1385- ``permlane16_swap``, letting each lane write the partner's assembled
1386- data to the partner's destination address so every lane performs a
1387- unique store.
1424+ This halves both the ``permlane16_swap`` count and the global store
1425+ count compared to the single-write approach.
13881426
1389- Preconditions:
1390- - The kernel must use swapped MFMA operands (B as LHS, A as RHS)
1391- so the accumulator's 4-contiguous values align with the output
1392- memory's stride-1 dimension.
1393- - The Write node must be tagged with ``_permlane_pack_global=True``
1394- by the ``coalesce_wide_stores`` pass.
1427+ Preconditions (same as ``_write_permlane_pack_to_global``):
1428+ - Swapped MFMA operands, F32_16x16x128_F8F6F4 layout, bf16 output.
1429+ - Both Write nodes tagged by ``coalesce_wide_stores``.
13951430
13961431 .. note::
1397- Currently assumes F32_16x16x128_F8F6F4 MMA layout (4 values
1398- along MMA-M per thread, 16-lane groups). Generalizing to other
1399- MMA types requires parameterizing the lane group size and
1400- elements per thread.
1432+ The store is emitted using tile A's ``output_shape``,
1433+ ``memory_custom``, ``kb_dest``, and ``index``. This is correct
1434+ when both tiles target the same output buffer with identical
1435+ shape and buffer descriptor (the standard MXFP4 GEMM case).
1436+ If the two tiles ever target different buffers, the
1437+ ``_create_vec_read_write`` call would need to be split into
1438+ two lane-divergent stores.
14011439 """
1402- vec_type = insert_vector .type
1403- num_elems = vec_type .shape [0 ] if hasattr (vec_type , "shape" ) else 1
1404- assert num_elems == 4 , (
1405- f"_write_permlane_pack_to_global expects 4 bf16 elements per thread "
1406- f"(F32_16x16x128_F8F6F4 MMA layout), got { num_elems } . "
1407- f"Other MMA types are not yet supported."
1440+ num_elems_a = vec_a .type .shape [0 ] if hasattr (vec_a .type , "shape" ) else 1
1441+ num_elems_b = vec_b .type .shape [0 ] if hasattr (vec_b .type , "shape" ) else 1
1442+ assert num_elems_a == 4 and num_elems_b == 4 , (
1443+ f"_write_permlane_pair_to_global expects 4 bf16 elements per thread "
1444+ f"per tile, got { num_elems_a } and { num_elems_b } ."
14081445 )
14091446
14101447 bf16_type = BF16Type .get ()
@@ -1414,57 +1451,86 @@ def _write_permlane_pack_to_global(
14141451 v4i32_type = VectorType .get ([4 ], i32_type )
14151452 v8bf16_type = VectorType .get ([8 ], bf16_type )
14161453
1417- i32_vec = vector_d .bitcast (v2i32_type , insert_vector )
1418- own_lo = vector_d .extract (i32_vec , static_position = [0 ], dynamic_position = [])
1419- own_hi = vector_d .extract (i32_vec , static_position = [1 ], dynamic_position = [])
1454+ i32_a = vector_d .bitcast (v2i32_type , vec_a )
1455+ a_lo = vector_d .extract (i32_a , static_position = [0 ], dynamic_position = [])
1456+ a_hi = vector_d .extract (i32_a , static_position = [1 ], dynamic_position = [])
1457+
1458+ i32_b = vector_d .bitcast (v2i32_type , vec_b )
1459+ b_lo = vector_d .extract (i32_b , static_position = [0 ], dynamic_position = [])
1460+ b_hi = vector_d .extract (i32_b , static_position = [1 ], dynamic_position = [])
14201461
14211462 swap_type = llvm_d .StructType .get_literal ([i32_type , i32_type ])
1422- partner_lo = llvm_d .extractvalue (
1423- i32_type , rocdl_d .permlane16_swap (swap_type , own_lo , own_lo , False , False ), [0 ]
1424- )
1425- partner_hi = llvm_d .extractvalue (
1426- i32_type , rocdl_d .permlane16_swap (swap_type , own_hi , own_hi , False , False ), [0 ]
1427- )
1463+
1464+ # old_dst = a, src = b → result[0] = partner's b, result[1] = partner's a
1465+ swap_lo = rocdl_d .permlane16_swap (swap_type , a_lo , b_lo , False , False )
1466+ swap_hi = rocdl_d .permlane16_swap (swap_type , a_hi , b_hi , False , False )
1467+
1468+ partner_b_lo = llvm_d .extractvalue (i32_type , swap_lo , [0 ])
1469+ partner_a_lo = llvm_d .extractvalue (i32_type , swap_lo , [1 ])
1470+ partner_b_hi = llvm_d .extractvalue (i32_type , swap_hi , [0 ])
1471+ partner_a_hi = llvm_d .extractvalue (i32_type , swap_hi , [1 ])
14281472
14291473 lane_in_wave = arith_d .remui (emitter .thread_ids [0 ], arith_d .constant (idx_type , 64 ))
14301474 half_pos = arith_d .remui (lane_in_wave , arith_d .constant (idx_type , 32 ))
14311475 is_lower = arith_d .cmpi (
14321476 arith_d .CmpIPredicate .ult , half_pos , arith_d .constant (idx_type , 16 )
14331477 )
14341478
1435- d0 = arith_d .select (is_lower , own_lo , partner_lo )
1436- d1 = arith_d .select (is_lower , own_hi , partner_hi )
1437- d2 = arith_d .select (is_lower , partner_lo , own_lo )
1438- d3 = arith_d .select (is_lower , partner_hi , own_hi )
1479+ # Lower lane: [own_A_lo, own_A_hi, partner_A_lo, partner_A_hi]
1480+ # Upper lane: [partner_B_lo, partner_B_hi, own_B_lo, own_B_hi]
1481+ d0 = arith_d .select (is_lower , a_lo , partner_b_lo )
1482+ d1 = arith_d .select (is_lower , a_hi , partner_b_hi )
1483+ d2 = arith_d .select (is_lower , partner_a_lo , b_lo )
1484+ d3 = arith_d .select (is_lower , partner_a_hi , b_hi )
14391485
14401486 wide_i32 = vector_d .from_elements (v4i32_type , [d0 , d1 , d2 , d3 ])
14411487 wide_vec = vector_d .bitcast (v8bf16_type , wide_i32 )
14421488
1443- elems_per_thread = arith_d .constant (idx_type , num_elems )
1489+ elems_per_thread = arith_d .constant (idx_type , 4 )
1490+
1491+ # Lower lane uses tile A's address; upper lane uses tile B's address.
1492+ # Upper lane subtracts elems_per_thread from the last dim to align
1493+ # to the lower lane's column position (same as the single-write path).
1494+ adj_th = list (start_indices_th_a )
1495+ adj_full = list (start_indices_a )
1496+ for dim_idx in range (len (adj_th )):
1497+ if dim_idx == len (adj_th ) - 1 :
1498+ adj_b_th = arith_d .subi (start_indices_th_b [- 1 ], elems_per_thread )
1499+ adj_b_full = arith_d .subi (start_indices_b [- 1 ], elems_per_thread )
1500+ adj_th [dim_idx ] = arith_d .select (is_lower , adj_th [dim_idx ], adj_b_th )
1501+ adj_full [dim_idx ] = arith_d .select (is_lower , adj_full [dim_idx ], adj_b_full )
1502+ else :
1503+ adj_th [dim_idx ] = arith_d .select (
1504+ is_lower , start_indices_th_a [dim_idx ], start_indices_th_b [dim_idx ]
1505+ )
1506+ adj_full [dim_idx ] = arith_d .select (
1507+ is_lower , start_indices_a [dim_idx ], start_indices_b [dim_idx ]
1508+ )
14441509
1445- adj_th = list (start_indices_th )
1446- adj_th [- 1 ] = arith_d .select (
1447- is_lower , adj_th [- 1 ], arith_d .subi (adj_th [- 1 ], elems_per_thread )
1448- )
1510+ adj_wg = list (start_indices_wg_a )
1511+ for dim_idx in range (len (adj_wg )):
1512+ adj_wg [dim_idx ] = arith_d .select (
1513+ is_lower , start_indices_wg_a [dim_idx ], start_indices_wg_b [dim_idx ]
1514+ )
14491515
1450- adj_full = list ( start_indices )
1451- adj_full [ - 1 ] = arith_d . select (
1452- is_lower , adj_full [ - 1 ], arith_d . subi ( adj_full [ - 1 ], elems_per_thread )
1453- )
1516+ sel_output_shape = output_shape_a
1517+ sel_memory_custom = memory_custom_a
1518+ sel_kb_dest = kb_dest_a
1519+ sel_index = index_a
14541520
14551521 _create_vec_read_write (
14561522 emitter ,
1457- output_shape ,
1458- kb_dest ,
1523+ sel_output_shape ,
1524+ sel_kb_dest ,
14591525 wide_vec ,
14601526 None ,
14611527 tuple (adj_full ),
1462- start_indices_wg ,
1528+ tuple ( adj_wg ) ,
14631529 tuple (adj_th ),
14641530 8 ,
1465- memory_custom ,
1531+ sel_memory_custom ,
14661532 None ,
1467- node_index = index ,
1533+ node_index = sel_index ,
14681534 )
14691535
14701536
0 commit comments