Skip to content

Commit a24e4aa

Browse files
committed
[Tool] Add example of gemm trace profile
1 parent 8d7b16b commit a24e4aa

4 files changed

Lines changed: 79 additions & 3 deletions

File tree

examples/example_gemm_trace.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
"""Trace an SM90 GEMM kernel and visualize in Perfetto.
3+
4+
Run with: QUACK_TRACE=1 python examples/example_gemm_trace.py
5+
Visualize: Open /tmp/gemm_trace.json in https://ui.perfetto.dev
6+
"""
7+
8+
import math
9+
10+
import torch
11+
import cutlass
12+
from cutlass import Float32
13+
14+
from quack.gemm import gemm
15+
from quack.gemm_default_epi import GemmDefaultSm90
16+
from quack.trace import TraceSession
17+
18+
M, N, K = 4096, 4096, 4096
19+
TILE_M, TILE_N = 128, 192
20+
CLUSTER_M, CLUSTER_N = 2, 1
21+
OUT_PATH = "/tmp/gemm_trace.json"
22+
23+
24+
def main():
25+
A = torch.randn(1, M, K, device="cuda", dtype=torch.float16)
26+
B = torch.randn(1, N, K, device="cuda", dtype=torch.float16)
27+
D = torch.empty(1, M, N, device="cuda", dtype=torch.float16)
28+
29+
# Query the GEMM config for block size (threads_per_cta).
30+
g = GemmDefaultSm90(Float32, cutlass.Float16, (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N, 1))
31+
# grid_size = math.ceil(M / TILE_M) * math.ceil(N / TILE_N)
32+
grid_size = 132
33+
34+
with TraceSession(OUT_PATH, grid_size=grid_size, block_size=g.threads_per_cta,
35+
region_names=["tma_load", "mma", "epilogue"]) as sess:
36+
gemm(A, B, D, C=None, tile_count_semaphore=None,
37+
tile_M=TILE_M, tile_N=TILE_N,
38+
cluster_M=CLUSTER_M, cluster_N=CLUSTER_N,
39+
persistent=True, pingpong=True, trace_ptr=sess.ptr)
40+
41+
# Verify correctness.
42+
ref = A[0] @ B[0].T
43+
print(f"max error: {(D[0] - ref).abs().max().item():.4f}")
44+
print(f"Open {OUT_PATH} in https://ui.perfetto.dev")
45+
46+
47+
if __name__ == "__main__":
48+
main()

quack/gemm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _compile_gemm(
5555
device_capacity,
5656
rounding_mode,
5757
sr_seed_mode,
58+
has_trace_ptr,
5859
):
5960
GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
6061
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
@@ -118,6 +119,7 @@ def fake_scalar(mode, dtype=Float32):
118119
epi_args,
119120
scheduler_args,
120121
varlen_args,
122+
has_trace_ptr=has_trace_ptr,
121123
)
122124

123125

@@ -147,6 +149,7 @@ def gemm(
147149
add_to_output: bool = False,
148150
rounding_mode: int = RoundingMode.RN,
149151
sr_seed: int | Tensor = 0,
152+
trace_ptr=None, # Optional Int64 from TraceSession.ptr
150153
) -> None:
151154
varlen_m = cu_seqlens_m is not None
152155
varlen_k = cu_seqlens_k is not None
@@ -216,6 +219,7 @@ def gemm(
216219
device_capacity,
217220
rounding_mode,
218221
sr_seed_mode,
222+
trace_ptr is not None,
219223
)
220224

221225
from quack.cache_utils import COMPILE_ONLY
@@ -251,6 +255,8 @@ def scalar_arg(scalar, mode, dtype=Float32):
251255
varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx)
252256

253257
if device_capacity[0] > 9:
254-
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)
258+
compiled_fn(
259+
A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr
260+
)
255261
else:
256-
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args)
262+
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr)

