Skip to content

Commit 8d7b16b

Browse files
committed
[Tool] Implement a trace profiler
1 parent 171b15b commit 8d7b16b

5 files changed

Lines changed: 1326 additions & 1 deletion

File tree

examples/example_trace.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env python3
2+
"""Minimal example: intra-kernel trace profiler in CuTe-DSL.
3+
4+
Run with: QUACK_TRACE=1 python examples/example_trace.py
5+
Run without: python examples/example_trace.py
6+
"""
7+
8+
from typing import Optional
9+
10+
import cutlass
11+
import cutlass.cute as cute
12+
from cutlass.cutlass_dsl import Int32, Int64
13+
14+
from quack.trace import TraceContext, TraceSession
15+
16+
ITERS = 1000
17+
BLOCK_THREADS = 128
18+
GRID = 4
19+
REGION_NAMES = ("loop_body", "kernel_total")
20+
21+
22+
@cute.kernel
23+
def trace_record_kernel(trace_ptr: Optional[Int64], iters: Int32):
24+
ctx = TraceContext.create(trace_ptr, region_names=REGION_NAMES)
25+
ctx.b("kernel_total")
26+
for i in cutlass.range(iters):
27+
ctx.b("loop_body")
28+
ctx.e("loop_body")
29+
ctx.e("kernel_total")
30+
ctx.flush()
31+
32+
33+
@cute.jit
34+
def launch(trace_ptr: Optional[Int64], iters: Int32):
35+
trace_record_kernel(trace_ptr, iters).launch(
36+
grid=(GRID, 1, 1), block=(BLOCK_THREADS, 1, 1),
37+
)
38+
39+
40+
def main():
41+
out_path = "/tmp/cute_dsl_trace.json"
42+
43+
# sess.ptr is None when QUACK_TRACE != 1 → TraceContext.create becomes a no-op.
44+
with TraceSession(out_path, grid_size=GRID, block_size=BLOCK_THREADS,
45+
region_names=list(REGION_NAMES)) as sess:
46+
launch(sess.ptr, ITERS)
47+
48+
if sess.ptr is not None:
49+
print(f"Open {out_path} in https://ui.perfetto.dev to visualize.")
50+
51+
52+
if __name__ == "__main__":
53+
main()

quack/copy_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import cutlass.pipeline
1414
from cutlass._mlir.dialects import llvm
1515
from cutlass._mlir import ir
16+
17+
from quack.utils import make_vector
1618
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
1719

1820

@@ -1029,3 +1031,87 @@ def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
10291031
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
10301032

10311033
return copy_fn
1034+
1035+
1036+
# ---------------------------------------------------------------------------
1037+
# Store helpers
1038+
# ---------------------------------------------------------------------------
1039+
1040+
1041+
@dsl_user_op
1042+
@cute.jit
1043+
def store(
1044+
ptr: cute.Pointer,
1045+
val,
1046+
pred: Optional[Boolean] = None,
1047+
cop: cutlass.Constexpr = None,
1048+
*,
1049+
loc=None,
1050+
ip=None,
1051+
):
1052+
"""Store a scalar value via cute.arch.store.
1053+
1054+
ptr: cute.Pointer (any address space).
1055+
val: DSL Numeric value.
1056+
pred: None → unconditional. DSL Boolean → skipped when pred == 0.
1057+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1058+
"""
1059+
if const_expr(pred is None):
1060+
cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
1061+
else:
1062+
if pred:
1063+
cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
1064+
1065+
1066+
@dsl_user_op
1067+
@cute.jit
1068+
def store_v2(
1069+
ptr: cute.Pointer,
1070+
v0,
1071+
v1,
1072+
pred: Optional[Boolean] = None,
1073+
cop: cutlass.Constexpr = None,
1074+
*,
1075+
loc=None,
1076+
ip=None,
1077+
):
1078+
"""Vectorized store of 2 elements via cute.arch.store.
1079+
1080+
Packs v0, v1 into an MLIR <2 x T> vector.
1081+
ptr: cute.Pointer (any address space, must be aligned for vector width).
1082+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1083+
"""
1084+
vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip)
1085+
if const_expr(pred is None):
1086+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1087+
else:
1088+
if pred:
1089+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1090+
1091+
1092+
@dsl_user_op
1093+
@cute.jit
1094+
def store_v4(
1095+
ptr: cute.Pointer,
1096+
v0,
1097+
v1,
1098+
v2,
1099+
v3,
1100+
pred: Optional[Boolean] = None,
1101+
cop: cutlass.Constexpr = None,
1102+
*,
1103+
loc=None,
1104+
ip=None,
1105+
):
1106+
"""Vectorized store of 4 elements via cute.arch.store.
1107+
1108+
Packs v0–v3 into an MLIR <4 x T> vector.
1109+
ptr: cute.Pointer (any address space, must be aligned for vector width).
1110+
cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1111+
"""
1112+
vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip)
1113+
if const_expr(pred is None):
1114+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1115+
else:
1116+
if pred:
1117+
cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)

0 commit comments

Comments
 (0)