2727 * CUDA-capable GPU
2828 * Driver/CUDA-compat >= 13.1 for annotation support
2929 * cuda-bindings >= 13.1.0
30+ * perfetto (``pip install perfetto``)
3031
3132CUDA graphs are a powerful optimization technique that can significantly reduce
3233kernel launch overhead by capturing and replaying sequences of CUDA operations.
99100# - A CUDA GPU
100101# - Driver/CUDA-compat >= 13.1 for annotation support
101102# - The ``cuda-bindings`` package >= 13.1.0 (``pip install cuda-python``)
103+ # - The ``perfetto`` package for writing the trace (``pip install perfetto``)
102104#
103105# The cuda-bindings package provides the Python bindings for CUDA runtime APIs.
104106# Version 13.1.0+ is required for the ``cudaGraphNodeGetToolsId`` API that
111113# appear in the final trace.
112114
113115import copy
116+ import hashlib
117+ import json
114118import math
115119import os
116120import pickle
117121import sys
118- from collections import Counter
122+ from collections import Counter , defaultdict
119123from pathlib import Path
120124
121125import torch
131135from torch .cuda ._annotate_cuda_graph_trace import (
132136 annotate_trace ,
133137 load_trace ,
134- save_trace ,
135- _fix_overlapping_timestamps ,
136- _move_overlapping_to_stream ,
137138)
138139
139140###############################################################################
@@ -297,13 +298,185 @@ def save_annotations(output_dir):
297298#
298299# 1. Loading the raw trace and annotations
299300# 2. Calling ``annotate_trace()`` to apply the annotations
300- # 3. Running cleanup passes to handle overlapping kernels
301- # 4. Saving the annotated trace
301+ # 3. Emitting a native Perfetto ``.pftrace`` that preserves overlapping kernels
302+ # on their real stream
302303#
303304# The result is a trace where kernels are organized by your semantic labels.
305+ #
306+ # **Why a Perfetto protobuf trace (not Chrome JSON)?** A Chrome JSON trace --
307+ # the format ``torch.profiler.export_chrome_trace`` produces -- has a
308+ # fundamental limitation: a single track (a ``(pid, tid)`` row) can only show
309+ # **properly nested** slices, never crossing/overlapping ones.
310+ #
311+ # Perfetto's native **protobuf** trace (``.pftrace``) solves this
312+ # via the ``TrackDescriptor`` field ``sibling_merge_key``. We split
313+ # overlapping slices across hidden *backing* tracks (so each protobuf
314+ # begin/end stack stays validly nested), then give those backing tracks the
315+ # **same** ``sibling_merge_key`` so the Perfetto UI merges them back into a
316+ # single logical row. Nothing is relocated to a fake stream and no timestamp is
317+ # clamped -- the overlap is shown faithfully on the kernel's real stream.
318+ #
319+ # This converter is adapted from Driss Guessous's `transformer_nuggets
320+ # <https://github.com/drisspg/transformer_nuggets>`_
321+ # (``transformer_nuggets/utils/track_event.py``); we inline a compact,
322+ # self-contained version here. It needs the ``perfetto`` package
323+ # (``pip install perfetto``).
324+
325+ def _stable_uuid (* parts ):
326+ """A stable 60-bit track UUID derived from its identifying parts."""
327+ digest = hashlib .sha1 (":" .join (str (p ) for p in parts ).encode ()).hexdigest ()
328+ return int (digest [:15 ], 16 )
329+
330+
331+ def _assign_nesting_lanes (slices ):
332+ """Split overlapping slices into backing lanes so each lane is nestable.
333+
334+ A lane only holds slices that are either disjoint or fully contained, so a
335+ begin/end stack on that lane never has crossing slices. Returns
336+ ``(lane_of_index, lane_count)``. The lane is a *backing* track index, not a
337+ user-visible stream -- lanes sharing a stream are merged back in the UI.
338+ """
339+ order = sorted (
340+ range (len (slices )),
341+ key = lambda i : (slices [i ]["ts" ], - slices [i ]["end" ], slices [i ]["index" ]),
342+ )
343+ lane_of = {}
344+ lane_end_stacks = []
345+ for i in order :
346+ s = slices [i ]
347+ assigned = None
348+ for lane , stack in enumerate (lane_end_stacks ):
349+ while stack and stack [- 1 ] <= s ["ts" ]:
350+ stack .pop ()
351+ # Valid if the lane is free or this slice nests inside the open one.
352+ if not stack or s ["end" ] <= stack [- 1 ]:
353+ stack .append (s ["end" ])
354+ assigned = lane
355+ break
356+ if assigned is None :
357+ lane_end_stacks .append ([s ["end" ]])
358+ assigned = len (lane_end_stacks ) - 1
359+ lane_of [i ] = assigned
360+ return lane_of , len (lane_end_stacks )
361+
362+
363+ def _add_debug_annotation (track_event , name , value ):
364+ """Carry a Chrome event arg over as a typed Perfetto debug annotation."""
365+ ann = track_event .debug_annotations .add ()
366+ ann .name = str (name )
367+ # bool must be checked before int (bool is a subclass of int in Python).
368+ if isinstance (value , bool ):
369+ ann .bool_value = value
370+ elif isinstance (value , int ):
371+ ann .int_value = value
372+ elif isinstance (value , float ):
373+ ann .double_value = value
374+ elif value is None :
375+ ann .string_value = "null"
376+ elif isinstance (value , str ):
377+ ann .string_value = value
378+ else :
379+ ann .legacy_json_value = json .dumps (value , default = str )
380+
381+
382+ def write_perfetto_trace (trace , output_path ):
383+ """Convert a Chrome JSON trace dict to a native Perfetto ``.pftrace``.
384+
385+ Each Chrome ``(pid, tid)`` row becomes a ``TrackDescriptor``; each ``ph='X'``
386+ slice becomes a ``TYPE_SLICE_BEGIN`` / ``TYPE_SLICE_END`` pair. Overlapping
387+ slices are split across backing lanes that share a ``sibling_merge_key`` so
388+ the UI re-merges them onto their real stream.
389+ """
390+ from perfetto .trace_builder .proto_builder import TraceProtoBuilder
391+ from perfetto .protos .perfetto .trace .perfetto_trace_pb2 import (
392+ TrackDescriptor ,
393+ TrackEvent ,
394+ )
395+
396+ events = trace ["traceEvents" ]
397+
398+ # Collect the process/thread names emitted as metadata ('M') events.
399+ process_names , thread_names = {}, {}
400+ for e in events :
401+ if e .get ("ph" ) == "M" :
402+ if e .get ("name" ) == "process_name" :
403+ process_names [e .get ("pid" )] = e .get ("args" , {}).get ("name" , "" )
404+ elif e .get ("name" ) == "thread_name" :
405+ key = (e .get ("pid" ), e .get ("tid" ))
406+ thread_names [key ] = e .get ("args" , {}).get ("name" , "" )
407+
408+ # Group complete ('X') slices by their (pid, tid) track.
409+ slices_by_track = defaultdict (list )
410+ for i , e in enumerate (events ):
411+ if e .get ("ph" ) == "X" :
412+ ts = float (e .get ("ts" , 0 ) or 0 )
413+ dur = float (e .get ("dur" , 0 ) or 0 )
414+ slices_by_track [(e .get ("pid" ), e .get ("tid" ))].append (
415+ {"event" : e , "index" : i , "ts" : ts , "end" : ts + dur }
416+ )
417+
418+ def ts_us_to_ns (value ):
419+ return int (round (value * 1000.0 ))
420+
421+ builder = TraceProtoBuilder ()
422+ SEQ = 1
423+
424+ # One descriptor per process.
425+ for pid in {pid for (pid , _tid ) in slices_by_track }:
426+ pkt = builder .add_packet ()
427+ desc = pkt .track_descriptor
428+ desc .uuid = _stable_uuid ("process" , pid )
429+ desc .name = process_names .get (pid , f"process { pid } " )
430+
431+ # One descriptor per backing lane; emit begin/end markers per slice.
432+ markers = []
433+ for (pid , tid ), slices in slices_by_track .items ():
434+ lane_of , lane_count = _assign_nesting_lanes (slices )
435+ name = thread_names .get ((pid , tid ), f"stream { tid } " )
436+ lane_uuids = []
437+ for lane in range (lane_count ):
438+ uuid = _stable_uuid ("track" , pid , tid , lane )
439+ lane_uuids .append (uuid )
440+ pkt = builder .add_packet ()
441+ desc = pkt .track_descriptor
442+ desc .uuid = uuid
443+ desc .parent_uuid = _stable_uuid ("process" , pid )
444+ desc .name = name
445+ # Multiple lanes for one stream -> merge them into one UI row.
446+ if lane_count > 1 :
447+ desc .sibling_merge_behavior = (
448+ TrackDescriptor .SIBLING_MERGE_BEHAVIOR_BY_SIBLING_MERGE_KEY
449+ )
450+ desc .sibling_merge_key = f"{ pid } :{ tid } :{ name } "
451+ for i , s in enumerate (slices ):
452+ uuid = lane_uuids [lane_of [i ]]
453+ markers .append ((ts_us_to_ns (s ["ts" ]), 1 , uuid , "begin" , s ["event" ]))
454+ markers .append ((ts_us_to_ns (s ["end" ]), 0 , uuid , "end" , s ["event" ]))
455+
456+ # Begin markers must be ordered before end markers at the same timestamp.
457+ markers .sort (key = lambda m : (m [0 ], m [1 ]))
458+ for ts_ns , _rank , uuid , kind , event in markers :
459+ pkt = builder .add_packet ()
460+ pkt .timestamp = ts_ns
461+ pkt .trusted_packet_sequence_id = SEQ
462+ track_event = pkt .track_event
463+ track_event .track_uuid = uuid
464+ if kind == "begin" :
465+ track_event .type = TrackEvent .TYPE_SLICE_BEGIN
466+ track_event .name = event .get ("name" , "slice" )
467+ if event .get ("cat" ):
468+ track_event .categories .append (event ["cat" ])
469+ for key , value in (event .get ("args" ) or {}).items ():
470+ _add_debug_annotation (track_event , key , value )
471+ else :
472+ track_event .type = TrackEvent .TYPE_SLICE_END
473+
474+ Path (output_path ).write_bytes (builder .serialize ())
475+ return output_path
476+
304477
305478def post_process_trace (raw_trace_path , annotations_path , output_dir ):
306- """Merge annotations into the trace and apply cleanup ."""
479+ """Merge annotations into the trace and emit a Perfetto ``.pftrace`` ."""
307480 output_dir = Path (output_dir )
308481
309482 # Load raw trace and annotations
@@ -318,13 +491,11 @@ def post_process_trace(raw_trace_path, annotations_path, output_dir):
318491 num_annotated = annotate_trace (annotated_trace , annotations )
319492 print (f"Annotated { num_annotated } kernels in the trace" )
320493
321- # Cleanup passes: move overlapping kernels and fix timestamps
322- _move_overlapping_to_stream (annotated_trace )
323- _fix_overlapping_timestamps (annotated_trace )
324-
325- # Save the annotated trace
326- annotated_path = output_dir / "trace_annotated.json.gz"
327- save_trace (annotated_trace , annotated_path )
494+ # Emit a native Perfetto protobuf trace. Overlapping kernels are split onto
495+ # backing lanes that re-merge in the UI -- no kernel is relocated to a fake
496+ # stream and no timestamp is mutated.
497+ annotated_path = output_dir / "trace_annotated.pftrace"
498+ write_perfetto_trace (annotated_trace , annotated_path )
328499 print (f"Saved annotated trace to { annotated_path } " )
329500
330501 return annotated_path , raw_trace , annotated_trace
@@ -442,7 +613,7 @@ def main():
442613#
443614# 5. Post-processing: merging annotations into trace...
444615# Annotated 65 kernels in the trace
445- # Saved annotated trace to traces/trace_annotated.json.gz
616+ # Saved annotated trace to traces/trace_annotated.pftrace
446617#
447618# 6. Comparing traces...
448619#
@@ -460,7 +631,7 @@ def main():
460631# SUMMARY
461632# ============================================================
462633# Raw trace: traces/trace_raw.json.gz
463- # Annotated trace: traces/trace_annotated.json.gz
634+ # Annotated trace: traces/trace_annotated.pftrace
464635# Annotations: traces/kernel_annotations_rank0_fwd_bwd.pkl
465636#
466637# Open the annotated trace in https://ui.perfetto.dev/ to visualize
@@ -653,7 +824,7 @@ def comm_annotation_demo():
653824# Saved 2 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl
654825# Saved raw trace to traces_comm/trace_raw.json.gz
655826# Annotated 5 kernels in the trace
656- # Saved annotated trace to traces_comm/trace_annotated.json.gz
827+ # Saved annotated trace to traces_comm/trace_annotated.pftrace
657828#
658829# The all_reduce runs a real NCCL kernel
659830# (``ncclDevKernel_AllReduce_Sum_f32_RING_LL``) across the two ranks:
@@ -675,21 +846,23 @@ def comm_annotation_demo():
675846# for a CUDA-graphed collective. This metadata is LOST without annotations.
676847
677848###############################################################################
678- # Understanding the Cleanup Passes
679- # ---------------------------------
849+ # How Overlapping Kernels Are Handled
850+ # ------------------------------------
680851#
681- # The post-processing applies two cleanup functions:
852+ # Graphed CUDA kernels often overlap slightly, and a single trace track can
853+ # only render properly nested slices. The Perfetto converter handles this
854+ # faithfully:
682855#
683- # 1. ``_move_overlapping_to_stream ()``: If kernels on the same lane overlap
684- # in time, move one to a different lane. This prevents visual overlap in
685- # the trace viewer .
856+ # 1. ``_assign_nesting_lanes ()``: For each stream, overlapping slices are split
857+ # across hidden *backing* lanes so that each lane's begin/end stack is validly
858+ # nested. A lane is a backing track index, **not** a user-visible stream .
686859#
687- # 2. ``_fix_overlapping_timestamps()``: Adjust timestamps slightly if
688- # overlapping kernels would cause confusion. This is a last resort to
689- # ensure the trace renders correctly.
860+ # 2. ``sibling_merge_key``: All backing lanes for one stream are given the same
861+ # merge key, so the Perfetto UI merges them back into a single logical row.
690862#
691- # These passes ensure that the trace is both accurate and readable, even
692- # when the original execution has complex concurrency patterns.
863+ # The result: overlaps render correctly on the kernel's **real** stream. No
864+ # kernel is relocated to a fabricated stream, and no timestamp is mutated --
865+ # unlike the legacy Chrome-JSON workaround, which had to do both.
693866
694867###############################################################################
695868# Performance Considerations
0 commit comments