quack/gemm_sm90.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def __call__(
375375
scheduler_args: TileSchedulerOptions,
376376
varlen_args: Optional[VarlenArguments],
377377
stream: cuda.CUstream,
378+
trace_ptr: Optional[cutlass.Int64] = None,
378379
):
379380
"""Execute the GEMM operation in steps:
380381
- Setup static attributes
@@ -542,6 +543,7 @@ class SharedStorage:
542543
self.epi_c_smem_layout_staged,
543544
tile_sched_params,
544545
TileSchedulerCls,
546+
trace_ptr,
545547
).launch(
546548
grid=grid,
547549
block=[self.threads_per_cta, 1, 1],
@@ -573,6 +575,7 @@ def kernel(
573575
epi_c_smem_layout: cute.ComposedLayout,
574576
tile_sched_params,
575577
TileSchedulerCls: cutlass.Constexpr[Callable],
578+
trace_ptr: Optional[cutlass.Int64] = None,
576579
):
577580
"""
578581
GPU device kernel performing the batched GEMM computation.
@@ -601,6 +604,11 @@ def kernel(
601604
:type epi_smem_layout: cute.ComposedLayout
602605
"""
603606

607+
from quack.trace import TraceContext
608+
609+
GEMM_REGIONS = ("tma_load", "mma", "epilogue")
610+
tctx = TraceContext.create(trace_ptr, region_names=GEMM_REGIONS)
611+
604612
varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
605613
varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
606614
assert not (varlen_m and varlen_k)
@@ -703,6 +711,7 @@ def kernel(
703711
pipeline.PipelineUserType.Producer, self.ab_stage
704712
)
705713
while work_tile.is_valid_tile:
714+
tctx.b("tma_load")
706715
tile_coord_mnkl = work_tile.tile_idx
707716
batch_idx = tile_coord_mnkl[3]
708717
# Local_tile partition global tensors
@@ -804,6 +813,7 @@ def kernel(
804813
k_tile_cnt,
805814
varlen_m=varlen_m,
806815
)
816+
tctx.e("tma_load")
807817
tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
808818
work_tile = tile_scheduler.get_current_work()
809819
# End of persistent scheduler loop
@@ -882,16 +892,19 @@ def kernel(
882892
batch_idx = tile_coord_mnkl[3]
883893
len_k = varlen_manager.len_k(batch_idx)
884894
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
895+
tctx.b("mma")
885896
ab_read_state = self.mma(
886897
ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx
887898
)
899+
tctx.e("mma")
888900
if const_expr(varlen_k):
889901
if k_tile_cnt == 0:
890902
acc.fill(0.0)
891903

892904
# EPILOGUE
893905
if const_expr(self.pingpong):
894906
self.pingpong_barrier_sync(warp_group_idx, "epi")
907+
tctx.b("epilogue")
895908

896909
copy_D = None
897910
if const_expr(has_D):
@@ -966,6 +979,8 @@ def kernel(
966979
epi_store_pipeline.producer_tail()
967980
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
968981

982+
tctx.e("epilogue")
983+
969984
if const_expr(not self.pingpong):
970985
tile_scheduler.advance_to_next_work()
971986
work_tile = tile_scheduler.get_current_work()
@@ -994,6 +1009,8 @@ def kernel(
9941009
if is_tma_warp:
9951010
epi_store_pipeline.producer_tail()
9961011

1012+
tctx.flush()
1013+
9971014
@cute.jit
9981015
def load_AB(
9991016
self,

quack/gemm_tvm_ffi_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
import cutlass.cute as cute
8-
from cutlass import Int32, Float32
8+
from cutlass import Int32, Int64, Float32
99
from cutlass.cute.runtime import make_ptr
1010

1111
from quack.compile_utils import make_fake_tensor as fake_tensor
@@ -185,6 +185,7 @@ def compile_gemm_kernel(
185185
post_init=None,
186186
mSFA=None,
187187
mSFB=None,
188+
has_trace_ptr=False,
188189
):
189190
"""Build GemmCls instance, apply SM90 partial, and cute.compile with TVM-FFI."""
190191
if device_capacity[0] == 9:
@@ -202,6 +203,9 @@ def compile_gemm_kernel(
202203
post_init(gemm_obj)
203204
stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
204205
sf_args = () if device_capacity[0] == 9 else (mSFA, mSFB)
206+
# Trace pointer: Optional[Int64]. Compile with Int64(0) when tracing is
207+
# requested, None otherwise. TVM-FFI caches each variant separately.
208+
trace_ptr = Int64(0) if has_trace_ptr else None
205209
return cute.compile(
206210
gemm_obj,
207211
mA,
@@ -213,5 +217,6 @@ def compile_gemm_kernel(
213217
varlen_args,
214218
stream,
215219
*sf_args,
220+
trace_ptr,
216221
options="--enable-tvm-ffi",
217222
)

0 commit comments

Comments
 (0)