Skip to content

Commit 06ea256

Browse files
committed
add comm example
1 parent e4db2a7 commit 06ea256

2 files changed

Lines changed: 232 additions & 17 deletions

File tree

267 KB
Loading

advanced_source/cuda_graph_annotations_tutorial.py

Lines changed: 232 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
* How to profile annotated graphs
1717
* How to post-process traces with semantic kernel lanes
1818
* How to visualize graph execution with custom stream assignments
19+
* How to annotate communication collectives with the metadata
20+
(collective type, message size, group, rank) that eager NCCL
21+
traces expose but CUDA graphs drop
1922
2023
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
2124
:class-card: card-prerequisites
@@ -34,6 +37,14 @@
3437
labels to kernels within CUDA graphs. These annotations can be merged back into
3538
profiler traces to create custom visualization lanes, making it easier to
3639
understand and debug complex graph executions.
40+
41+
Annotations are not limited to compute kernels. One of the most valuable uses
42+
is annotating **communication collectives**. In eager mode, the profiler
43+
attaches rich metadata to every NCCL kernel -- the collective type, message
44+
size, process group, and ranks -- so you can see exactly what each comm is
45+
doing. Under CUDA graphs that metadata is lost: the collective replays as an
46+
opaque kernel. This tutorial shows how to re-attach that metadata with
47+
annotations so graphed comms read just like eager ones.
3748
"""
3849

3950
###############################################################################
@@ -62,17 +73,23 @@
6273
# belong to which logical component of your model.
6374
#
6475
# .. image:: /_static/img/cuda_graph_trace_before.png
65-
# :width: 100%
76+
# :width: 80%
6677
# :alt: CUDA graph trace before annotations showing all kernels on one stream
6778
#
6879
# **After annotations:** Kernels are organized into semantic lanes (streams 61
6980
# and 62) with meaningful labels like "attention" and "mlp", making it easy to
7081
# identify different components and understand the execution structure.
7182
#
7283
# .. image:: /_static/img/cuda_graph_trace_after.png
73-
# :width: 100%
84+
# :width: 80%
7485
# :alt: CUDA graph trace after annotations showing kernels organized by function
7586
#
87+
# As another example, here is an AllReduce kernel with annotated metadata:
88+
#
89+
# .. image:: /_static/img/annotated_cudagraph.png
90+
# :width: 80%
91+
# :alt: AllReduce kernel with annotated metadata
92+
#
7693
# Requirements
7794
# ------------
7895
#
@@ -95,15 +112,19 @@
95112

96113
import copy
97114
import math
115+
import os
98116
import pickle
99117
import sys
100118
from collections import Counter
101119
from pathlib import Path
102120

103121
import torch
122+
import torch.distributed as dist
123+
import torch.multiprocessing
104124
from torch.profiler import profile, ProfilerActivity
105125
from torch.cuda._graph_annotations import (
106126
get_kernel_annotations,
127+
get_stream_for_pg,
107128
mark_kernels,
108129
_is_tools_id_unavailable,
109130
)
@@ -398,7 +419,7 @@ def main():
398419
print(f"Raw trace: {raw_trace_path}")
399420
print(f"Annotated trace: {annotated_path}")
400421
print(f"Annotations: {annotations_path}")
401-
print("\nOpen the annotated trace in chrome://tracing to visualize")
422+
print("\nOpen the annotated trace in https://ui.perfetto.dev/ to visualize")
402423
print("the semantic kernel lanes.")
403424
print("="*60)
404425

@@ -442,25 +463,216 @@ def main():
442463
# Annotated trace: traces/trace_annotated.json.gz
443464
# Annotations: traces/kernel_annotations_rank0_fwd_bwd.pkl
444465
#
445-
# Open the annotated trace in chrome://tracing to visualize
466+
# Open the annotated trace in https://ui.perfetto.dev/ to visualize
446467
# the semantic kernel lanes.
447468
# ============================================================
448469

449470
###############################################################################
450-
# Visualizing Results
451-
# -------------------
452-
#
453-
# To view the annotated trace:
471+
# Annotating Communication Collectives
472+
# -------------------------------------
454473
#
455-
# 1. Open Chrome/Chromium browser
456-
# 2. Navigate to ``chrome://tracing``
457-
# 3. Click "Load" and select the ``trace_annotated.json.gz`` file
458-
# 4. You should see kernels organized into custom lanes like "qkv_proj",
459-
# "attention", "out_proj", and "mlp"
474+
# In eager mode the profiler **automatically intercepts** NCCL collectives and
475+
# records rich metadata: collective type, input/output message sizes, the process
476+
# group, its size, and the participating ranks.
477+
#
478+
# Under CUDA graphs that automatic interception stops working. The collective is
479+
# captured once and then replayed as an opaque kernel node. The profiler cannot
480+
# intercept graph replay, so it has nothing to attach the NCCL metadata to. The
481+
# kernels still show up in the trace (e.g., ``ncclDevKernel_AllReduce_Sum_f32_RING_LL``),
482+
# but they are opaque: you cannot tell what collective type it is, how many bytes
483+
# moved, or which process group it belongs to.
484+
#
485+
# Annotations close this gap. By wrapping the collective in ``mark_kernels``
486+
# with the same fields the profiler auto-attaches in eager mode, we manually
487+
# re-attach that metadata to the graphed kernel. After post-processing, a
488+
# graphed collective reads just like an eager one. The helper below builds the
489+
# metadata dict; using the field names the profiler uses in eager
490+
# (``In msg nelems``, ``Group size``, ``Process Group Name``, ...) keeps the
491+
# annotated trace consistent with non-graphed traces.
492+
493+
def annotate_collective(collective_name, input_tensor, output_tensor, group=None):
494+
"""Annotate a collective with the metadata eager NCCL traces expose.
495+
496+
Returns a ``mark_kernels`` context manager. Any kernels launched inside
497+
(i.e. the collective) are tagged with the collective type, message sizes,
498+
dtype, and the process group's name/description/ranks, and placed on a
499+
dedicated lane keyed by the process group so comms are visually separated
500+
from compute.
501+
502+
The field names match the keys the profiler records for eager collectives
503+
(``In msg nelems``, ``Group size``, ``Process Group Name``, ...), so an
504+
annotated graphed collective reads exactly like a non-graphed one.
505+
"""
506+
pg = group if group is not None else (dist.group.WORLD if dist.is_initialized() else None)
507+
ranks = dist.get_process_group_ranks(pg) if pg is not None else [0]
508+
group_name = getattr(pg, "group_name", "default")
509+
group_desc = getattr(pg, "group_desc", "default")
510+
511+
# NCCL always uses its own internal stream, so key the lane on the process
512+
# group (name + description) and give it a stable id (>= 60).
513+
pg_key = f"{group_name}_{group_desc}"
514+
annotation = {
515+
"name": collective_name,
516+
"In msg nelems": input_tensor.numel(),
517+
"Out msg nelems": output_tensor.numel(),
518+
"Group size": len(ranks),
519+
"dtype": str(input_tensor.dtype).replace("torch.", ""),
520+
"Process Group Name": group_name,
521+
"Process Group Description": group_desc,
522+
"Process Group Ranks": ranks,
523+
"stream": get_stream_for_pg(pg_key),
524+
}
525+
return mark_kernels(annotation)
526+
527+
###############################################################################
528+
# A Block That Mixes Compute and Communication
529+
# ----------------------------------------------
460530
#
461-
# The custom stream IDs (61, 62) specified in ``mark_kernels`` appear as
462-
# separate lanes, making it easy to see which operations run concurrently
463-
# or sequentially.
531+
# A tensor- or data-parallel layer interleaves matmuls with collectives. Here
532+
# the projection output is all-reduced across the group, mirroring the comm in
533+
# a tensor-parallel linear. The collective is annotated with
534+
# ``annotate_collective`` and lands on its own lane.
535+
536+
def build_comm_block(group=None):
537+
"""Create a compute + collective block annotated for profiling."""
538+
device = "cuda"
539+
torch.manual_seed(0)
540+
dim = 1024
541+
params = {
542+
"x": torch.randn(4, 256, dim, device=device),
543+
"W": torch.randn(dim, dim, device=device) / math.sqrt(dim),
544+
}
545+
546+
def forward():
547+
with mark_kernels({"name": "proj", "stream": 61}):
548+
h = params["x"] @ params["W"]
549+
550+
# All-reduce the projection output across the group (e.g. tensor
551+
# parallel). all_reduce is in-place, so the input and output tensors
552+
# are the same. The annotation re-attaches the NCCL metadata that a
553+
# CUDA graph would otherwise drop.
554+
if dist.is_available() and dist.is_initialized():
555+
with annotate_collective("all_reduce", h, h, group):
556+
dist.all_reduce(h)
557+
return h
558+
559+
return forward
560+
561+
###############################################################################
562+
# Running the Communication Demo
563+
# -------------------------------
564+
#
565+
566+
WORLD_SIZE = 2
567+
568+
def init_pg(rank, world_size):
569+
"""Initialize a NCCL group for one rank of the spawned demo."""
570+
os.environ["MASTER_ADDR"] = "127.0.0.1"
571+
os.environ["MASTER_PORT"] = "29500"
572+
os.environ["RANK"] = str(rank)
573+
os.environ["WORLD_SIZE"] = str(world_size)
574+
# Use loopback interface for single-node setup
575+
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
576+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
577+
torch.cuda.set_device(rank)
578+
579+
def _comm_worker(rank, world_size):
580+
"""Per-rank worker: build, capture, profile, and (on rank 0) post-process."""
581+
init_pg(rank, world_size)
582+
583+
output_dir = Path("traces_comm")
584+
585+
if rank == 0:
586+
print("\nBuilding compute + collective block...")
587+
model_fn = build_comm_block()
588+
589+
if rank == 0:
590+
print("Capturing CUDA graph with annotations...")
591+
graph, _ = capture_graph_with_annotations(model_fn)
592+
593+
# Every rank participates in the collective during profiling, but only
594+
# rank 0 saves and post-processes the trace.
595+
if rank == 0:
596+
annotations_path = save_annotations(output_dir)
597+
raw_trace_path = profile_graph(graph, output_dir)
598+
annotated_path, _, annotated_trace = post_process_trace(
599+
raw_trace_path, annotations_path, output_dir
600+
)
601+
602+
# Print the args of the annotated collective kernel(s) to show that the
603+
# eager-style metadata is now attached to the graphed comm.
604+
print("\nAnnotated collective kernels (metadata restored):")
605+
for event in annotated_trace["traceEvents"]:
606+
args = event.get("args", {})
607+
if args.get("In msg nelems") is not None:
608+
print(f" {event.get('name', '?')[:40]}")
609+
for key in (
610+
"In msg nelems",
611+
"Out msg nelems",
612+
"Group size",
613+
"dtype",
614+
"Process Group Name",
615+
"Process Group Description",
616+
"Process Group Ranks",
617+
"stream",
618+
):
619+
if key in args:
620+
print(f" {key}: {args[key]}")
621+
print(f"\nAnnotated trace: {annotated_path}")
622+
else:
623+
# Match rank 0's warmup + profiled replays so the collective completes.
624+
for _ in range(3):
625+
graph.replay()
626+
torch.cuda.synchronize()
627+
for _ in range(5):
628+
graph.replay()
629+
torch.cuda.synchronize()
630+
631+
dist.destroy_process_group()
632+
633+
def comm_annotation_demo():
634+
"""Spawn a ``world_size=2`` group and surface the comm metadata."""
635+
if not (dist.is_available() and torch.cuda.is_available()):
636+
print("Distributed/NCCL unavailable; skipping comm annotation demo.")
637+
return
638+
if torch.cuda.device_count() < WORLD_SIZE:
639+
print(f"Need {WORLD_SIZE} GPUs for the comm demo; skipping.")
640+
return
641+
642+
torch.multiprocessing.spawn(
643+
_comm_worker, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True
644+
)
645+
646+
# Example output (2 GPUs):
647+
# if __name__ == "__main__":
648+
# comm_annotation_demo()
649+
#
650+
# Building compute + collective block...
651+
# Capturing CUDA graph with annotations...
652+
# Captured graph with 2 annotated nodes
653+
# Saved 2 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl
654+
# Saved raw trace to traces_comm/trace_raw.json.gz
655+
# Annotated 5 kernels in the trace
656+
# Saved annotated trace to traces_comm/trace_annotated.json.gz
657+
#
658+
# The all_reduce runs a real NCCL kernel
659+
# (``ncclDevKernel_AllReduce_Sum_f32_RING_LL``) across the two ranks:
660+
#
661+
# Annotated collective kernels (metadata restored):
662+
# ncclDevKernel_AllReduce_Sum_f32_RING_LL
663+
# In msg nelems: 1048576
664+
# Out msg nelems: 1048576
665+
# Group size: 2
666+
# dtype: float32
667+
# Process Group Name: default
668+
# Process Group Description: default
669+
# Process Group Ranks: [0, 1]
670+
# stream: 60
671+
#
672+
# In the trace viewer, the all-reduce sits on its own dedicated comm lane
673+
# (stream 60), and selecting it shows the collective type, message sizes, group,
674+
# and ranks -- the same fields you would see in an eager trace, now recovered
675+
# for a CUDA-graphed collective. This metadata is LOST without annotations.
464676

465677
###############################################################################
466678
# Understanding the Cleanup Passes
@@ -528,8 +740,11 @@ def main():
528740
#
529741
# - Use ``mark_kernels()`` to label regions during graph capture
530742
# - Enable annotations with ``enable_annotations=True``
743+
# - Annotate communication collectives to recover the NCCL metadata
744+
# (collective type, message size, group, rank) that CUDA graphs drop but
745+
# eager traces expose
531746
# - Post-process traces with ``annotate_trace()`` and cleanup passes
532-
# - View results in chrome://tracing for intuitive visualization
747+
# - View results in https://ui.perfetto.dev/ for intuitive visualization
533748
#
534749
# This technique is especially valuable for large models with many components,
535750
# distributed training setups, or any scenario where understanding the

0 commit comments

Comments
 (0)