|
13 | 13 | import cutlass.pipeline |
14 | 14 | from cutlass._mlir.dialects import llvm |
15 | 15 | from cutlass._mlir import ir |
| 16 | +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir |
16 | 17 |
|
| 18 | +from quack import layout_utils |
17 | 19 | from quack.utils import make_vector |
18 | | -from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir |
19 | 20 |
|
20 | 21 |
|
21 | 22 | Sm100MmaPeerBitMask = 0xFEFFFFFF |
@@ -1023,16 +1024,83 @@ def gather_m_get_tma_copy_fn( |
1023 | 1024 | tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) |
1024 | 1025 |
|
1025 | 1026 | def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): |
| 1027 | + tSR_sA_cur = tSR_sA[None, None, None, dst_idx] |
1026 | 1028 | col_idx = tile_K * src_idx |
1027 | 1029 | for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): |
1028 | 1030 | row_indices = [tSR_rAIdx[v, m] for v in range(4)] |
1029 | | - smem_ptr = tSR_sA[None, m, None, dst_idx].iterator |
| 1031 | + smem_ptr = tSR_sA_cur[None, m, None].iterator |
1030 | 1032 | with cute.arch.elect_one(): |
1031 | 1033 | tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) |
1032 | 1034 |
|
1033 | 1035 | return copy_fn |
1034 | 1036 |
|
1035 | 1037 |
|
| 1038 | +@cute.jit |
| 1039 | +def gather_k_get_tma_copy_fn( |
| 1040 | + tma_atom: cute.CopyAtom, |
| 1041 | + sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout |
| 1042 | + sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem |
| 1043 | + col_idx: Int32, # M offset in global tensor (contiguous dim for M-major) |
| 1044 | + warp_idx: Int32, |
| 1045 | + num_warps: int, |
| 1046 | + num_cta: int = 1, |
| 1047 | +) -> Tuple[Callable, Callable]: |
| 1048 | + """Build a copy function for TMA gather4 in K dimension (M-major A). |
| 1049 | +
|
| 1050 | + Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements. |
| 1051 | + col_idx is the absolute M position in the global tensor. |
| 1052 | + K indices come from sAIdx (prefetched to smem by the scheduler warp). |
| 1053 | +
|
| 1054 | + Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which: |
| 1055 | + Issues gather4 calls with those K indices as row_indices |
| 1056 | + """ |
| 1057 | + tile_K = cute.size(sAIdx, mode=[0]) |
| 1058 | + assert tile_K % 4 == 0 |
| 1059 | + cta_group = num_cta |
| 1060 | + |
| 1061 | + # Tiled copy for loading K indices from smem to registers (4 per vector, across warps) |
| 1062 | + copy_AIdx_s2r = cute.make_tiled_copy_tv( |
| 1063 | + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), |
| 1064 | + cute.make_layout(num_warps), # thr_layout |
| 1065 | + cute.make_layout(4), # val_layout — 4 K indices per gather4 |
| 1066 | + ) |
| 1067 | + warp_idx = cute.arch.make_warp_uniform(warp_idx) |
| 1068 | + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) |
| 1069 | + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4)) |
| 1070 | + # ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192)) |
| 1071 | + tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA)) |
| 1072 | + tma_desc_ptr = get_tma_desc_addr(tma_atom) |
| 1073 | + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) |
| 1074 | + |
| 1075 | + def prefetch_from_smem_fn( |
| 1076 | + a_prefetch_pipeline, |
| 1077 | + src_idx, |
| 1078 | + dst_idx, |
| 1079 | + a_prefetch_consumer_state, |
| 1080 | + ) -> cute.Tensor: |
| 1081 | + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) |
| 1082 | + tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx]) |
| 1083 | + cute.arch.sync_warp() |
| 1084 | + with cute.arch.elect_one(): |
| 1085 | + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) |
| 1086 | + return tSR_rAIdx |
| 1087 | + |
| 1088 | + def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer): |
| 1089 | + # Issue gather4: col_idx = M position, row_indices = 4 K positions |
| 1090 | + tSR_sA_cur = tSR_sA[None, None, None, dst_idx] |
| 1091 | + gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64 |
| 1092 | + for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): |
| 1093 | + row_indices = [tSR_rAIdx[v, k] for v in range(4)] |
| 1094 | + for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True): |
| 1095 | + smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator |
| 1096 | + with cute.arch.elect_one(): |
| 1097 | + tma_gather4_load_fn( |
| 1098 | + smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices |
| 1099 | + ) |
| 1100 | + |
| 1101 | + return copy_fn, prefetch_from_smem_fn |
| 1102 | + |
| 1103 | + |
1036 | 1104 | # --------------------------------------------------------------------------- |
1037 | 1105 | # Store helpers |
1038 | 1106 | # --------------------------------------------------------------------------- |
|
0 commit comments