Skip to content

Commit a824696

Browse files
committed
add ptrace
1 parent 06ea256 commit a824696

1 file changed

Lines changed: 201 additions & 28 deletions

File tree

advanced_source/cuda_graph_annotations_tutorial.py

Lines changed: 201 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
* CUDA-capable GPU
2828
* Driver/CUDA-compat >= 13.1 for annotation support
2929
* cuda-bindings >= 13.1.0
30+
* perfetto (``pip install perfetto``)
3031
3132
CUDA graphs are a powerful optimization technique that can significantly reduce
3233
kernel launch overhead by capturing and replaying sequences of CUDA operations.
@@ -99,6 +100,7 @@
99100
# - A CUDA GPU
100101
# - Driver/CUDA-compat >= 13.1 for annotation support
101102
# - The ``cuda-bindings`` package >= 13.1.0 (``pip install cuda-python``)
103+
# - The ``perfetto`` package for writing the trace (``pip install perfetto``)
102104
#
103105
# The cuda-bindings package provides the Python bindings for CUDA runtime APIs.
104106
# Version 13.1.0+ is required for the ``cudaGraphNodeGetToolsId`` API that
@@ -111,11 +113,13 @@
111113
# appear in the final trace.
112114

113115
import copy
116+
import hashlib
117+
import json
114118
import math
115119
import os
116120
import pickle
117121
import sys
118-
from collections import Counter
122+
from collections import Counter, defaultdict
119123
from pathlib import Path
120124

121125
import torch
@@ -131,9 +135,6 @@
131135
from torch.cuda._annotate_cuda_graph_trace import (
132136
annotate_trace,
133137
load_trace,
134-
save_trace,
135-
_fix_overlapping_timestamps,
136-
_move_overlapping_to_stream,
137138
)
138139

