|
16 | 16 | * How to profile annotated graphs |
17 | 17 | * How to post-process traces with semantic kernel lanes |
18 | 18 | * How to visualize graph execution with custom stream assignments |
| 19 | + * How to annotate communication collectives with the metadata |
| 20 | + (collective type, message size, group, rank) that eager NCCL |
| 21 | + traces expose but CUDA graphs drop |
19 | 22 |
|
20 | 23 | .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
21 | 24 | :class-card: card-prerequisites |
|
34 | 37 | labels to kernels within CUDA graphs. These annotations can be merged back into |
35 | 38 | profiler traces to create custom visualization lanes, making it easier to |
36 | 39 | understand and debug complex graph executions. |
| 40 | +
|
| 41 | +Annotations are not limited to compute kernels. One of the most valuable uses |
| 42 | +is annotating **communication collectives**. In eager mode, the profiler |
| 43 | +attaches rich metadata to every NCCL kernel -- the collective type, message |
| 44 | +size, process group, and ranks -- so you can see exactly what each comm is |
| 45 | +doing. Under CUDA graphs that metadata is lost: the collective replays as an |
| 46 | +opaque kernel. This tutorial shows how to re-attach that metadata with |
| 47 | +annotations so graphed comms read just like eager ones. |
37 | 48 | """ |
38 | 49 |
|
39 | 50 | ############################################################################### |
|
62 | 73 | # belong to which logical component of your model. |
63 | 74 | # |
64 | 75 | # .. image:: /_static/img/cuda_graph_trace_before.png |
65 | | -# :width: 100% |
| 76 | +# :width: 80% |
66 | 77 | # :alt: CUDA graph trace before annotations showing all kernels on one stream |
67 | 78 | # |
68 | 79 | # **After annotations:** Kernels are organized into semantic lanes (streams 61 |
69 | 80 | # and 62) with meaningful labels like "attention" and "mlp", making it easy to |
70 | 81 | # identify different components and understand the execution structure. |
71 | 82 | # |
72 | 83 | # .. image:: /_static/img/cuda_graph_trace_after.png |
73 | | -# :width: 100% |
| 84 | +# :width: 80% |
74 | 85 | # :alt: CUDA graph trace after annotations showing kernels organized by function |
75 | 86 | # |
| 87 | +# As another example, here is an AllReduce kernel with annotated metadata: |
| 88 | +# |
| 89 | +# .. image:: /_static/img/annotated_cudagraph.png |
| 90 | +# :width: 80% |
| 91 | +# :alt: AllReduce kernel with annotated metadata |
| 92 | +# |
76 | 93 | # Requirements |
77 | 94 | # ------------ |
78 | 95 | # |
|
95 | 112 |
|
96 | 113 | import copy |
97 | 114 | import math |
| 115 | +import os |
98 | 116 | import pickle |
99 | 117 | import sys |
100 | 118 | from collections import Counter |
101 | 119 | from pathlib import Path |
102 | 120 |
|
103 | 121 | import torch |
| 122 | +import torch.distributed as dist |
| 123 | +import torch.multiprocessing |
104 | 124 | from torch.profiler import profile, ProfilerActivity |
105 | 125 | from torch.cuda._graph_annotations import ( |
106 | 126 | get_kernel_annotations, |
| 127 | + get_stream_for_pg, |
107 | 128 | mark_kernels, |
108 | 129 | _is_tools_id_unavailable, |
109 | 130 | ) |
@@ -398,7 +419,7 @@ def main(): |
398 | 419 | print(f"Raw trace: {raw_trace_path}") |
399 | 420 | print(f"Annotated trace: {annotated_path}") |
400 | 421 | print(f"Annotations: {annotations_path}") |
401 | | - print("\nOpen the annotated trace in chrome://tracing to visualize") |
| 422 | + print("\nOpen the annotated trace in https://ui.perfetto.dev/ to visualize") |
402 | 423 | print("the semantic kernel lanes.") |
403 | 424 | print("="*60) |
404 | 425 |
|
@@ -442,25 +463,216 @@ def main(): |
442 | 463 | # Annotated trace: traces/trace_annotated.json.gz |
443 | 464 | # Annotations: traces/kernel_annotations_rank0_fwd_bwd.pkl |
444 | 465 | # |
445 | | -# Open the annotated trace in chrome://tracing to visualize |
| 466 | +# Open the annotated trace in https://ui.perfetto.dev/ to visualize |
446 | 467 | # the semantic kernel lanes. |
447 | 468 | # ============================================================ |
448 | 469 |
|
449 | 470 | ############################################################################### |
450 | | -# Visualizing Results |
451 | | -# ------------------- |
452 | | -# |
453 | | -# To view the annotated trace: |
| 471 | +# Annotating Communication Collectives |
| 472 | +# ------------------------------------- |
454 | 473 | # |
455 | | -# 1. Open Chrome/Chromium browser |
456 | | -# 2. Navigate to ``chrome://tracing`` |
457 | | -# 3. Click "Load" and select the ``trace_annotated.json.gz`` file |
458 | | -# 4. You should see kernels organized into custom lanes like "qkv_proj", |
459 | | -# "attention", "out_proj", and "mlp" |
| 474 | +# In eager mode the profiler **automatically intercepts** NCCL collectives and |
| 475 | +# records rich metadata: collective type, input/output message sizes, the process |
| 476 | +# group, its size, and the participating ranks. |
| 477 | +# |
| 478 | +# Under CUDA graphs that automatic interception stops working. The collective is |
| 479 | +# captured once and then replayed as an opaque kernel node. The profiler cannot |
| 480 | +# intercept graph replay, so it has nothing to attach the NCCL metadata to. The |
| 481 | +# kernels still show up in the trace (e.g., ``ncclDevKernel_AllReduce_Sum_f32_RING_LL``), |
| 482 | +# but they are opaque: you cannot tell what collective type it is, how many bytes |
| 483 | +# moved, or which process group it belongs to. |
| 484 | +# |
| 485 | +# Annotations close this gap. By wrapping the collective in ``mark_kernels`` |
| 486 | +# with the same fields the profiler auto-attaches in eager mode, we manually |
| 487 | +# re-attach that metadata to the graphed kernel. After post-processing, a |
| 488 | +# graphed collective reads just like an eager one. The helper below builds the |
| 489 | +# metadata dict; using the field names the profiler uses in eager |
| 490 | +# (``In msg nelems``, ``Group size``, ``Process Group Name``, ...) keeps the |
| 491 | +# annotated trace consistent with non-graphed traces. |
| 492 | + |
| 493 | +def annotate_collective(collective_name, input_tensor, output_tensor, group=None): |
| 494 | + """Annotate a collective with the metadata eager NCCL traces expose. |
| 495 | +
|
| 496 | + Returns a ``mark_kernels`` context manager. Any kernels launched inside |
| 497 | + (i.e. the collective) are tagged with the collective type, message sizes, |
| 498 | + dtype, and the process group's name/description/ranks, and placed on a |
| 499 | + dedicated lane keyed by the process group so comms are visually separated |
| 500 | + from compute. |
| 501 | +
|
| 502 | + The field names match the keys the profiler records for eager collectives |
| 503 | + (``In msg nelems``, ``Group size``, ``Process Group Name``, ...), so an |
| 504 | + annotated graphed collective reads exactly like a non-graphed one. |
| 505 | + """ |
| 506 | + pg = group if group is not None else (dist.group.WORLD if dist.is_initialized() else None) |
| 507 | + ranks = dist.get_process_group_ranks(pg) if pg is not None else [0] |
| 508 | + group_name = getattr(pg, "group_name", "default") |
| 509 | + group_desc = getattr(pg, "group_desc", "default") |
| 510 | + |
| 511 | + # NCCL always uses its own internal stream, so key the lane on the process |
| 512 | + # group (name + description) and give it a stable id (>= 60). |
| 513 | + pg_key = f"{group_name}_{group_desc}" |
| 514 | + annotation = { |
| 515 | + "name": collective_name, |
| 516 | + "In msg nelems": input_tensor.numel(), |
| 517 | + "Out msg nelems": output_tensor.numel(), |
| 518 | + "Group size": len(ranks), |
| 519 | + "dtype": str(input_tensor.dtype).replace("torch.", ""), |
| 520 | + "Process Group Name": group_name, |
| 521 | + "Process Group Description": group_desc, |
| 522 | + "Process Group Ranks": ranks, |
| 523 | + "stream": get_stream_for_pg(pg_key), |
| 524 | + } |
| 525 | + return mark_kernels(annotation) |
| 526 | + |
| 527 | +############################################################################### |
| 528 | +# A Block That Mixes Compute and Communication |
| 529 | +# ---------------------------------------------- |
460 | 530 | # |
461 | | -# The custom stream IDs (61, 62) specified in ``mark_kernels`` appear as |
462 | | -# separate lanes, making it easy to see which operations run concurrently |
463 | | -# or sequentially. |
| 531 | +# A tensor- or data-parallel layer interleaves matmuls with collectives. Here |
| 532 | +# the projection output is all-reduced across the group, mirroring the comm in |
| 533 | +# a tensor-parallel linear. The collective is annotated with |
| 534 | +# ``annotate_collective`` and lands on its own lane. |
| 535 | + |
| 536 | +def build_comm_block(group=None): |
| 537 | + """Create a compute + collective block annotated for profiling.""" |
| 538 | + device = "cuda" |
| 539 | + torch.manual_seed(0) |
| 540 | + dim = 1024 |
| 541 | + params = { |
| 542 | + "x": torch.randn(4, 256, dim, device=device), |
| 543 | + "W": torch.randn(dim, dim, device=device) / math.sqrt(dim), |
| 544 | + } |
| 545 | + |
| 546 | + def forward(): |
| 547 | + with mark_kernels({"name": "proj", "stream": 61}): |
| 548 | + h = params["x"] @ params["W"] |
| 549 | + |
| 550 | + # All-reduce the projection output across the group (e.g. tensor |
| 551 | + # parallel). all_reduce is in-place, so the input and output tensors |
| 552 | + # are the same. The annotation re-attaches the NCCL metadata that a |
| 553 | + # CUDA graph would otherwise drop. |
| 554 | + if dist.is_available() and dist.is_initialized(): |
| 555 | + with annotate_collective("all_reduce", h, h, group): |
| 556 | + dist.all_reduce(h) |
| 557 | + return h |
| 558 | + |
| 559 | + return forward |
| 560 | + |
| 561 | +############################################################################### |
| 562 | +# Running the Communication Demo |
| 563 | +# ------------------------------- |
| 564 | +# |
| 565 | + |
| 566 | +WORLD_SIZE = 2 |
| 567 | + |
| 568 | +def init_pg(rank, world_size): |
| 569 | + """Initialize a NCCL group for one rank of the spawned demo.""" |
| 570 | + os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 571 | + os.environ["MASTER_PORT"] = "29500" |
| 572 | + os.environ["RANK"] = str(rank) |
| 573 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 574 | + # Use loopback interface for single-node setup |
| 575 | + os.environ["NCCL_SOCKET_IFNAME"] = "lo" |
| 576 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 577 | + torch.cuda.set_device(rank) |
| 578 | + |
| 579 | +def _comm_worker(rank, world_size): |
| 580 | + """Per-rank worker: build, capture, profile, and (on rank 0) post-process.""" |
| 581 | + init_pg(rank, world_size) |
| 582 | + |
| 583 | + output_dir = Path("traces_comm") |
| 584 | + |
| 585 | + if rank == 0: |
| 586 | + print("\nBuilding compute + collective block...") |
| 587 | + model_fn = build_comm_block() |
| 588 | + |
| 589 | + if rank == 0: |
| 590 | + print("Capturing CUDA graph with annotations...") |
| 591 | + graph, _ = capture_graph_with_annotations(model_fn) |
| 592 | + |
| 593 | + # Every rank participates in the collective during profiling, but only |
| 594 | + # rank 0 saves and post-processes the trace. |
| 595 | + if rank == 0: |
| 596 | + annotations_path = save_annotations(output_dir) |
| 597 | + raw_trace_path = profile_graph(graph, output_dir) |
| 598 | + annotated_path, _, annotated_trace = post_process_trace( |
| 599 | + raw_trace_path, annotations_path, output_dir |
| 600 | + ) |
| 601 | + |
| 602 | + # Print the args of the annotated collective kernel(s) to show that the |
| 603 | + # eager-style metadata is now attached to the graphed comm. |
| 604 | + print("\nAnnotated collective kernels (metadata restored):") |
| 605 | + for event in annotated_trace["traceEvents"]: |
| 606 | + args = event.get("args", {}) |
| 607 | + if args.get("In msg nelems") is not None: |
| 608 | + print(f" {event.get('name', '?')[:40]}") |
| 609 | + for key in ( |
| 610 | + "In msg nelems", |
| 611 | + "Out msg nelems", |
| 612 | + "Group size", |
| 613 | + "dtype", |
| 614 | + "Process Group Name", |
| 615 | + "Process Group Description", |
| 616 | + "Process Group Ranks", |
| 617 | + "stream", |
| 618 | + ): |
| 619 | + if key in args: |
| 620 | + print(f" {key}: {args[key]}") |
| 621 | + print(f"\nAnnotated trace: {annotated_path}") |
| 622 | + else: |
| 623 | + # Match rank 0's warmup + profiled replays so the collective completes. |
| 624 | + for _ in range(3): |
| 625 | + graph.replay() |
| 626 | + torch.cuda.synchronize() |
| 627 | + for _ in range(5): |
| 628 | + graph.replay() |
| 629 | + torch.cuda.synchronize() |
| 630 | + |
| 631 | + dist.destroy_process_group() |
| 632 | + |
| 633 | +def comm_annotation_demo(): |
| 634 | + """Spawn a ``world_size=2`` group and surface the comm metadata.""" |
| 635 | + if not (dist.is_available() and torch.cuda.is_available()): |
| 636 | + print("Distributed/NCCL unavailable; skipping comm annotation demo.") |
| 637 | + return |
| 638 | + if torch.cuda.device_count() < WORLD_SIZE: |
| 639 | + print(f"Need {WORLD_SIZE} GPUs for the comm demo; skipping.") |
| 640 | + return |
| 641 | + |
| 642 | + torch.multiprocessing.spawn( |
| 643 | + _comm_worker, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True |
| 644 | + ) |
| 645 | + |
| 646 | +# Example output (2 GPUs): |
| 647 | +# if __name__ == "__main__": |
| 648 | +# comm_annotation_demo() |
| 649 | +# |
| 650 | +# Building compute + collective block... |
| 651 | +# Capturing CUDA graph with annotations... |
| 652 | +# Captured graph with 2 annotated nodes |
| 653 | +# Saved 2 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl |
| 654 | +# Saved raw trace to traces_comm/trace_raw.json.gz |
| 655 | +# Annotated 5 kernels in the trace |
| 656 | +# Saved annotated trace to traces_comm/trace_annotated.json.gz |
| 657 | +# |
| 658 | +# The all_reduce runs a real NCCL kernel |
| 659 | +# (``ncclDevKernel_AllReduce_Sum_f32_RING_LL``) across the two ranks: |
| 660 | +# |
| 661 | +# Annotated collective kernels (metadata restored): |
| 662 | +# ncclDevKernel_AllReduce_Sum_f32_RING_LL |
| 663 | +# In msg nelems: 1048576 |
| 664 | +# Out msg nelems: 1048576 |
| 665 | +# Group size: 2 |
| 666 | +# dtype: float32 |
| 667 | +# Process Group Name: default |
| 668 | +# Process Group Description: default |
| 669 | +# Process Group Ranks: [0, 1] |
| 670 | +# stream: 60 |
| 671 | +# |
| 672 | +# In the trace viewer, the all-reduce sits on its own dedicated comm lane |
| 673 | +# (stream 60), and selecting it shows the collective type, message sizes, group, |
| 674 | +# and ranks -- the same fields you would see in an eager trace, now recovered |
| 675 | +# for a CUDA-graphed collective. This metadata is LOST without annotations. |
464 | 676 |
|
465 | 677 | ############################################################################### |
466 | 678 | # Understanding the Cleanup Passes |
@@ -528,8 +740,11 @@ def main(): |
528 | 740 | # |
529 | 741 | # - Use ``mark_kernels()`` to label regions during graph capture |
530 | 742 | # - Enable annotations with ``enable_annotations=True`` |
| 743 | +# - Annotate communication collectives to recover the NCCL metadata |
| 744 | +# (collective type, message size, group, rank) that CUDA graphs drop but |
| 745 | +# eager traces expose |
531 | 746 | # - Post-process traces with ``annotate_trace()`` and cleanup passes |
532 | | -# - View results in chrome://tracing for intuitive visualization |
| 747 | +# - View results in https://ui.perfetto.dev/ for intuitive visualization |
533 | 748 | # |
534 | 749 | # This technique is especially valuable for large models with many components, |
535 | 750 | # distributed training setups, or any scenario where understanding the |
|
0 commit comments