Skip to content

Commit b95d72f

Browse files
committed
add comm example
1 parent e4db2a7 commit b95d72f

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

advanced_source/cuda_graph_annotations_tutorial.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
* How to profile annotated graphs
1717
* How to post-process traces with semantic kernel lanes
1818
* How to visualize graph execution with custom stream assignments
19+
* How to annotate communication collectives with the metadata
20+
(collective type, message size, group, rank) that eager NCCL
21+
traces expose but CUDA graphs drop
1922
2023
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
2124
:class-card: card-prerequisites
@@ -34,6 +37,14 @@
3437
labels to kernels within CUDA graphs. These annotations can be merged back into
3538
profiler traces to create custom visualization lanes, making it easier to
3639
understand and debug complex graph executions.
40+
41+
Annotations are not limited to compute kernels. One of the most valuable uses
42+
is annotating **communication collectives**. In eager mode, the profiler
43+
attaches rich metadata to every NCCL kernel -- the collective type, message
44+
size, process group, and ranks -- so you can see exactly what each comm is
45+
doing. Under CUDA graphs that metadata is lost: the collective replays as an
46+
opaque kernel. This tutorial shows how to re-attach that metadata with
47+
annotations so graphed comms read just like eager ones.
3748
"""
3849

3950
###############################################################################
@@ -95,15 +106,18 @@
95106

96107
import copy
97108
import math
109+
import os
98110
import pickle
99111
import sys
100112
from collections import Counter
101113
from pathlib import Path
102114

103115
import torch
116+
import torch.distributed as dist
104117
from torch.profiler import profile, ProfilerActivity
105118
from torch.cuda._graph_annotations import (
106119
get_kernel_annotations,
120+
get_stream_for_pg,
107121
mark_kernels,
108122
_is_tools_id_unavailable,
109123
)
@@ -446,6 +460,180 @@ def main():
446460
# the semantic kernel lanes.
447461
# ============================================================
448462

463+
###############################################################################
464+
# Annotating Communication Collectives
465+
# -------------------------------------
466+
#
467+
# In eager mode the profiler records a NCCL collective with a set of metadata
468+
# fields -- the collective type, input/output message sizes, the process group,
469+
# its size, and the participating ranks. When you select an ``all_reduce`` in
470+
# the trace viewer you see all of it, which is invaluable for spotting an
471+
# undersized bucket, a collective on the wrong group, or a rank imbalance.
472+
#
473+
# Under CUDA graphs that context disappears. The collective is captured once
474+
# and then replayed as an anonymous kernel node, so the profiler has nothing to
475+
# attach the NCCL metadata to. The kernels still show up in the trace, but they
476+
# are opaque: you cannot tell an all-reduce from an all-gather, let alone how
477+
# many bytes moved.
478+
#
479+
# Annotations close this gap. By wrapping the collective in ``mark_kernels``
480+
# with the same fields eager records, we re-attach that metadata to the graphed
481+
# kernel. After post-processing, a graphed collective reads just like an eager
482+
# one. The helper below builds the metadata dict; using the field names the
483+
# profiler uses in eager (``Collective name``, ``In msg nelems``,
484+
# ``Group size``, ...) keeps the annotated trace consistent with non-graphed
485+
# traces, so the same tooling and muscle memory apply.
486+
487+
def annotate_collective(collective_name, tensor, group=None):
488+
"""Annotate a collective with the metadata eager NCCL traces expose.
489+
490+
Returns a ``mark_kernels`` context manager. Any kernels launched inside
491+
(i.e. the collective) are tagged with the collective type, message size,
492+
dtype, group name/size, and rank, and placed on a dedicated lane keyed by
493+
the process group so comms are visually separated from compute.
494+
"""
495+
initialized = dist.is_available() and dist.is_initialized()
496+
world_size = dist.get_world_size(group) if initialized else 1
497+
rank = dist.get_rank(group) if initialized else 0
498+
pg_name = getattr(group, "group_name", "default") if group else "default"
499+
500+
metadata = {
501+
"name": collective_name,
502+
# Field names mirror what the profiler records for eager collectives.
503+
"Collective name": collective_name,
504+
"dtype": str(tensor.dtype).replace("torch.", ""),
505+
"In msg nelems": tensor.numel(),
506+
"Out msg nelems": tensor.numel(),
507+
"Group size": world_size,
508+
"Process Group Name": pg_name,
509+
"rank": rank,
510+
# Give every process group its own lane (a stable id >= 60).
511+
"stream": get_stream_for_pg(pg_name),
512+
}
513+
return mark_kernels(metadata)
514+
515+
###############################################################################
516+
# A Block That Mixes Compute and Communication
517+
# ----------------------------------------------
518+
#
519+
# A tensor- or data-parallel layer interleaves matmuls with collectives. Here
520+
# the projection output is all-reduced across the group, mirroring the comm in
521+
# a tensor-parallel linear. The collective is annotated with
522+
# ``annotate_collective`` and lands on its own lane.
523+
524+
def build_comm_block(group=None):
525+
"""Create a compute + collective block annotated for profiling."""
526+
device = "cuda"
527+
torch.manual_seed(0)
528+
dim = 1024
529+
params = {
530+
"x": torch.randn(4, 256, dim, device=device),
531+
"W": torch.randn(dim, dim, device=device) / math.sqrt(dim),
532+
}
533+
534+
def forward():
535+
with mark_kernels({"name": "proj", "stream": 61}):
536+
h = params["x"] @ params["W"]
537+
538+
# All-reduce the projection output across the group (e.g. tensor
539+
# parallel). The annotation re-attaches the NCCL metadata that a
540+
# CUDA graph would otherwise drop.
541+
if dist.is_available() and dist.is_initialized():
542+
with annotate_collective("all_reduce", h, group):
543+
dist.all_reduce(h)
544+
return h
545+
546+
return forward
547+
548+
###############################################################################
549+
# Running the Communication Demo
550+
# -------------------------------
551+
#
552+
# Collectives need a process group. In real training the group already exists
553+
# with ``world_size > 1``; to keep this tutorial runnable on a single GPU we
554+
# initialize a trivial one-rank NCCL group. The capture, profiling, and
555+
# post-processing steps are exactly the same helpers used for the compute demo
556+
# -- comm annotations need no special handling on the trace side.
557+
558+
def maybe_init_single_rank_pg():
559+
"""Initialize a 1-rank NCCL group so the demo runs on a single GPU."""
560+
if not (dist.is_available() and torch.cuda.is_available()):
561+
return False
562+
if not dist.is_initialized():
563+
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
564+
os.environ.setdefault("MASTER_PORT", "29500")
565+
dist.init_process_group("nccl", rank=0, world_size=1)
566+
return True
567+
568+
def comm_annotation_demo():
569+
"""Capture a compute+collective block and surface the comm metadata."""
570+
if not maybe_init_single_rank_pg():
571+
print("Distributed/NCCL unavailable; skipping comm annotation demo.")
572+
return
573+
574+
output_dir = Path("traces_comm")
575+
576+
print("\nBuilding compute + collective block...")
577+
model_fn = build_comm_block()
578+
579+
print("Capturing CUDA graph with annotations...")
580+
graph, _ = capture_graph_with_annotations(model_fn)
581+
582+
annotations_path = save_annotations(output_dir)
583+
raw_trace_path = profile_graph(graph, output_dir)
584+
annotated_path, _, annotated_trace = post_process_trace(
585+
raw_trace_path, annotations_path, output_dir
586+
)
587+
588+
# Print the args of the annotated collective kernel(s) to show that the
589+
# eager-style metadata is now attached to the graphed comm.
590+
print("\nAnnotated collective kernels (metadata restored):")
591+
for event in annotated_trace["traceEvents"]:
592+
args = event.get("args", {})
593+
if args.get("Collective name"):
594+
print(f" {event.get('name', '?')[:40]}")
595+
for key in (
596+
"Collective name",
597+
"dtype",
598+
"In msg nelems",
599+
"Group size",
600+
"Process Group Name",
601+
"rank",
602+
"stream",
603+
):
604+
if key in args:
605+
print(f" {key}: {args[key]}")
606+
print(f"\nAnnotated trace: {annotated_path}")
607+
608+
# Example output:
609+
# if __name__ == "__main__":
610+
# comm_annotation_demo()
611+
#
612+
# Building compute + collective block...
613+
# Capturing CUDA graph with annotations...
614+
# Captured graph with 3 annotated nodes
615+
# Saved 3 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl
616+
# Saved raw trace to traces_comm/trace_raw.json.gz
617+
# Annotated 3 kernels in the trace
618+
# Saved annotated trace to traces_comm/trace_annotated.json.gz
619+
#
620+
# Annotated collective kernels (metadata restored):
621+
# ncclDevKernel_AllReduce_Sum_f32_RING_LL
622+
# Collective name: all_reduce
623+
# dtype: float32
624+
# In msg nelems: 1048576
625+
# Group size: 1
626+
# Process Group Name: default
627+
# rank: 0
628+
# stream: 60
629+
#
630+
# Annotated trace: traces_comm/trace_annotated.json.gz
631+
#
632+
# In the trace viewer the all-reduce now sits on its own ``comm`` lane, and
633+
# selecting it shows the collective type, message size, group, and rank --
634+
# the same fields you would see in an eager trace, recovered for a graphed
635+
# collective.
636+
449637
###############################################################################
450638
# Visualizing Results
451639
# -------------------
@@ -528,6 +716,9 @@ def main():
528716
#
529717
# - Use ``mark_kernels()`` to label regions during graph capture
530718
# - Enable annotations with ``enable_annotations=True``
719+
# - Annotate communication collectives to recover the NCCL metadata
720+
# (collective type, message size, group, rank) that CUDA graphs drop but
721+
# eager traces expose
531722
# - Post-process traces with ``annotate_trace()`` and cleanup passes
532723
# - View results in chrome://tracing for intuitive visualization
533724
#

0 commit comments

Comments
 (0)