@@ -18,15 +18,16 @@ import cuTile as ct
1818 Helper: 2D swizzle (same pattern as matmul.jl)
1919=============================================================================#
2020
21- @inline function swizzle_2d (M, N, tm, tn, GROUP_SIZE_M, bid)
21+ function swizzle_2d (M, N, tm, tn, GROUP_SIZE_M, bid)
2222 num_bid_m = cld (M, Int32 (tm))
2323 num_bid_n = cld (N, Int32 (tn))
2424 num_bid_in_group = Int32 (GROUP_SIZE_M) * num_bid_n
25- group_id = fld (bid, num_bid_in_group)
25+ bid0 = bid - Int32 (1 )
26+ group_id = fld (bid0, num_bid_in_group)
2627 first_bid_m = group_id * Int32 (GROUP_SIZE_M)
2728 group_size_m = min (num_bid_m - first_bid_m, Int32 (GROUP_SIZE_M))
28- bid_m = first_bid_m + rem (bid , group_size_m)
29- bid_n = fld (rem (bid , num_bid_in_group), group_size_m)
29+ bid_m = first_bid_m + rem (bid0 , group_size_m) + Int32 ( 1 )
30+ bid_n = fld (rem (bid0 , num_bid_in_group), group_size_m) + Int32 ( 1 )
3031 return bid_m, bid_n
3132end
3233
@@ -53,19 +54,18 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
5354 K = size (B, 1 )
5455 N = size (B, 2 )
5556
56- bid = ct. bid (1 ) - Int32 ( 1 ) # 0-indexed for swizzle
57+ bid = ct. bid (1 )
5758 bid_m, bid_n = swizzle_2d (M, N, TILE_M, TILE_N, Int32 (8 ), bid)
5859
5960 # Gather 1-indexed token IDs for this block
60- token_id_indices = bid_m * Int32 (TILE_M) .+ ct. arange (TILE_M)
61+ token_id_indices = ( bid_m - Int32 ( 1 )) * Int32 (TILE_M) .+ ct. arange (TILE_M)
6162 token_ids = ct. gather (sorted_token_ids, token_id_indices)
6263
63- # Map 1-indexed flat token_id to 1-indexed column in A
64- # token_id k → original token = (k-1) ÷ num_token_replicas + 1
65- a_tok_indices = (token_ids .- Int32 (1 )) .÷ Int32 ( num_token_replicas) .+ Int32 ( 1 )
64+ # 1-indexed flat token_id → 1-indexed column in A. Each original token
65+ # has `num_token_replicas` consecutive ids; ceil-divide recovers it.
66+ a_tok_indices = cld . (token_ids, Int32 (num_token_replicas))
6667
67- # Expert for this block (scalar, 1-indexed tile index for load)
68- expert_id = sorted_expert_ids[bid_m + Int32 (1 )]
68+ expert_id = sorted_expert_ids[bid_m]
6969
7070 acc = zeros (Float32, TILE_N, TILE_M)
7171 num_k = cld (K, Int32 (TILE_K))
@@ -81,7 +81,7 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
8181 # B is (K, N, num_experts): load (TILE_N, TILE_K) slice for this expert
8282 # order=(2,1,3) folds the transpose into the load via dim_map, matching
8383 # Python cuTile's order=(0,2,1) and avoiding an explicit permute in Tile IR.
84- b = ct. load (B; index= (bid_n + Int32 ( 1 ) , k, expert_id),
84+ b = ct. load (B; index= (bid_n, k, expert_id),
8585 shape= (TILE_N, TILE_K, 1 ), order= (2 , 1 , 3 ),
8686 padding_mode= ct. PaddingMode. Zero)
8787 b = reshape (b, (TILE_N, TILE_K))
@@ -97,7 +97,7 @@ function fused_moe_kernel(A::ct.TileArray{T, 2}, B::ct.TileArray{T, 3},
9797
9898 # Scatter result to C at token_id positions
9999 # C is (N, total_tokens): dim 1 = N, dim 2 = tokens
100- c_n_indices = bid_n * Int32 (TILE_N) .+ ct. arange (TILE_N) # 1-indexed
100+ c_n_indices = ( bid_n - Int32 ( 1 )) * Int32 (TILE_N) .+ ct. arange (TILE_N) # 1-indexed
101101 output = convert (ct. Tile{T}, acc)
102102 ct. scatter (C, (reshape (c_n_indices, (TILE_N, 1 )),
103103 reshape (token_ids, (1 , TILE_M))), output)
@@ -204,24 +204,12 @@ function invoke_silu_and_mul_kernel(AB, C)
204204 inter = size (C, 1 ) # C is (intermediate, total_tokens)
205205 total = size (AB, 2 )
206206
207- # Split AB(inter*2, total) into gate and up halves along dim 1.
208- # A_half = AB[1:inter, :]
209- # B_half = AB[inter+1:2*inter, :]
210- # FIXME : CUDA.jl's CuArray indexing (AB[1:inter, :]) uses a slow generic kernel.
211- # Use unsafe_copy2d! (cuMemcpy2D) for hardware-accelerated pitched 2D copy instead.
212- T = eltype (AB)
213- A_half = similar (AB, inter, total)
214- B_half = similar (AB, inter, total)
215- src_pitch = size (AB, 1 ) * sizeof (T)
216- dst_pitch = inter * sizeof (T)
217- CUDACore. unsafe_copy2d! (pointer (A_half), CUDACore. DeviceMemory,
218- pointer (AB), CUDACore. DeviceMemory,
219- inter, total; srcPitch= src_pitch, dstPitch= dst_pitch,
220- async= true )
221- CUDACore. unsafe_copy2d! (pointer (B_half), CUDACore. DeviceMemory,
222- pointer (AB) + inter * sizeof (T), CUDACore. DeviceMemory,
223- inter, total; srcPitch= src_pitch, dstPitch= dst_pitch,
224- async= true )
207+ # Split AB(inter*2, total) into gate and up halves along dim 1 — mirrors
208+ # cuTile Python's `AB.chunk(2, dim=-1)`. Views are non-contiguous along
209+ # dim 2 but each block only loads a (TILE_N, 1) tile, so codegen is
210+ # unaffected.
211+ A_half = view (AB, 1 : inter, :)
212+ B_half = view (AB, (inter + 1 ): (2 * inter), :)
225213
226214 tile_n = nextpow (2 , inter)
227215 @cuda backend= cuTile blocks= total silu_and_mul_kernel (A_half, B_half, C, ct. Constant (tile_n))
0 commit comments