139140
###############################################################################
@@ -297,13 +298,185 @@ def save_annotations(output_dir):
297298
#
298299
# 1. Loading the raw trace and annotations
299300
# 2. Calling ``annotate_trace()`` to apply the annotations
300-
# 3. Running cleanup passes to handle overlapping kernels
301-
# 4. Saving the annotated trace
301+
# 3. Emitting a native Perfetto ``.pftrace`` that preserves overlapping kernels
302+
# on their real stream
302303
#
303304
# The result is a trace where kernels are organized by your semantic labels.
305+
#
306+
# **Why a Perfetto protobuf trace (not Chrome JSON)?** A Chrome JSON trace --
307+
# the format ``torch.profiler.export_chrome_trace`` produces -- has a
308+
# fundamental limitation: a single track (a ``(pid, tid)`` row) can only show
309+
# **properly nested** slices, never crossing/overlapping ones.
310+
#
311+
# Perfetto's native **protobuf** trace (``.pftrace``) solves this
312+
# via the ``TrackDescriptor`` field ``sibling_merge_key``. We split
313+
# overlapping slices across hidden *backing* tracks (so each protobuf
314+
# begin/end stack stays validly nested), then give those backing tracks the
315+
# **same** ``sibling_merge_key`` so the Perfetto UI merges them back into a
316+
# single logical row. Nothing is relocated to a fake stream and no timestamp is
317+
# clamped -- the overlap is shown faithfully on the kernel's real stream.
318+
#
319+
# This converter is adapted from Driss Guessous's `transformer_nuggets
320+
# <https://github.com/drisspg/transformer_nuggets>`_
321+
# (``transformer_nuggets/utils/track_event.py``); we inline a compact,
322+
# self-contained version here. It needs the ``perfetto`` package
323+
# (``pip install perfetto``).
324+
325+
def _stable_uuid(*parts):
326+
"""A stable 60-bit track UUID derived from its identifying parts."""
327+
digest = hashlib.sha1(":".join(str(p) for p in parts).encode()).hexdigest()
328+
return int(digest[:15], 16)
329+
330+
331+
def _assign_nesting_lanes(slices):
332+
"""Split overlapping slices into backing lanes so each lane is nestable.
333+
334+
A lane only holds slices that are either disjoint or fully contained, so a
335+
begin/end stack on that lane never has crossing slices. Returns
336+
``(lane_of_index, lane_count)``. The lane is a *backing* track index, not a
337+
user-visible stream -- lanes sharing a stream are merged back in the UI.
338+
"""
339+
order = sorted(
340+
range(len(slices)),
341+
key=lambda i: (slices[i]["ts"], -slices[i]["end"], slices[i]["index"]),
342+
)
343+
lane_of = {}
344+
lane_end_stacks = []
345+
for i in order:
346+
s = slices[i]
347+
assigned = None
348+
for lane, stack in enumerate(lane_end_stacks):
349+
while stack and stack[-1] <= s["ts"]:
350+
stack.pop()
351+
# Valid if the lane is free or this slice nests inside the open one.
352+
if not stack or s["end"] <= stack[-1]:
353+
stack.append(s["end"])
354+
assigned = lane
355+
break
356+
if assigned is None:
357+
lane_end_stacks.append([s["end"]])
358+
assigned = len(lane_end_stacks) - 1
359+
lane_of[i] = assigned
360+
return lane_of, len(lane_end_stacks)
361+
362+
363+
def _add_debug_annotation(track_event, name, value):
364+
"""Carry a Chrome event arg over as a typed Perfetto debug annotation."""
365+
ann = track_event.debug_annotations.add()
366+
ann.name = str(name)
367+
# bool must be checked before int (bool is a subclass of int in Python).
368+
if isinstance(value, bool):
369+
ann.bool_value = value
370+
elif isinstance(value, int):
371+
ann.int_value = value
372+
elif isinstance(value, float):
373+
ann.double_value = value
374+
elif value is None:
375+
ann.string_value = "null"
376+
elif isinstance(value, str):
377+
ann.string_value = value
378+
else:
379+
ann.legacy_json_value = json.dumps(value, default=str)
380+
381+
382+
def write_perfetto_trace(trace, output_path):
383+
"""Convert a Chrome JSON trace dict to a native Perfetto ``.pftrace``.
384+
385+
Each Chrome ``(pid, tid)`` row becomes a ``TrackDescriptor``; each ``ph='X'``
386+
slice becomes a ``TYPE_SLICE_BEGIN`` / ``TYPE_SLICE_END`` pair. Overlapping
387+
slices are split across backing lanes that share a ``sibling_merge_key`` so
388+
the UI re-merges them onto their real stream.
389+
"""
390+
from perfetto.trace_builder.proto_builder import TraceProtoBuilder
391+
from perfetto.protos.perfetto.trace.perfetto_trace_pb2 import (
392+
TrackDescriptor,
393+
TrackEvent,
394+
)
395+
396+
events = trace["traceEvents"]
397+
398+
# Collect the process/thread names emitted as metadata ('M') events.
399+
process_names, thread_names = {}, {}
400+
for e in events:
401+
if e.get("ph") == "M":
402+
if e.get("name") == "process_name":
403+
process_names[e.get("pid")] = e.get("args", {}).get("name", "")
404+
elif e.get("name") == "thread_name":
405+
key = (e.get("pid"), e.get("tid"))
406+
thread_names[key] = e.get("args", {}).get("name", "")
407+
408+
# Group complete ('X') slices by their (pid, tid) track.
409+
slices_by_track = defaultdict(list)
410+
for i, e in enumerate(events):
411+
if e.get("ph") == "X":
412+
ts = float(e.get("ts", 0) or 0)
413+
dur = float(e.get("dur", 0) or 0)
414+
slices_by_track[(e.get("pid"), e.get("tid"))].append(
415+
{"event": e, "index": i, "ts": ts, "end": ts + dur}
416+
)
417+
418+
def ts_us_to_ns(value):
419+
return int(round(value * 1000.0))
420+
421+
builder = TraceProtoBuilder()
422+
SEQ = 1
423+
424+
# One descriptor per process.
425+
for pid in {pid for (pid, _tid) in slices_by_track}:
426+
pkt = builder.add_packet()
427+
desc = pkt.track_descriptor
428+
desc.uuid = _stable_uuid("process", pid)
429+
desc.name = process_names.get(pid, f"process {pid}")
430+
431+
# One descriptor per backing lane; emit begin/end markers per slice.
432+
markers = []
433+
for (pid, tid), slices in slices_by_track.items():
434+
lane_of, lane_count = _assign_nesting_lanes(slices)
435+
name = thread_names.get((pid, tid), f"stream {tid}")
436+
lane_uuids = []
437+
for lane in range(lane_count):
438+
uuid = _stable_uuid("track", pid, tid, lane)
439+
lane_uuids.append(uuid)
440+
pkt = builder.add_packet()
441+
desc = pkt.track_descriptor
442+
desc.uuid = uuid
443+
desc.parent_uuid = _stable_uuid("process", pid)
444+
desc.name = name
445+
# Multiple lanes for one stream -> merge them into one UI row.
446+
if lane_count > 1:
447+
desc.sibling_merge_behavior = (
448+
TrackDescriptor.SIBLING_MERGE_BEHAVIOR_BY_SIBLING_MERGE_KEY
449+
)
450+
desc.sibling_merge_key = f"{pid}:{tid}:{name}"
451+
for i, s in enumerate(slices):
452+
uuid = lane_uuids[lane_of[i]]
453+
markers.append((ts_us_to_ns(s["ts"]), 1, uuid, "begin", s["event"]))
454+
markers.append((ts_us_to_ns(s["end"]), 0, uuid, "end", s["event"]))
455+
456+
# Begin markers must be ordered before end markers at the same timestamp.
457+
markers.sort(key=lambda m: (m[0], m[1]))
458+
for ts_ns, _rank, uuid, kind, event in markers:
459+
pkt = builder.add_packet()
460+
pkt.timestamp = ts_ns
461+
pkt.trusted_packet_sequence_id = SEQ
462+
track_event = pkt.track_event
463+
track_event.track_uuid = uuid
464+
if kind == "begin":
465+
track_event.type = TrackEvent.TYPE_SLICE_BEGIN
466+
track_event.name = event.get("name", "slice")
467+
if event.get("cat"):
468+
track_event.categories.append(event["cat"])
469+
for key, value in (event.get("args") or {}).items():
470+
_add_debug_annotation(track_event, key, value)
471+
else:
472+
track_event.type = TrackEvent.TYPE_SLICE_END
473+
474+
Path(output_path).write_bytes(builder.serialize())
475+
return output_path
476+
304477

