|
16 | 16 | * How to profile annotated graphs |
17 | 17 | * How to post-process traces with semantic kernel lanes |
18 | 18 | * How to visualize graph execution with custom stream assignments |
| 19 | + * How to annotate communication collectives with the metadata |
| 20 | + (collective type, message size, group, rank) that eager NCCL |
| 21 | + traces expose but CUDA graphs drop |
19 | 22 |
|
20 | 23 | .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
21 | 24 | :class-card: card-prerequisites |
|
34 | 37 | labels to kernels within CUDA graphs. These annotations can be merged back into |
35 | 38 | profiler traces to create custom visualization lanes, making it easier to |
36 | 39 | understand and debug complex graph executions. |
| 40 | +
|
| 41 | +Annotations are not limited to compute kernels. One of the most valuable uses |
| 42 | +is annotating **communication collectives**. In eager mode, the profiler |
| 43 | +attaches rich metadata to every NCCL kernel -- the collective type, message |
| 44 | +size, process group, and ranks -- so you can see exactly what each comm is |
| 45 | +doing. Under CUDA graphs that metadata is lost: the collective replays as an |
| 46 | +opaque kernel. This tutorial shows how to re-attach that metadata with |
| 47 | +annotations so graphed comms read just like eager ones. |
37 | 48 | """ |
38 | 49 |
|
39 | 50 | ############################################################################### |
|
95 | 106 |
|
96 | 107 | import copy |
97 | 108 | import math |
| 109 | +import os |
98 | 110 | import pickle |
99 | 111 | import sys |
100 | 112 | from collections import Counter |
101 | 113 | from pathlib import Path |
102 | 114 |
|
103 | 115 | import torch |
| 116 | +import torch.distributed as dist |
104 | 117 | from torch.profiler import profile, ProfilerActivity |
105 | 118 | from torch.cuda._graph_annotations import ( |
106 | 119 | get_kernel_annotations, |
| 120 | + get_stream_for_pg, |
107 | 121 | mark_kernels, |
108 | 122 | _is_tools_id_unavailable, |
109 | 123 | ) |
@@ -446,6 +460,180 @@ def main(): |
446 | 460 | # the semantic kernel lanes. |
447 | 461 | # ============================================================ |
448 | 462 |
|
| 463 | +############################################################################### |
| 464 | +# Annotating Communication Collectives |
| 465 | +# ------------------------------------- |
| 466 | +# |
| 467 | +# In eager mode the profiler records a NCCL collective with a set of metadata |
| 468 | +# fields -- the collective type, input/output message sizes, the process group, |
| 469 | +# its size, and the participating ranks. When you select an ``all_reduce`` in |
| 470 | +# the trace viewer you see all of it, which is invaluable for spotting an |
| 471 | +# undersized bucket, a collective on the wrong group, or a rank imbalance. |
| 472 | +# |
| 473 | +# Under CUDA graphs that context disappears. The collective is captured once |
| 474 | +# and then replayed as an anonymous kernel node, so the profiler has nothing to |
| 475 | +# attach the NCCL metadata to. The kernels still show up in the trace, but they |
| 476 | +# are opaque: you cannot tell an all-reduce from an all-gather, let alone how |
| 477 | +# many bytes moved. |
| 478 | +# |
| 479 | +# Annotations close this gap. By wrapping the collective in ``mark_kernels`` |
| 480 | +# with the same fields eager records, we re-attach that metadata to the graphed |
| 481 | +# kernel. After post-processing, a graphed collective reads just like an eager |
| 482 | +# one. The helper below builds the metadata dict; using the field names the |
| 483 | +# profiler uses in eager (``Collective name``, ``In msg nelems``, |
| 484 | +# ``Group size``, ...) keeps the annotated trace consistent with non-graphed |
| 485 | +# traces, so the same tooling and muscle memory apply. |
| 486 | + |
| 487 | +def annotate_collective(collective_name, tensor, group=None): |
| 488 | + """Annotate a collective with the metadata eager NCCL traces expose. |
| 489 | +
|
| 490 | + Returns a ``mark_kernels`` context manager. Any kernels launched inside |
| 491 | + (i.e. the collective) are tagged with the collective type, message size, |
| 492 | + dtype, group name/size, and rank, and placed on a dedicated lane keyed by |
| 493 | + the process group so comms are visually separated from compute. |
| 494 | + """ |
| 495 | + initialized = dist.is_available() and dist.is_initialized() |
| 496 | + world_size = dist.get_world_size(group) if initialized else 1 |
| 497 | + rank = dist.get_rank(group) if initialized else 0 |
| 498 | + pg_name = getattr(group, "group_name", "default") if group else "default" |
| 499 | + |
| 500 | + metadata = { |
| 501 | + "name": collective_name, |
| 502 | + # Field names mirror what the profiler records for eager collectives. |
| 503 | + "Collective name": collective_name, |
| 504 | + "dtype": str(tensor.dtype).replace("torch.", ""), |
| 505 | + "In msg nelems": tensor.numel(), |
| 506 | + "Out msg nelems": tensor.numel(), |
| 507 | + "Group size": world_size, |
| 508 | + "Process Group Name": pg_name, |
| 509 | + "rank": rank, |
| 510 | + # Give every process group its own lane (a stable id >= 60). |
| 511 | + "stream": get_stream_for_pg(pg_name), |
| 512 | + } |
| 513 | + return mark_kernels(metadata) |
| 514 | + |
| 515 | +############################################################################### |
| 516 | +# A Block That Mixes Compute and Communication |
| 517 | +# ---------------------------------------------- |
| 518 | +# |
| 519 | +# A tensor- or data-parallel layer interleaves matmuls with collectives. Here |
| 520 | +# the projection output is all-reduced across the group, mirroring the comm in |
| 521 | +# a tensor-parallel linear. The collective is annotated with |
| 522 | +# ``annotate_collective`` and lands on its own lane. |
| 523 | + |
| 524 | +def build_comm_block(group=None): |
| 525 | + """Create a compute + collective block annotated for profiling.""" |
| 526 | + device = "cuda" |
| 527 | + torch.manual_seed(0) |
| 528 | + dim = 1024 |
| 529 | + params = { |
| 530 | + "x": torch.randn(4, 256, dim, device=device), |
| 531 | + "W": torch.randn(dim, dim, device=device) / math.sqrt(dim), |
| 532 | + } |
| 533 | + |
| 534 | + def forward(): |
| 535 | + with mark_kernels({"name": "proj", "stream": 61}): |
| 536 | + h = params["x"] @ params["W"] |
| 537 | + |
| 538 | + # All-reduce the projection output across the group (e.g. tensor |
| 539 | + # parallel). The annotation re-attaches the NCCL metadata that a |
| 540 | + # CUDA graph would otherwise drop. |
| 541 | + if dist.is_available() and dist.is_initialized(): |
| 542 | + with annotate_collective("all_reduce", h, group): |
| 543 | + dist.all_reduce(h) |
| 544 | + return h |
| 545 | + |
| 546 | + return forward |
| 547 | + |
| 548 | +############################################################################### |
| 549 | +# Running the Communication Demo |
| 550 | +# ------------------------------- |
| 551 | +# |
| 552 | +# Collectives need a process group. In real training the group already exists |
| 553 | +# with ``world_size > 1``; to keep this tutorial runnable on a single GPU we |
| 554 | +# initialize a trivial one-rank NCCL group. The capture, profiling, and |
| 555 | +# post-processing steps are exactly the same helpers used for the compute demo |
| 556 | +# -- comm annotations need no special handling on the trace side. |
| 557 | + |
| 558 | +def maybe_init_single_rank_pg(): |
| 559 | + """Initialize a 1-rank NCCL group so the demo runs on a single GPU.""" |
| 560 | + if not (dist.is_available() and torch.cuda.is_available()): |
| 561 | + return False |
| 562 | + if not dist.is_initialized(): |
| 563 | + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") |
| 564 | + os.environ.setdefault("MASTER_PORT", "29500") |
| 565 | + dist.init_process_group("nccl", rank=0, world_size=1) |
| 566 | + return True |
| 567 | + |
| 568 | +def comm_annotation_demo(): |
| 569 | + """Capture a compute+collective block and surface the comm metadata.""" |
| 570 | + if not maybe_init_single_rank_pg(): |
| 571 | + print("Distributed/NCCL unavailable; skipping comm annotation demo.") |
| 572 | + return |
| 573 | + |
| 574 | + output_dir = Path("traces_comm") |
| 575 | + |
| 576 | + print("\nBuilding compute + collective block...") |
| 577 | + model_fn = build_comm_block() |
| 578 | + |
| 579 | + print("Capturing CUDA graph with annotations...") |
| 580 | + graph, _ = capture_graph_with_annotations(model_fn) |
| 581 | + |
| 582 | + annotations_path = save_annotations(output_dir) |
| 583 | + raw_trace_path = profile_graph(graph, output_dir) |
| 584 | + annotated_path, _, annotated_trace = post_process_trace( |
| 585 | + raw_trace_path, annotations_path, output_dir |
| 586 | + ) |
| 587 | + |
| 588 | + # Print the args of the annotated collective kernel(s) to show that the |
| 589 | + # eager-style metadata is now attached to the graphed comm. |
| 590 | + print("\nAnnotated collective kernels (metadata restored):") |
| 591 | + for event in annotated_trace["traceEvents"]: |
| 592 | + args = event.get("args", {}) |
| 593 | + if args.get("Collective name"): |
| 594 | + print(f" {event.get('name', '?')[:40]}") |
| 595 | + for key in ( |
| 596 | + "Collective name", |
| 597 | + "dtype", |
| 598 | + "In msg nelems", |
| 599 | + "Group size", |
| 600 | + "Process Group Name", |
| 601 | + "rank", |
| 602 | + "stream", |
| 603 | + ): |
| 604 | + if key in args: |
| 605 | + print(f" {key}: {args[key]}") |
| 606 | + print(f"\nAnnotated trace: {annotated_path}") |
| 607 | + |
| 608 | +# Example output: |
| 609 | +# if __name__ == "__main__": |
| 610 | +# comm_annotation_demo() |
| 611 | +# |
| 612 | +# Building compute + collective block... |
| 613 | +# Capturing CUDA graph with annotations... |
| 614 | +# Captured graph with 3 annotated nodes |
| 615 | +# Saved 3 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl |
| 616 | +# Saved raw trace to traces_comm/trace_raw.json.gz |
| 617 | +# Annotated 3 kernels in the trace |
| 618 | +# Saved annotated trace to traces_comm/trace_annotated.json.gz |
| 619 | +# |
| 620 | +# Annotated collective kernels (metadata restored): |
| 621 | +# ncclDevKernel_AllReduce_Sum_f32_RING_LL |
| 622 | +# Collective name: all_reduce |
| 623 | +# dtype: float32 |
| 624 | +# In msg nelems: 1048576 |
| 625 | +# Group size: 1 |
| 626 | +# Process Group Name: default |
| 627 | +# rank: 0 |
| 628 | +# stream: 60 |
| 629 | +# |
| 630 | +# Annotated trace: traces_comm/trace_annotated.json.gz |
| 631 | +# |
| 632 | +# In the trace viewer the all-reduce now sits on its own ``comm`` lane, and |
| 633 | +# selecting it shows the collective type, message size, group, and rank -- |
| 634 | +# the same fields you would see in an eager trace, recovered for a graphed |
| 635 | +# collective. |
| 636 | + |
449 | 637 | ############################################################################### |
450 | 638 | # Visualizing Results |
451 | 639 | # ------------------- |
@@ -528,6 +716,9 @@ def main(): |
528 | 716 | # |
529 | 717 | # - Use ``mark_kernels()`` to label regions during graph capture |
530 | 718 | # - Enable annotations with ``enable_annotations=True`` |
| 719 | +# - Annotate communication collectives to recover the NCCL metadata |
| 720 | +# (collective type, message size, group, rank) that CUDA graphs drop but |
| 721 | +# eager traces expose |
531 | 722 | # - Post-process traces with ``annotate_trace()`` and cleanup passes |
532 | 723 | # - View results in chrome://tracing for intuitive visualization |
533 | 724 | # |
|
0 commit comments