From d5b203c00b2fbc052aa21861f4caf32abf8781c4 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 28 Apr 2026 15:19:57 -0700 Subject: [PATCH 1/3] Add CP benchmark/profile machinery and config catalog MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move benchmarking infrastructure for CP attention onto a dedicated branch so it persists outside of stash. The core test suite (test_attention_with_cp.py) stays focused on correctness; this branch layers benchmark/profile/stress configs and a cross-backend consistency check on top. run_attention_with_cp.py changes (worker side): - thd_seqlen_pattern arg supports max/half/linear/alternating/random and explicit comma-separated lengths, so benchmark configs can pin a specific variable-length workload instead of randomizing per-run. - benchmark arg drives a 10-warmup + N-iter timing loop wrapped in cudaProfilerStart/Stop and prints ms/iter for nsys/ncu workflows. - torch.manual_seed(1234) for reproducibility across runs. - CP_CROSS_BACKEND_SAVE_DIR env saves per-rank inputs/outputs as .pt for the cross-backend consistency test to compare without re-running. - Soft import from benchmark_cp so the worker can resolve names like cp_thd_0, bench_8k, bariamis_8k, rl16k without test_attention_with_cp.py needing to know about them. benchmark_cp.py (new): - Stress configs (cp_thd_0..3, cp_thd_swa_0..3) — higher batch/longer seqlen than the core suite. - Llama3-8b-shaped configs (bench_8k/16k/32k). - Variable-length training-workload configs (rl16k, bucket32k/64k/128k, mixed32k, outlier64k) with per-config thd_seqlen_pattern. - Worker-only configs (bariamis_*, bench_84992/86016) for manual invocation against the AG spike investigation log shapes. - test_cp_thd_cross_backend_consistency: runs each backend (p2p/all_gather/a2a) on the same input, saves outputs via CP_CROSS_BACKEND_SAVE_DIR, and asserts pairwise agreement within atol=0.1. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/benchmark_cp.py | 266 ++++++++++++++++++ .../attention/run_attention_with_cp.py | 109 ++++++- 2 files changed, 373 insertions(+), 2 deletions(-) create mode 100644 tests/pytorch/attention/benchmark_cp.py diff --git a/tests/pytorch/attention/benchmark_cp.py b/tests/pytorch/attention/benchmark_cp.py new file mode 100644 index 0000000000..1a15ff4c38 --- /dev/null +++ b/tests/pytorch/attention/benchmark_cp.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Benchmark/profile configurations and cross-backend consistency test for CP attention. + +Configs here are intended for benchmarking and stress testing on top of the +core correctness suite in test_attention_with_cp.py. Most configs use larger +batch sizes / sequence lengths than the core suite. Variable-length THD inputs +use the `thd_seqlen_pattern` attribute, plumbed through to run_attention_with_cp.py. + +The bench/profile machinery (timing loop, cudaProfilerStart/Stop, save dir, seqlen +pattern arg) lives in run_attention_with_cp.py. +""" + +import os +import sys +import pathlib +import logging +import copy +import tempfile +import pytest +import torch +from transformer_engine.pytorch import ( + get_device_compute_capability, + get_cudnn_version, +) + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig, get_available_attention_backends, run_distributed + +# Reuse the worker-launch helper from the core test module so we don't duplicate it. +from test_attention_with_cp import get_bash_arguments + +pytest_logging_level = logging.getLevelName(logging.root.level) + + +# Benchmark/stress configs (llama3_8b-like: 32 heads, 8 GQA, d=128). +model_configs_fused_attn = { + # Llama3-8b-shaped, varying seqlen + "bench_8k": ModelConfig(2, 8192, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "bench_16k": ModelConfig(1, 16384, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "bench_32k": ModelConfig(1, 32768, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + # THD stress: higher batch / longer seqlen than core suite + "cp_thd_0": ModelConfig(8, 8192, 12, 128, attn_mask_type="causal"), # MHA b=8 + "cp_thd_1": ModelConfig(8, 8192, 12, 128), # MHA b=8 non-causal + "cp_thd_2": ModelConfig(16, 4096, 12, 128, attn_mask_type="causal"), # MHA b=16 + "cp_thd_3": ModelConfig(8, 8192, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA b=8 + # THD + SWA + "cp_thd_swa_0": ModelConfig( + 8, 8192, 12, 128, attn_mask_type="causal", window_size=(512, 0) + ), # MHA SWA causal + "cp_thd_swa_1": ModelConfig( + 8, 8192, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) + ), # GQA SWA causal + "cp_thd_swa_2": ModelConfig( + 8, 8192, 12, 128, attn_mask_type="causal", window_size=(512, 512) + ), # MHA SWA causal+right + "cp_thd_swa_3": ModelConfig( + 8, 8192, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) + ), # GQA SWA causal+right +} + + +# Variable-length training-workload configs. +# Seqlen patterns derived from cp_bench RESULTS.md Section 6B. +_training_workloads = { + "bucket32k": ( + ModelConfig(4, 32768, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "24576,28672,30720,32768", + ), + "bucket64k": ( + ModelConfig(4, 65536, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "57344,61440,63488,65536", + ), + "mixed32k": ( + ModelConfig(8, 32768, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "16384,24576,32768,8192,28672,32768,20480,16384", + ), + "rl16k": ( + ModelConfig(8, 16384, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "4096,6144,6144,8192,8192,10240,12288,16384", + ), + "outlier64k": ( + ModelConfig(4, 65536, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "8192,8192,8192,65536", + ), + "bucket128k": ( + ModelConfig(3, 131072, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), + "114688,122880,131072", + ), +} +for _name, (_cfg, _pat) in _training_workloads.items(): + _cfg.thd_seqlen_pattern = _pat + model_configs_fused_attn[_name] = _cfg + + +# Worker-only configs: not run via pytest, only resolved by name from the worker +# subprocess (e.g. when invoking run_attention_with_cp.py model=bariamis_8k ...). +# Matches dbariamis/cp_comm_attention benchmark log: B=2, H=16 MHA, d=128, causal. +model_configs_fused_attn["bariamis_8k"] = ModelConfig(2, 8192, 16, 128, attn_mask_type="causal") +model_configs_fused_attn["bariamis_262k"] = ModelConfig( + 2, 262144, 16, 128, attn_mask_type="causal" +) +model_configs_fused_attn["bench_84992"] = ModelConfig(2, 84992, 16, 128, attn_mask_type="causal") +model_configs_fused_attn["bench_86016"] = ModelConfig(2, 86016, 16, 128, attn_mask_type="causal") + + +# pytest-runnable subset: skip the worker-only and the very large configs. +_pytest_skip_configs = {"bariamis_8k", "bariamis_262k", "bench_84992", "bench_86016"} +_pytest_configs = { + k: v for k, v in model_configs_fused_attn.items() if k not in _pytest_skip_configs +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD requires sm90+.") +@pytest.mark.parametrize("model", _pytest_configs.keys()) +@pytest.mark.parametrize("qkv_format", ["thd"]) +@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"]) +def test_cp_benchmark_configs(model, qkv_format, cp_comm_type): + """Run benchmark/stress configs through the core CP path for correctness.""" + if 2 > torch.cuda.device_count(): + pytest.skip("Test requires 2 GPUs") + + config = _pytest_configs[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + + has_swa = config.window_size != (-1, 0) and config.window_size != (-1, -1) + if has_swa and cp_comm_type == "p2p": + pytest.skip("p2p does not support sliding window") + + # THD uses padding mask types for backend availability check + check_config = copy.deepcopy(config) + if "causal" in check_config.attn_mask_type: + check_config.attn_mask_type = "padding_causal" + else: + check_config.attn_mask_type = "padding" + + available, *_ = get_available_attention_backends( + check_config, + qkv_dtype=torch.bfloat16, + qkv_layout="_".join([qkv_format] * 3), + ) + _, fused_supported, _ = available + if not fused_supported: + pytest.skip("FusedAttention not available for this config") + + extra_kwargs = {} + thd_pat = getattr(config, "thd_seqlen_pattern", None) + if thd_pat is not None: + extra_kwargs["thd_seqlen_pattern"] = thd_pat + + run_distributed( + get_bash_arguments( + num_gpus_per_node=2, + dtype="bf16", + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, + **extra_kwargs, + ), + ) + + +# Cross-backend consistency configs: run the same input through p2p / all_gather / a2a +# and assert the outputs agree within tolerance. +model_configs_cross_backend = { + "cp_thd_0": model_configs_fused_attn["cp_thd_0"], + "cp_thd_1": model_configs_fused_attn["cp_thd_1"], + "cp_thd_2": model_configs_fused_attn["cp_thd_2"], + "cp_thd_3": model_configs_fused_attn["cp_thd_3"], + "cp_thd_swa_0": model_configs_fused_attn["cp_thd_swa_0"], + "cp_thd_swa_1": model_configs_fused_attn["cp_thd_swa_1"], + "cp_thd_swa_2": model_configs_fused_attn["cp_thd_swa_2"], + "cp_thd_swa_3": model_configs_fused_attn["cp_thd_swa_3"], +} +# Add a few training workloads (smaller ones to keep runtime reasonable) +for _name in ["rl16k", "bucket32k", "mixed32k"]: + if _name in model_configs_fused_attn: + model_configs_cross_backend[_name] = model_configs_fused_attn[_name] + + +@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD requires sm90+.") +@pytest.mark.parametrize("model", model_configs_cross_backend.keys()) +def test_cp_thd_cross_backend_consistency(model): + """Compare outputs of p2p, all_gather, and a2a backends for THD format.""" + if 2 > torch.cuda.device_count(): + pytest.skip("Test requires 2 GPUs") + + config = model_configs_cross_backend[model] + config.context_parallel = True + + has_swa = config.window_size != (-1, 0) and config.window_size != (-1, -1) + # p2p doesn't support sliding window + backends = ["all_gather", "a2a"] if has_swa else ["p2p", "all_gather", "a2a"] + saved_outputs = {} + + # THD uses padding mask types for backend availability check + check_config = copy.deepcopy(config) + if "causal" in check_config.attn_mask_type: + check_config.attn_mask_type = "padding_causal" + else: + check_config.attn_mask_type = "padding" + + with tempfile.TemporaryDirectory() as tmpdir: + for backend in backends: + check_config.cp_comm_type = backend + available, *_ = get_available_attention_backends( + check_config, + qkv_dtype=torch.bfloat16, + qkv_layout="thd_thd_thd", + ) + _, fused_supported, _ = available + if not fused_supported: + pytest.skip(f"FusedAttention not available for {backend}") + + save_dir = os.path.join(tmpdir, backend) + env = os.environ.copy() + env["CP_CROSS_BACKEND_SAVE_DIR"] = save_dir + extra_kwargs = {} + thd_pat = getattr(config, "thd_seqlen_pattern", None) + if thd_pat is not None: + extra_kwargs["thd_seqlen_pattern"] = thd_pat + run_distributed( + get_bash_arguments( + num_gpus_per_node=2, + dtype="bf16", + model=model, + qkv_format="thd", + kernel_backend="FusedAttention", + cp_comm_type=backend, + log_level=pytest_logging_level, + **extra_kwargs, + ), + env=env, + ) + saved_outputs[backend] = { + r: torch.load( + os.path.join(save_dir, f"outputs_{backend}_rank{r}.pt"), + weights_only=True, + ) + for r in range(2) + } + + # Compare all backends pairwise against the first as reference. + # Cross-backend diffs compound two independent CP implementations, + # so use a wider tolerance than the per-backend CP-vs-nonCP tests. + ref = backends[0] + atol = 0.1 + for backend in backends[1:]: + for rank in range(2): + ref_out = saved_outputs[ref][rank] + cmp_out = saved_outputs[backend][rank] + for key in ["out", "dq", "dk", "dv"]: + if ref_out[key] is None or cmp_out[key] is None: + continue + diff = (ref_out[key] - cmp_out[key]).abs().max().item() + assert diff < atol, ( + f"{backend} vs {ref} rank{rank} {key}: max_diff={diff} > {atol}" + ) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 697e0601c0..021fbb4893 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -4,6 +4,7 @@ import os import sys +import time import logging from contextlib import nullcontext import torch @@ -14,6 +15,18 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn + +# Merge in benchmark/stress configs from benchmark_cp.py if available so the worker +# can resolve names like cp_thd_0, bench_8k, bariamis_8k, rl16k, etc. +try: + from benchmark_cp import ( + model_configs_fused_attn as _bench_cfgs_fused_attn, + ) + + for _k, _v in _bench_cfgs_fused_attn.items(): + model_configs_fused_attn.setdefault(_k, _v) +except ImportError: + pass from transformer_engine.pytorch import ( autocast, DotProductAttention, @@ -38,6 +51,7 @@ def generate_input_shapes( world_size: int, kernel_backend: str, fa_pad_between_seqs: str = "False", + thd_seqlen_pattern: str = "random", ): if qkv_format == "bshd": q_input_shape = ( @@ -96,7 +110,29 @@ def generate_input_shapes( cu_seqlens_q_padded = None cu_seqlens_kv_padded = None elif qkv_format == "thd": - seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) + b, s = config.batch_size, config.max_seqlen_q + # Custom list: "24576,28672,30720,32768" -> explicit per-seq lengths + if "," in thd_seqlen_pattern: + seqlens_q = torch.tensor( + [int(x) for x in thd_seqlen_pattern.split(",")], dtype=torch.int32 + ) + b = len(seqlens_q) + s = int(seqlens_q.max()) + config.batch_size = b + config.max_seqlen_q = s + config.max_seqlen_kv = s + elif thd_seqlen_pattern == "max": + seqlens_q = torch.full([b], s, dtype=torch.int32) + elif thd_seqlen_pattern == "half": + seqlens_q = torch.full([b], s // 2, dtype=torch.int32) + elif thd_seqlen_pattern == "linear": + seqlens_q = torch.linspace(1, s, b).to(torch.int32) + elif thd_seqlen_pattern == "alternating": + seqlens_q = torch.tensor( + [s if i % 2 == 0 else s // 4 for i in range(b)], dtype=torch.int32 + ) + else: # "random" + seqlens_q = torch.randint(0, s + 1, [b]).to(torch.int32) seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) cu_seqlens_q_padded = torch.cat( [ @@ -193,8 +229,12 @@ def run_dpa_with_cp( fa_pad_between_seqs="False", deterministic="False", log_level=logging.WARNING, + benchmark="0", + thd_seqlen_pattern="random", ): """Test DotProductAttention module with context parallelism""" + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" @@ -293,13 +333,30 @@ def run_dpa_with_cp( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs) + ) = generate_input_shapes( + qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs, thd_seqlen_pattern + ) q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() dout_orig = torch.clamp( torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 ).cuda() + # Save inputs for cross-backend comparison + _save_path = os.environ.get("CP_CROSS_BACKEND_SAVE_DIR") + if _save_path: + os.makedirs(_save_path, exist_ok=True) + torch.save( + { + "q": q_orig, + "k": k_orig, + "v": v_orig, + "dout": dout_orig, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_q_padded": cu_seqlens_q_padded, + }, + os.path.join(_save_path, f"inputs_rank{rank}.pt"), + ) if scaling_mode == "delayed": qkv_quantizer = Float8Quantizer( fp8_dtype=tex.DType.kFloat8E4M3, @@ -525,6 +582,54 @@ def run_dpa_with_cp( dq_, dk_, dv_, dbias_ = None, None, None, None d_softmax_offset_ = None + if _save_path: + torch.save( + { + "out": out_.detach(), + "dq": dq_.detach() if dq_ is not None else None, + "dk": dk_.detach() if dk_ is not None else None, + "dv": dv_.detach() if dv_ is not None else None, + }, + os.path.join(_save_path, f"outputs_{cp_comm_type}_rank{rank}.pt"), + ) + + # Benchmark: re-run forward+backward with timing + benchmark_iters = int(benchmark) + if benchmark_iters > 0: + warmup = 10 + t0 = None + for it in range(warmup + benchmark_iters): + q_b, k_b, v_b = [x.clone().detach().requires_grad_() for x in [q_, k_, v_]] + torch.cuda.synchronize() + if it == warmup: + torch.cuda.cudart().cudaProfilerStart() + t0 = time.perf_counter() + with fp8_context: + out_b = core_attn( + q_b, + k_b, + v_b, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, + ) + if isinstance(out_b, tuple): + out_b = out_b[0] + if is_training: + out_b.backward(dout_) + torch.cuda.synchronize() + elapsed = (time.perf_counter() - t0) / benchmark_iters * 1000 + torch.cuda.cudart().cudaProfilerStop() + print( + f"[Rank {rank}] {cp_comm_type} {qkv_format} {dtype}: {elapsed:.2f} ms/iter" + f" ({benchmark_iters} iters)", + flush=True, + ) + # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] From 0497cc8a1ed3ccf034465cc955748b6397317898 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 30 Apr 2026 21:35:32 -0700 Subject: [PATCH 2/3] Add SWA benchmark configs and CP bench results README MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 18 SWA training workload configs (6 real workloads × 3 windows) to benchmark_cp.py for benchmarking sliding-window attention with context parallelism. Replace the old single-GPU FusedAttn vs FlashAttn benchmark script with a README documenting full benchmark results (full causal + SWA, cp=2/4/8, p2p/all_gather/a2a) and individual config runner usage. Signed-off-by: Sudhakar Singh --- benchmarks/attention/README.md | 171 ++++++++++++ benchmarks/attention/benchmark_attention.py | 278 -------------------- tests/pytorch/attention/benchmark_cp.py | 45 ++++ 3 files changed, 216 insertions(+), 278 deletions(-) create mode 100644 benchmarks/attention/README.md delete mode 100644 benchmarks/attention/benchmark_attention.py diff --git a/benchmarks/attention/README.md b/benchmarks/attention/README.md new file mode 100644 index 0000000000..d670ebf087 --- /dev/null +++ b/benchmarks/attention/README.md @@ -0,0 +1,171 @@ +# Context-Parallel Attention Benchmarks + +Benchmark and profile suite for THD context-parallel attention with three communication backends: **p2p** (ring), **all_gather** (full KV gather), and **a2a** (all-to-all head redistribution). + +## Quick Start + +All commands run from the `tests/pytorch/attention/` directory. The runner (`run_attention_with_cp.py`) accepts `key=value` CLI args after the torch.distributed launcher. + +### Single benchmark run + +```bash +cd tests/pytorch/attention + +# Benchmark: 50 timed iterations on 2 GPUs, bucket32k workload, a2a backend +python -m torch.distributed.launch --nproc-per-node=2 \ + run_attention_with_cp.py \ + dtype=bf16 model=bucket32k qkv_format=thd \ + kernel_backend=FusedAttention cp_comm_type=a2a \ + benchmark=50 log_level=WARNING \ + thd_seqlen_pattern="24576,28672,30720,32768" +``` + +### Single profile run (nsys) + +```bash +# Profile: 5 iterations, rank-0 only capture +NSYS_OUT=my_profile torchrun --nproc-per-node=4 --no-python \ + /path/to/nsys_rank0_only.sh \ + python run_attention_with_cp.py \ + dtype=bf16 model=mixed32k qkv_format=thd \ + kernel_backend=FusedAttention cp_comm_type=p2p \ + benchmark=5 log_level=WARNING \ + thd_seqlen_pattern="16384,24576,32768,8192,28672,32768,20480,16384" +``` + +The `nsys_rank0_only.sh` wrapper runs rank 0 under `nsys profile` and other ranks bare. + +### SWA (Sliding Window Attention) + +SWA configs append `_swa` to the model name. p2p does not support SWA — use all_gather or a2a. + +```bash +python -m torch.distributed.launch --nproc-per-node=8 \ + run_attention_with_cp.py \ + dtype=bf16 model=mixed32k_swa512 qkv_format=thd \ + kernel_backend=FusedAttention cp_comm_type=a2a \ + benchmark=50 log_level=WARNING \ + thd_seqlen_pattern="16384,24576,32768,8192,28672,32768,20480,16384" +``` + +## Runner Parameters + +| Parameter | Default | Description | +|---|---|---| +| `dtype` | — | `bf16`, `fp16`, or `fp8` | +| `model` | — | Config name from `benchmark_cp.py` (e.g. `bucket32k`, `mixed32k_swa1024`) | +| `qkv_format` | `bshd` | `bshd`, `sbhd`, or `thd` (variable-length packed) | +| `kernel_backend` | `FlashAttention` | `FusedAttention` (cuDNN) or `FlashAttention` | +| `cp_comm_type` | `p2p` | `p2p`, `all_gather`, or `a2a` | +| `benchmark` | `0` | Number of timed iterations (0 = correctness-only, no timing) | +| `thd_seqlen_pattern` | `random` | Comma-separated per-sequence lengths, or `random`/`max`/`half`/`linear`/`alternating` | +| `log_level` | `WARNING` | Python logging level | +| `is_training` | `True` | Run backward pass | +| `deterministic` | `False` | Force deterministic cuDNN algorithms | + +## Available Configs + +Configs are defined in `benchmark_cp.py` and auto-merged into the runner's config dict. + +### Uniform THD (constant seqlen) + +| Config | B | S | H | g | d | mask | +|---|---:|---:|---:|---:|---:|---| +| bench_8k | 2 | 8192 | 32 | 8 | 128 | causal | +| bench_16k | 1 | 16384 | 32 | 8 | 128 | causal | +| bench_32k | 1 | 32768 | 32 | 8 | 128 | causal | +| cp_thd_0 | 8 | 8192 | 12 | 12 | 128 | causal | +| cp_thd_1 | 8 | 8192 | 12 | 12 | 128 | non-causal | +| cp_thd_2 | 16 | 4096 | 12 | 12 | 128 | causal | +| cp_thd_3 | 8 | 8192 | 12 | 2 | 128 | causal | + +### Variable-length training workloads (Llama3-8B-shaped: H=32, g=8, d=128) + +| Workload | B | S_max | thd_seqlen_pattern | +|---|---:|---:|---| +| rl16k | 8 | 16384 | 4096,6144,6144,8192,8192,10240,12288,16384 | +| bucket32k | 4 | 32768 | 24576,28672,30720,32768 | +| mixed32k | 8 | 32768 | 16384,24576,32768,8192,28672,32768,20480,16384 | +| outlier64k | 4 | 65536 | 8192,8192,8192,65536 | +| bucket64k | 4 | 65536 | 57344,61440,63488,65536 | +| bucket128k | 3 | 131072 | 114688,122880,131072 | + +SWA variants: append `_swa512`, `_swa1024`, or `_swa2048` to any training workload name (e.g. `mixed32k_swa1024`). Window is `(W, 0)` — left-only sliding window with causal mask. + +### Skip rules + +- **a2a**: requires `num_heads % cp_size == 0` AND `num_gqa_groups % cp_size == 0` +- **p2p + SWA**: not supported (p2p ring protocol cannot express windowed attention) + +## Benchmark Results + +Hardware: 8× H100 80GB HBM3, NCCL, bf16, FusedAttention (cuDNN ≥ 9.22). +Iters: 50 timed (after 10 warmup). Values in ms/iter (fwd+bwd). + +### Full causal — training workloads + +| Workload | cp=2 p2p | cp=2 AG | cp=2 a2a | cp=4 p2p | cp=4 AG | cp=4 a2a | cp=8 p2p | cp=8 AG | cp=8 a2a | +|---|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| rl16k | 26.50 | 32.63 | **28.03** | 16.31 | 18.59 | **14.50** | 12.57 | 12.97 | **7.99** | +| bucket32k | **103.82** | 124.23 | 105.37 | 56.27 | 62.49 | **53.58** | 33.17 | 34.65 | **27.40** | +| mixed32k | **140.22** | 167.77 | 142.02 | 76.71 | 84.49 | **72.55** | 45.30 | 47.39 | **37.16** | +| outlier64k | **130.87** | 157.70 | 131.88 | 68.99 | 77.18 | **66.86** | 38.68 | 41.03 | **34.23** | +| bucket64k | **435.59** | 524.93 | 439.71 | 227.24 | 252.45 | **220.08** | 123.44 | 131.01 | **111.55** | +| bucket128k | **1253.48** | 1710.59 | 1293.13 | **640.02** | 729.16 | 640.68 | 337.82 | 360.72 | **324.56** | + +**Bold = fastest.** a2a wins at cp≥4; p2p ties at cp=2 (network-bottlenecked). + +### SWA — training workloads (all_gather vs a2a) + +**cp=2** + +| Workload | W=512 AG | W=512 a2a | W=1024 AG | W=1024 a2a | W=2048 AG | W=2048 a2a | +|---|---:|---:|---:|---:|---:|---:| +| rl16k | 28.74 | **10.44** | 29.19 | **11.98** | 30.02 | **14.84** | +| bucket32k | 99.83 | **16.60** | 100.67 | **19.26** | 102.16 | **24.93** | +| mixed32k | 135.73 | **24.53** | 136.25 | **28.61** | 139.33 | **37.76** | +| outlier64k | 124.48 | **13.01** | 124.99 | **15.09** | 125.68 | **19.25** | +| bucket64k | 412.27 | **33.53** | 416.65 | **39.41** | 419.81 | **52.62** | +| bucket128k | 1369.76 | **49.45** | 1408.04 | **58.71** | 1415.66 | **78.19** | + +**cp=4** + +| Workload | W=512 AG | W=512 a2a | W=1024 AG | W=1024 a2a | W=2048 AG | W=2048 a2a | +|---|---:|---:|---:|---:|---:|---:| +| rl16k | 17.46 | **6.25** | 17.74 | **7.00** | 18.09 | **8.45** | +| bucket32k | 50.81 | **9.63** | 51.44 | **11.02** | 52.37 | **13.68** | +| mixed32k | 69.47 | **14.27** | 70.26 | **16.40** | 71.65 | **20.47** | +| outlier64k | 61.28 | **7.61** | 61.23 | **8.69** | 61.88 | **10.80** | +| bucket64k | 196.76 | **19.32** | 197.83 | **22.25** | 201.31 | **28.24** | +| bucket128k | 547.60 | **27.87** | FAIL* | **32.39** | FAIL* | **41.71** | + +**cp=8** + +| Workload | W=512 AG | W=512 a2a | W=1024 AG | W=1024 a2a | W=2048 AG | W=2048 a2a | +|---|---:|---:|---:|---:|---:|---:| +| rl16k | 12.51 | **3.89** | 12.56 | **4.31** | 12.73 | **5.01** | +| bucket32k | 29.48 | **5.64** | 29.58 | **6.32** | 29.93 | **7.62** | +| mixed32k | 40.79 | **8.04** | 40.86 | **9.12** | 41.44 | **11.18** | +| outlier64k | 33.44 | **4.60** | 33.52 | **5.06** | 33.70 | **6.10** | +| bucket64k | 102.93 | **10.80** | 103.33 | **12.27** | 104.03 | **15.21** | +| bucket128k | FAIL* | **15.56** | FAIL* | **17.76** | FAIL* | **22.22** | + +*bucket128k SWA + all_gather crashes with `cudaErrorIllegalInstruction` at cp≥4. See known issues below. + +### Key takeaway: use a2a for SWA + +all_gather gathers the full KV tensor regardless of window size — SWA only reduces compute, not communication. a2a redistributes Q heads so both communication and compute shrink with the window. The speedup ranges from **2× (rl16k)** to **28× (bucket128k)** depending on seqlen. + +## Known Issues + +**bucket128k SWA + all_gather at cp≥4**: crashes with `cudaErrorIllegalInstruction`. Only affects the AG path — a2a and full causal AG pass. Likely a cuDNN edge case with SWA masking on very large gathered KV tensors (131072 × 2×cp_size tokens). Workaround: use a2a (also 5–28× faster). + +## Correctness Tests + +```bash +# Run all CP benchmark configs through correctness checks (2 GPU) +pytest benchmark_cp.py -k "test_cp_benchmark_configs" -x -v + +# Cross-backend consistency (compare p2p/all_gather/a2a outputs) +pytest benchmark_cp.py -k "test_cp_thd_cross_backend_consistency" -x -v +``` diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py deleted file mode 100644 index 77b2da0b10..0000000000 --- a/benchmarks/attention/benchmark_attention.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import os, sys, time -import subprocess -import pandas as pd -import numpy as np -import torch -import nvtx -import transformer_engine -from tests.pytorch.utils import ( - ModelConfig, - get_available_attention_backends, -) -from tests.pytorch.attention.test_attention import _run_dot_product_attention - -pd.set_option("display.precision", 4) - -# data type -dtype = torch.bfloat16 -# number of iterations after 3 warmup iterations -num_iters = 3 -# checkpointing -ckpt_attn = False -# workspace optimization path for cuDNN attention -workspace_opt = True -# QKV memory layout -qkv_layout = "bshd_bshd_bshd" -# padding between sequences for qkv_format=thd -pad_between_seqs = False -# training mode -is_training = True - -model_configs = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA -} - - -def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported): - config = model_configs[model] - if dtype == torch.bfloat16: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=5e-3, rtol=5e-3) - - cudnn_times = [] - flash_times = [] - warmup_iters = 3 - for i in range(warmup_iters): - if fused_attn_supported: - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - if flash_attn_supported: - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FlashAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - if fused_attn_supported and flash_attn_supported: - torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) - for i, _ in enumerate(flash_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) - - torch.cuda.cudart().cudaProfilerStart() - torch.cuda.synchronize() - fused_attn_start = time.time() - if fused_attn_supported: - for i in range(num_iters): - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - torch.cuda.synchronize() - fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0 - - torch.cuda.synchronize() - flash_attn_start = time.time() - if flash_attn_supported: - for i in range(num_iters): - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FlashAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - torch.cuda.synchronize() - flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0 - - df = pd.read_csv("times.csv") - df = pd.concat( - [ - df, - pd.DataFrame( - [ - [ - fused_attn_time * 1e3 / num_iters, - 0, - 0, - 0, - flash_attn_time * 1e3 / num_iters, - 0, - 0, - 0, - 0, - ] - ], - columns=df.columns, - ), - ], - ignore_index=True, - ) - df.to_csv("times.csv", index=False) - torch.cuda.cudart().cudaProfilerStop() - - -def parse_results(per_cudnn, per_flash, model): - filename = f"prof_{model}_cuda_gpu_trace.csv" - df = pd.read_csv(os.path.join("./", filename)) - df_times = pd.read_csv("times.csv") - row = len(df_times.index) - 1 - - if per_cudnn > 0: - t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy() - t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn) - t_cudnn_avg = np.average(t_cudnn_all, axis=0) - df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6 - df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6 - df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 - - if per_flash > 0: - t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy() - t_flash_all = t_flash_all.reshape(-1, per_flash) - t_flash_avg = np.average(t_flash_all, axis=0) - df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 - df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6 - df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6 - - if per_cudnn > 0 and per_flash > 0: - df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = ( - df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] - / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] - ) - df_times.to_csv("times.csv", index=False) - - -def main(): - times = pd.DataFrame( - columns=[ - "FusedAttention Module", - "FusedAttention Kernels (fwd)", - "FusedAttention Kernels (bwd)", - "FusedAttention Kernels (fwd+bwd)", - "FlashAttention Module", - "FlashAttention Kernels (fwd)", - "FlashAttention Kernels (bwd)", - "FlashAttention Kernels (fwd+bwd)", - "Fused vs Flash Kernels Speedup (fwd+bwd)", - ] - ) - times.to_csv("times.csv", index=False) - - device_id = torch.cuda.current_device() - device_properties = torch.cuda.get_device_properties(device_id) - print( - f"Device {device_id}: " - f"{device_properties.name} GPU, " - f"sm{device_properties.major}{device_properties.minor} compute capability, " - f"{device_properties.total_memory/1024**3:.1f}GB memory" - ) - for model in model_configs.keys(): - config = model_configs[model] - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - - print( - f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' - f'{" and flash-attention" if flash_attn_supported else ""}...' - ) - - prof_cmd = [ - "nsys", - "profile", - "--capture-range=cudaProfilerApi", - "--capture-range-end=stop-shutdown", - "--force-overwrite=true", - f"--output=prof_{model}", - "python", - "-c", - f""" "import benchmark_attention;""", - f"""benchmark_attention.benchmark_dot_product_attention(""" - f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """, - ] - prof_cmd = " ".join(prof_cmd) - subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) - stats_cmd = [ - "nsys", - "stats", - "-q", - "-r", - "cuda_gpu_trace", - "--format", - "csv,column", - "--force-overwrite=true", - "--force-export=true", - f"--output=prof_{model}", - f"prof_{model}.nsys-rep", - ] - if fused_attn_supported: - num_kernels_cudnn = 4 - if config.attn_bias_type == "post_scale_bias": - num_kernels_cudnn = num_kernels_cudnn + 1 - if config.num_heads != config.num_gqa_groups: - num_kernels_cudnn = num_kernels_cudnn + 2 - else: - num_kernels_cudnn = 0 - num_kernels_flash = 4 if flash_attn_supported else 0 - stats_cmd = " ".join(stats_cmd) - subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) - parse_cmd = [ - "python", - "-c", - f""" "import benchmark_attention;""", - f"""benchmark_attention.parse_results(""" - f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """, - ] - parse_cmd = " ".join(parse_cmd) - subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True) - - df_times = pd.read_csv("times.csv") - df_times.index = list(model_configs.keys()) - a = df_times[ - [ - "FusedAttention Kernels (fwd+bwd)", - "FlashAttention Kernels (fwd+bwd)", - "Fused vs Flash Kernels Speedup (fwd+bwd)", - ] - ] - a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"] - print() - print(a) - - -if __name__ == "__main__": - main() diff --git a/tests/pytorch/attention/benchmark_cp.py b/tests/pytorch/attention/benchmark_cp.py index 1a15ff4c38..c6ead15407 100644 --- a/tests/pytorch/attention/benchmark_cp.py +++ b/tests/pytorch/attention/benchmark_cp.py @@ -90,7 +90,52 @@ ModelConfig(3, 131072, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), "114688,122880,131072", ), + # SWA variants of mixed32k for cross-backend SWA correctness checks. + "mixed32k_swa512": ( + ModelConfig( + 8, 32768, 32, 128, num_gqa_groups=8, + attn_mask_type="causal", window_size=(512, 0), + ), + "16384,24576,32768,8192,28672,32768,20480,16384", + ), + "mixed32k_swa1024": ( + ModelConfig( + 8, 32768, 32, 128, num_gqa_groups=8, + attn_mask_type="causal", window_size=(1024, 0), + ), + "16384,24576,32768,8192,28672,32768,20480,16384", + ), + "mixed32k_swa2048": ( + ModelConfig( + 8, 32768, 32, 128, num_gqa_groups=8, + attn_mask_type="causal", window_size=(2048, 0), + ), + "16384,24576,32768,8192,28672,32768,20480,16384", + ), +} + +# SWA variants of all 6 training workloads at windows 512/1024/2048. +_swa_base = { + "bucket32k": (4, 32768, "24576,28672,30720,32768"), + "bucket64k": (4, 65536, "57344,61440,63488,65536"), + "mixed32k_full": (8, 32768, "16384,24576,32768,8192,28672,32768,20480,16384"), + "rl16k": (8, 16384, "4096,6144,6144,8192,8192,10240,12288,16384"), + "outlier64k": (4, 65536, "8192,8192,8192,65536"), + "bucket128k": (3, 131072, "114688,122880,131072"), } +for _name, (_b, _s, _pat) in _swa_base.items(): + for _w in (512, 1024, 2048): + # Use "_full" suffix dropped; keep mixed32k_swa* names matching above + if _name == "mixed32k_full": + continue # already added with mixed32k_swa{512,1024,2048} + _key = f"{_name}_swa{_w}" + _training_workloads[_key] = ( + ModelConfig( + _b, _s, 32, 128, num_gqa_groups=8, + attn_mask_type="causal", window_size=(_w, 0), + ), + _pat, + ) for _name, (_cfg, _pat) in _training_workloads.items(): _cfg.thd_seqlen_pattern = _pat model_configs_fused_attn[_name] = _cfg From 27ae68fd114242e1fbfeee940deae3d639ecc6dd Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 1 May 2026 00:22:09 -0700 Subject: [PATCH 3/3] Update CP attention bench results with second-node numbers Re-ran all 6 real-training configs (full causal + SWA{512,1024,2048}) on a second 8x H100 node with cuDNN 9.21 / NCCL 2.29.7 and replaced the prior results tables. cp=2 was re-run serially because 4-wide concurrency on a single node distorted a2a SWA timings ~2x and triggered intermittent cudaErrorIllegalInstruction on AG SWA configs. The original-node bucket128k SWA AG cp>=4 'FAIL' matrix is no longer present on the new node, but a smaller intermittent-crash failure mode (cp=2 SWA AG under heavy concurrency) was observed; documented as a known issue with the serial-run workaround. Signed-off-by: Sudhakar Singh --- benchmarks/attention/README.md | 83 ++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/benchmarks/attention/README.md b/benchmarks/attention/README.md index d670ebf087..b4e3b771e0 100644 --- a/benchmarks/attention/README.md +++ b/benchmarks/attention/README.md @@ -99,21 +99,35 @@ SWA variants: append `_swa512`, `_swa1024`, or `_swa2048` to any training worklo ## Benchmark Results -Hardware: 8× H100 80GB HBM3, NCCL, bf16, FusedAttention (cuDNN ≥ 9.22). +Hardware: 8× H100 80GB HBM3 (full NV18 NVLink mesh), cuDNN 9.21, NCCL 2.29.7, bf16, FusedAttention. Iters: 50 timed (after 10 warmup). Values in ms/iter (fwd+bwd). +cp=2 runs in serial; cp=4 and cp=8 used 2-wide / 1-wide GPU partitioning. ### Full causal — training workloads | Workload | cp=2 p2p | cp=2 AG | cp=2 a2a | cp=4 p2p | cp=4 AG | cp=4 a2a | cp=8 p2p | cp=8 AG | cp=8 a2a | |---|---:|---:|---:|---:|---:|---:|---:|---:|---:| -| rl16k | 26.50 | 32.63 | **28.03** | 16.31 | 18.59 | **14.50** | 12.57 | 12.97 | **7.99** | -| bucket32k | **103.82** | 124.23 | 105.37 | 56.27 | 62.49 | **53.58** | 33.17 | 34.65 | **27.40** | -| mixed32k | **140.22** | 167.77 | 142.02 | 76.71 | 84.49 | **72.55** | 45.30 | 47.39 | **37.16** | -| outlier64k | **130.87** | 157.70 | 131.88 | 68.99 | 77.18 | **66.86** | 38.68 | 41.03 | **34.23** | -| bucket64k | **435.59** | 524.93 | 439.71 | 227.24 | 252.45 | **220.08** | 123.44 | 131.01 | **111.55** | -| bucket128k | **1253.48** | 1710.59 | 1293.13 | **640.02** | 729.16 | 640.68 | 337.82 | 360.72 | **324.56** | +| rl16k | **20.20** | 24.95 | 20.90 | 12.98 | 15.16 | **11.26** | 11.87 | 11.12 | **6.33** | +| bucket32k | **38.31** | 46.32 | 39.41 | 22.46 | 24.90 | **20.57** | 15.01 | 14.93 | **10.91** | +| mixed32k | **59.17** | 71.75 | 61.21 | 35.07 | 38.43 | **31.68** | 22.64 | 22.88 | **16.54** | +| outlier64k | **125.09** | 151.13 | 127.00 | 68.29 | 76.22 | **65.03** | 40.01 | 41.04 | **33.48** | +| bucket64k | **125.07** | 151.14 | 126.98 | 69.41 | 75.90 | **65.64** | 39.91 | 41.06 | **33.45** | +| bucket128k | **263.77** | 323.10 | 267.57 | 139.93 | 156.81 | **136.45** | 77.15 | 81.32 | **69.50** | -**Bold = fastest.** a2a wins at cp≥4; p2p ties at cp=2 (network-bottlenecked). +**Bold = fastest.** p2p wins at cp=2 (lowest comm cost). a2a wins at cp=4 and cp=8. + +### Scaling efficiency (cp=2 → cp=8, full causal) + +Ideal would be 4×. a2a sustains the best scaling for every workload. + +| Workload | p2p scale | AG scale | a2a scale | +|---|---:|---:|---:| +| rl16k | 1.70× | 2.24× | **3.30×** | +| bucket32k | 2.55× | 3.10× | **3.61×** | +| mixed32k | 2.61× | 3.14× | **3.70×** | +| outlier64k | 3.13× | 3.68× | **3.79×** | +| bucket64k | 3.13× | 3.68× | **3.80×** | +| bucket128k | 3.42× | **3.97×** | 3.85× | ### SWA — training workloads (all_gather vs a2a) @@ -121,44 +135,53 @@ Iters: 50 timed (after 10 warmup). Values in ms/iter (fwd+bwd). | Workload | W=512 AG | W=512 a2a | W=1024 AG | W=1024 a2a | W=2048 AG | W=2048 a2a | |---|---:|---:|---:|---:|---:|---:| -| rl16k | 28.74 | **10.44** | 29.19 | **11.98** | 30.02 | **14.84** | -| bucket32k | 99.83 | **16.60** | 100.67 | **19.26** | 102.16 | **24.93** | -| mixed32k | 135.73 | **24.53** | 136.25 | **28.61** | 139.33 | **37.76** | -| outlier64k | 124.48 | **13.01** | 124.99 | **15.09** | 125.68 | **19.25** | -| bucket64k | 412.27 | **33.53** | 416.65 | **39.41** | 419.81 | **52.62** | -| bucket128k | 1369.76 | **49.45** | 1408.04 | **58.71** | 1415.66 | **78.19** | +| rl16k | 22.76 | **8.93** | 23.09 | **10.18** | 23.83 | **12.46** | +| bucket32k | 39.43 | **9.33** | 39.57 | **10.73** | 40.22 | **13.39** | +| mixed32k | 60.31 | **15.44** | 60.97 | **17.81** | 62.55 | **22.45** | +| outlier64k | 121.29 | **14.98** | 121.88 | **17.36** | 123.41 | **22.11** | +| bucket64k | 121.32 | **14.96** | 121.88 | **17.34** | 123.35 | **22.11** | +| bucket128k | 253.68 | **19.71** | 254.55 | **23.05** | 256.89 | **30.42** | **cp=4** | Workload | W=512 AG | W=512 a2a | W=1024 AG | W=1024 a2a | W=2048 AG | W=2048 a2a | |---|---:|---:|---:|---:|---:|---:| -| rl16k | 17.46 | **6.25** | 17.74 | **7.00** | 18.09 | **8.45** | -| bucket32k | 50.81 | **9.63** | 51.44 | **11.02** | 52.37 | **13.68** | -| mixed32k | 69.47 | **14.27** | 70.26 | **16.40** | 71.65 | **20.47** | -| outlier64k | 61.28 | **7.61** | 61.23 | **8.69** | 61.88 | **10.80** | -| bucket64k | 196.76 | **19.32** | 197.83 | **22.25** | 201.31 | **28.24** | -| bucket128k | 547.60 | **27.87** | FAIL* | **32.39** | FAIL* | **41.71** | +| rl16k | 14.29 | **5.38** | 14.59 | **6.01** | 15.07 | **7.15** | +| bucket32k | 21.52 | **5.56** | 21.56 | **6.23** | 21.96 | **7.51** | +| mixed32k | 33.16 | **8.91** | 33.60 | **10.11** | 34.24 | **12.38** | +| outlier64k | 60.28 | **8.59** | 60.84 | **9.83** | 61.41 | **12.21** | +| bucket64k | 60.35 | **8.57** | 60.65 | **9.83** | 61.37 | **12.16** | +| bucket128k | 121.71 | **11.47** | 122.28 | **13.17** | 123.56 | **16.53** | **cp=8** | Workload | W=512 AG | W=512 a2a | W=1024 AG | W=1024 a2a | W=2048 AG | W=2048 a2a | |---|---:|---:|---:|---:|---:|---:| -| rl16k | 12.51 | **3.89** | 12.56 | **4.31** | 12.73 | **5.01** | -| bucket32k | 29.48 | **5.64** | 29.58 | **6.32** | 29.93 | **7.62** | -| mixed32k | 40.79 | **8.04** | 40.86 | **9.12** | 41.44 | **11.18** | -| outlier64k | 33.44 | **4.60** | 33.52 | **5.06** | 33.70 | **6.10** | -| bucket64k | 102.93 | **10.80** | 103.33 | **12.27** | 104.03 | **15.21** | -| bucket128k | FAIL* | **15.56** | FAIL* | **17.76** | FAIL* | **22.22** | - -*bucket128k SWA + all_gather crashes with `cudaErrorIllegalInstruction` at cp≥4. See known issues below. +| rl16k | 10.67 | **3.51** | 10.77 | **3.91** | 10.91 | **4.38** | +| bucket32k | 13.80 | **3.71** | 13.99 | **4.00** | 14.17 | **4.65** | +| mixed32k | 21.17 | **5.25** | 21.29 | **5.83** | 21.64 | **6.95** | +| outlier64k | 33.36 | **5.14** | 33.87 | **5.69** | 33.89 | **6.86** | +| bucket64k | 33.44 | **5.06** | 33.55 | **5.73** | 33.90 | **6.83** | +| bucket128k | 64.09 | **6.61** | 64.40 | **7.39** | 64.65 | **9.11** | ### Key takeaway: use a2a for SWA -all_gather gathers the full KV tensor regardless of window size — SWA only reduces compute, not communication. a2a redistributes Q heads so both communication and compute shrink with the window. The speedup ranges from **2× (rl16k)** to **28× (bucket128k)** depending on seqlen. +all_gather gathers the full KV tensor regardless of window size — SWA only reduces compute, not communication. a2a redistributes Q heads so both communication and compute shrink with the window. The AG-vs-a2a speedup ranges from **~2× (rl16k)** to **~13× (bucket128k W=512)** depending on seqlen and window size. + +### a2a vs all_gather speedup with SWA (AG/a2a ratio) + +| Workload | cp=2 W=512 | cp=2 W=1024 | cp=2 W=2048 | cp=4 W=512 | cp=4 W=1024 | cp=4 W=2048 | cp=8 W=512 | cp=8 W=1024 | cp=8 W=2048 | +|---|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| rl16k | 2.5× | 2.3× | 1.9× | 2.7× | 2.4× | 2.1× | 3.0× | 2.8× | 2.5× | +| bucket32k | 4.2× | 3.7× | 3.0× | 3.9× | 3.5× | 2.9× | 3.7× | 3.5× | 3.0× | +| mixed32k | 3.9× | 3.4× | 2.8× | 3.7× | 3.3× | 2.8× | 4.0× | 3.7× | 3.1× | +| outlier64k | 8.1× | 7.0× | 5.6× | 7.0× | 6.2× | 5.0× | 6.5× | 6.0× | 4.9× | +| bucket64k | 8.1× | 7.0× | 5.6× | 7.0× | 6.2× | 5.0× | 6.6× | 5.9× | 5.0× | +| bucket128k | 12.9× | 11.0× | 8.4× | 10.6× | 9.3× | 7.5× | 9.7× | 8.7× | 7.1× | ## Known Issues -**bucket128k SWA + all_gather at cp≥4**: crashes with `cudaErrorIllegalInstruction`. Only affects the AG path — a2a and full causal AG pass. Likely a cuDNN edge case with SWA masking on very large gathered KV tensors (131072 × 2×cp_size tokens). Workaround: use a2a (also 5–28× faster). +**SWA + all_gather rare `cudaErrorIllegalInstruction`**: a small number of SWA AG runs at cp=2 with 4-wide parallel-batch execution crashed intermittently. The same configs pass cleanly when run alone or with cp≥4. The crash signature matches an earlier stream-race fix (`cp_stream.wait_stream(...)` after the THD reorder, commit `611d876e`), suggesting another asynchronous race only exposed under heavy concurrent driver load. Workaround: use a2a (always faster anyway), or run cp=2 SWA AG configs serially. ## Correctness Tests