305478
def post_process_trace(raw_trace_path, annotations_path, output_dir):
306-
"""Merge annotations into the trace and apply cleanup."""
479+
"""Merge annotations into the trace and emit a Perfetto ``.pftrace``."""
307480
output_dir = Path(output_dir)
308481

309482
# Load raw trace and annotations
@@ -318,13 +491,11 @@ def post_process_trace(raw_trace_path, annotations_path, output_dir):
318491
num_annotated = annotate_trace(annotated_trace, annotations)
319492
print(f"Annotated {num_annotated} kernels in the trace")
320493

321-
# Cleanup passes: move overlapping kernels and fix timestamps
322-
_move_overlapping_to_stream(annotated_trace)
323-
_fix_overlapping_timestamps(annotated_trace)
324-
325-
# Save the annotated trace
326-
annotated_path = output_dir / "trace_annotated.json.gz"
327-
save_trace(annotated_trace, annotated_path)
494+
# Emit a native Perfetto protobuf trace. Overlapping kernels are split onto
495+
# backing lanes that re-merge in the UI -- no kernel is relocated to a fake
496+
# stream and no timestamp is mutated.
497+
annotated_path = output_dir / "trace_annotated.pftrace"
498+
write_perfetto_trace(annotated_trace, annotated_path)
328499
print(f"Saved annotated trace to {annotated_path}")
329500

330501
return annotated_path, raw_trace, annotated_trace
@@ -442,7 +613,7 @@ def main():
442613
#
443614
# 5. Post-processing: merging annotations into trace...
444615
# Annotated 65 kernels in the trace
445-
# Saved annotated trace to traces/trace_annotated.json.gz
616+
# Saved annotated trace to traces/trace_annotated.pftrace
446617
#
447618
# 6. Comparing traces...
448619
#
@@ -460,7 +631,7 @@ def main():
460631
# SUMMARY
461632
# ============================================================
462633
# Raw trace: traces/trace_raw.json.gz
463-
# Annotated trace: traces/trace_annotated.json.gz
634+
# Annotated trace: traces/trace_annotated.pftrace
464635
# Annotations: traces/kernel_annotations_rank0_fwd_bwd.pkl
465636
#
466637
# Open the annotated trace in https://ui.perfetto.dev/ to visualize
@@ -653,7 +824,7 @@ def comm_annotation_demo():
653824
# Saved 2 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl
654825
# Saved raw trace to traces_comm/trace_raw.json.gz
655826
# Annotated 5 kernels in the trace
656-
# Saved annotated trace to traces_comm/trace_annotated.json.gz
827+
# Saved annotated trace to traces_comm/trace_annotated.pftrace
657828
#
658829
# The all_reduce runs a real NCCL kernel
659830
# (``ncclDevKernel_AllReduce_Sum_f32_RING_LL``) across the two ranks:
@@ -675,21 +846,23 @@ def comm_annotation_demo():
675846
# for a CUDA-graphed collective. This metadata is LOST without annotations.
676847

677848
###############################################################################
678-
# Understanding the Cleanup Passes
679-
# ---------------------------------
849+
# How Overlapping Kernels Are Handled
850+
# ------------------------------------
680851
#
681-
# The post-processing applies two cleanup functions:
852+
# Graphed CUDA kernels often overlap slightly, and a single trace track can
853+
# only render properly nested slices. The Perfetto converter handles this
854+
# faithfully:
682855
#
683-
# 1. ``_move_overlapping_to_stream()``: If kernels on the same lane overlap
684-
# in time, move one to a different lane. This prevents visual overlap in
685-
# the trace viewer.
856+
# 1. ``_assign_nesting_lanes()``: For each stream, overlapping slices are split
857+
# across hidden *backing* lanes so that each lane's begin/end stack is validly
858+
# nested. A lane is a backing track index, **not** a user-visible stream.
686859
#
687-
# 2. ``_fix_overlapping_timestamps()``: Adjust timestamps slightly if
688-
# overlapping kernels would cause confusion. This is a last resort to
689-
# ensure the trace renders correctly.
860+
# 2. ``sibling_merge_key``: All backing lanes for one stream are given the same
861+
# merge key, so the Perfetto UI merges them back into a single logical row.
690862
#
691-
# These passes ensure that the trace is both accurate and readable, even
692-
# when the original execution has complex concurrency patterns.
863+
# The result: overlaps render correctly on the kernel's **real** stream. No
864+
# kernel is relocated to a fabricated stream, and no timestamp is mutated --
865+
# unlike the legacy Chrome-JSON workaround, which had to do both.
693866

694867
###############################################################################
695868
# Performance Considerations

0 commit comments

Comments
 (0)