diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index db13e9f1e0..7eb34a62e4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,6 +22,24 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +# Run CP tests (deterministic + non-deterministic) first so they can be parallelized. +# Each needs 4 GPUs, so >=8 GPUs allows them to run concurrently on disjoint GPU sets. +NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())") +echo "Detected $NUM_GPUS GPU(s)" +if [ "$NUM_GPUS" -ge 8 ]; then + echo "Running CP tests in parallel: non-deterministic on GPUs 0-3, deterministic on GPUs 4-7" + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP_NONDET=$! + CUDA_VISIBLE_DEVICES=4,5,6,7 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP_DET=$! + wait $PID_CP_NONDET || test_fail "test_attention_with_cp.py" + wait $PID_CP_DET || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" +else + echo "Running CP tests sequentially: need >=8 GPUs for parallel execution" + python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" +fi + python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" @@ -29,7 +47,6 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 642eb93b06..30f1fc38c0 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -2,13 +2,25 @@ # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -pip3 install pytest==8.2.1 +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" # Limit parallel build jobs to avoid overwhelming system resources export MAX_JOBS=32 @@ -16,12 +28,18 @@ export MAX_JOBS=32 # Iterate over Flash Attention versions sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` export FLASH_ATTN_CUDA_ARCHS=$sm_arch +# CP tests are expensive and run only once per arch: +# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper +# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 +# Non-CP tests still run for every FA version in the array. if [ $sm_arch -gt 90 ] then - FA_versions=(2.8.3 4.0.0b8) + FA_versions=(2.8.3 4.0.0b11) + CP_FA_VERSION="${FA_versions[-1]}" elif [ $sm_arch -eq 90 ] then - FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8) + FA_versions=(2.8.3 3.0.0b1 4.0.0b11) + CP_FA_VERSION="3.0.0b1" fi for fa_version in "${FA_versions[@]}" @@ -35,12 +53,63 @@ do then pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation else - git clone https://github.com/Dao-AILab/flash-attention.git - cd flash-attention/hopper && python setup.py install - cd ../../ + # FA3 source build (~20 min). Skip if FA3 is already installed. + if python3 -c "import flash_attn_3" 2>/dev/null; then + echo "FA3 already installed (from base image); skipping source build" + else + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/hopper && python setup.py install + cd ../../ + fi fi + # Ensure local test utils is found before nvidia-cutlass-dsl's utils package + export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-} + # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py + NUM_GPUS=$(nvidia-smi -L | wc -l) + echo "Detected $NUM_GPUS GPU(s)" + + # Suffix junit XMLs with the FA version so per-iteration results are preserved + # (otherwise pytest.xml is overwritten on each loop iteration and we lose timing + # data for all but the last FA version). + fa_tag="${fa_version//./_}" + XML_ATTN="$XML_LOG_DIR/pytest_test_attention_fa${fa_tag}.xml" + XML_CP="$XML_LOG_DIR/pytest_test_attention_with_cp_fa${fa_tag}.xml" + + if [ "$fa_version" = "$CP_FA_VERSION" ]; then + echo "Running CP tests with FA $fa_version (CP version for sm$sm_arch)" + if [ "$NUM_GPUS" -ge 5 ]; then + CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 )) + CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS) + echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" + + CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_ATTN \ + $TE_PATH/tests/pytorch/attention/test_attention.py & + PID_ATTN=$! + CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_CP \ + $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP=$! + + wait $PID_ATTN || test_fail "test_attention.py (FA $fa_version)" + wait $PID_CP || test_fail "test_attention_with_cp.py (FA $fa_version)" + else + echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_CP $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py (FA $fa_version)" + fi + else + echo "Skipping CP tests for FA $fa_version (CP only runs with FA $CP_FA_VERSION on sm$sm_arch)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" + fi done + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 9f6b4944e6..6fca61d3c0 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -47,6 +47,7 @@ def generate_input_shapes( config: ModelConfig, world_size: int, kernel_backend: str, + fa_pad_between_seqs: str = "False", ): if qkv_format == "bshd": q_input_shape = ( @@ -115,9 +116,12 @@ def generate_input_shapes( ).cuda() cu_seqlens_q = torch.clone(cu_seqlens_q_padded) - # Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does, - # cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only. - if kernel_backend == "FusedAttention": + # Generate padded data (cu_seqlens_q reflects non-padded lengths, so it + # differs from cu_seqlens_q_padded) for FusedAttention always, and for + # FlashAttention only when its test param requests it. DPA auto-detects + # pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded + # mismatch. + if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True": cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() # NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded` @@ -196,6 +200,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + fa_pad_between_seqs="False", deterministic="False", log_level=logging.WARNING, ): @@ -314,7 +319,7 @@ def run_dpa_with_cp( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs) q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() @@ -557,11 +562,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): + for tensor, name in zip(tensors, names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + assert torch.all(~torch.isnan(tensor)), f"{name} has nan values" + assert torch.all(~torch.isinf(tensor)), f"{name} has inf values" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -617,49 +622,60 @@ def run_dpa_with_cp( if is_training: dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] - ).item() - == 0 - ) + num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - ( + cu_seqlens_q_padded - cu_seqlens_q + )[:-1] cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 + num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - ( + cu_seqlens_kv_padded - cu_seqlens_kv + )[:-1] + # FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover). + # Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an + # output rather than mutating it in-place, triggering PyTorch's aliasing constraint. + # Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs. + if fa_pad_between_seqs == "True": + # out_ is a view inside the CP custom autograd Function, so in-place + # zeroing is blocked by PyTorch. Clone to break the view relationship. + out_ = out_.clone() + for x in [out, out_, dq]: + for b in range(config.batch_size): + x[ + cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1] + ] = 0.0 + x[cu_seqlens_q_padded[-1] :] = 0.0 + for x in [dk, dv]: + for b in range(config.batch_size): + x[ + cu_seqlens_kv_padded[b + 1] + - num_pads_kv[b] : cu_seqlens_kv_padded[b + 1] + ] = 0.0 + x[cu_seqlens_kv_padded[-1] :] = 0.0 + # Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py). + for xname, x, cu, np_ in [ + ("dq_", dq_, cu_seqlens_q_padded, num_pads_q), + ("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv), + ("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv), + ]: + nnz = torch.count_nonzero(x[cu[-1] :]).item() + assert nnz == 0, ( + f"{xname} has {nnz} nonzero values in tail padding — " + "context_parallel.py should zero padding positions" ) + for b in range(config.batch_size): + if np_[b] > 0: + nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item() + assert nnz == 0, ( + f"{xname} has {nnz} nonzero values in batch {b} padding — " + "context_parallel.py should zero padding positions" + ) else: - # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() out_ = out_ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..5c46949f67 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -124,7 +124,7 @@ def reset_global_fp8_state(): @pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) -@pytest.mark.parametrize("pad_between_seqs", [False]) +@pytest.mark.parametrize("pad_between_seqs", [False, True]) def test_dot_product_attention( dtype, model_configs, @@ -157,6 +157,8 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if pad_between_seqs and qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") if qkv_format == "thd" and "padding" not in config.attn_mask_type: config.attn_mask_type = ( "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" @@ -195,19 +197,6 @@ def test_dot_product_attention( ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention - # mannually pads and unpads the input and output of FlashAttention for testing purposes - if ( - pad_between_seqs - and FlashAttentionUtils.is_installed - and not ( - config.max_seqlen_q != config.max_seqlen_kv - and config.attn_mask_type in ["causal", "padding_causal"] - ) - and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) - ): - flash_attn_supported = True - # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") @@ -1301,12 +1290,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: block.softmax_offset.requires_grad = True # Run a forward and backward pass - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: q = inp_orig[0] k = inp_orig[1] v = inp_orig[2] d_out = out_grad_orig - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: q = inp[0] k = inp[1] v = inp[2] @@ -1322,14 +1311,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=config.max_seqlen_kv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, - cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, + cu_seqlens_q_padded=( + cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), + cu_seqlens_kv_padded=( + cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, fast_zero_fill=True, + pad_between_seqs=pad_between_seqs, # Only pass num_splits when exercising the FlashAttention path num_splits=config.num_splits if backend == "FlashAttention" else 1, ) @@ -1343,12 +1337,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: if is_training: return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: return out, max_logit, (None, None, None, d_softmax_offset) - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) if is_training: diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index f0d2c27c12..a03f51f6c9 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -306,10 +306,19 @@ def _submit(pool: PoolWorker, **kwargs) -> None: @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("pad_between_seqs", [False, True]) +def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type, pad_between_seqs): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 pool = cp_pool(num_gpus) + if pad_between_seqs: + if qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") + if not FlashAttentionUtils.v3_is_installed or get_device_compute_capability() > (9, 0): + pytest.skip("pad_between_seqs with CP requires Flash Attention v3 on Hopper (sm90)!") + if cp_comm_type == "a2a+p2p": + pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type @@ -361,6 +370,7 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + fa_pad_between_seqs=pad_between_seqs, log_level=pytest_logging_level, ) @@ -606,6 +616,7 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: config_copy = copy.deepcopy(config) @@ -628,6 +639,23 @@ def test_cp_with_fused_attention( pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + # Observed: cuDNN det THD backward asks for ~128 * bHSS bytes of workspace + # on sm90; at 1<<30 that's 128 GiB, won't fit on H100's 80 GB. Held exactly + # at b=2 + power-of-2 S in our sweep; for b>=3 the workspace was observed to + # grow super-linearly (b=4 took ~4x the b=2 amount, not 2x) — revisit if a + # config uses b>2. + SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30 + if ( + _deterministic + and qkv_format == "thd" + and get_device_compute_capability() == (9, 0) + and config.batch_size * config.num_heads * config.max_seqlen_q * config.max_seqlen_kv + >= SM90_DET_FUSED_THD_BWD_MAX_BHSS + ): + pytest.skip( + "Deterministic FusedAttention backward with THD format OOMs on sm90" + " for large bHSS configs (known cuDNN issue)." + ) _submit( pool, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 6e097265ff..6c6adc6e3f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -822,10 +822,13 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + pad_between_seqs: Optional[bool] = False, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, num_splits: Optional[int] = 1, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -1024,8 +1027,16 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q if qkv_format == "thd" else None, - cu_seqlens_kv if qkv_format == "thd" else None, + ( + cu_seqlens_q_padded + if pad_between_seqs + else (cu_seqlens_q if qkv_format == "thd" else None) + ), + ( + cu_seqlens_kv_padded + if pad_between_seqs + else (cu_seqlens_kv if qkv_format == "thd" else None) + ), self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -1037,7 +1048,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, - pad_between_seqs=False, + pad_between_seqs=pad_between_seqs, use_flash_attn_3=use_flash_attn_3, fp8_output=fp8_output, ) @@ -1082,8 +1093,12 @@ def forward( else: func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None): - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append( + cu_seqlens_q_padded if pad_between_seqs else cu_seqlens_q + ) + fa_optional_forward_args_thd.append( + cu_seqlens_kv_padded if pad_between_seqs else cu_seqlens_kv + ) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if use_flash_attn_4: @@ -1139,6 +1154,13 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["num_splits"] = num_splits + if pad_between_seqs: + fa_3_optional_forward_kwargs["seqused_q"] = ( + cu_seqlens_q[1:] - cu_seqlens_q[:-1] + ) + fa_3_optional_forward_kwargs["seqused_k"] = ( + cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + ) if inference_params is None: fa_3_optional_forward_kwargs["deterministic"] = self.deterministic else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 35684625a5..36847e40ed 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -663,6 +663,8 @@ def get_fa_args( dq=None, dk=None, dv=None, + seqused_q=None, + seqused_k=None, ): """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: @@ -672,7 +674,9 @@ def get_fa_args( *[None] * 4, # k_new, v_new, qv, out cu_seqlens_q, cu_seqlens_kv, - *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + None, # cu_seqlens_k_new + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, *[None] @@ -690,8 +694,8 @@ def get_fa_args( return [ cu_seqlens_q, cu_seqlens_kv, - None, # sequed_q - None, # sequed_k + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, dq, @@ -701,8 +705,8 @@ def get_fa_args( return [ None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sequed_q - None, # sequed_k + None, # seqused_q + None, # seqused_k max_seqlen_q, max_seqlen_kv, dq, @@ -1020,6 +1024,9 @@ def cp_p2p_fwd_flash_attn( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1046,6 +1053,20 @@ def cp_p2p_fwd_flash_attn( fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + seqused_q = None + seqused_k = None + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + # Derive actual token counts per batch element from cu_seqlens + seqused_q = cu_seqlens_q_per_step[1:] - cu_seqlens_q_per_step[:-1] + seqused_k = cu_seqlens_kv_per_step[1:] - cu_seqlens_kv_per_step[:-1] + # Override cu_seqlens to padded layout for tensor memory layout + cu_seqlens_q_ = cu_seqlens_q_padded + cu_seqlens_kv_ = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_ = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_ = cu_seqlens_q_padded // 2 + fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1054,6 +1075,8 @@ def cp_p2p_fwd_flash_attn( cu_seqlens_kv=cu_seqlens_kv_, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -1296,6 +1319,9 @@ def cp_p2p_bwd_flash_attn( rng_states, softmax_lse, softmax_lse_, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1304,7 +1330,10 @@ def cp_p2p_bwd_flash_attn( section, ): """Per-tile backward call of CP P2P with FlashAttention backend""" - dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: @@ -1329,17 +1358,33 @@ def cp_p2p_bwd_flash_attn( max_seqlen_q_ = max_seqlen_q // 2 softmax_lse__ = softmax_lse_ + seqused_q = None + seqused_k = None + cu_seqlens_q_bwd = cu_seqlens_q_per_step[cp_size - step - 1] + cu_seqlens_kv_bwd = cu_seqlens_kv_per_step[cp_size - step - 1] + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q_bwd[1:] - cu_seqlens_q_bwd[:-1] + seqused_k = cu_seqlens_kv_bwd[1:] - cu_seqlens_kv_bwd[:-1] + cu_seqlens_q_bwd = cu_seqlens_q_padded + cu_seqlens_kv_bwd = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_bwd = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_bwd = cu_seqlens_q_padded // 2 + fa_backward_args_thd = get_fa_args( False, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + cu_seqlens_q=cu_seqlens_q_bwd, + cu_seqlens_kv=cu_seqlens_kv_bwd, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if use_flash_attn_3: fa_backward_kwargs["is_causal"] = causal_ @@ -1779,6 +1824,9 @@ def forward( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # cp_size = 4: @@ -1821,7 +1869,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) elif i <= rank: @@ -1848,7 +1898,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1875,7 +1927,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1900,7 +1954,11 @@ def forward( ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( - cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, + *prepare_outputs, + section, + ) ) # softmax_lse correction @@ -2150,6 +2208,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8_meta = fp8_meta @@ -2560,6 +2619,9 @@ def backward(ctx, dout, *_args): rng_states, softmax_lse, softmax_lse_, + ctx.pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, @@ -2575,7 +2637,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) elif i >= (cp_size - rank - 1): section = "lower-triangle" @@ -2586,7 +2650,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "upper-triangle" @@ -2597,7 +2663,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "all" @@ -2608,7 +2676,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) # dq, dk, dv are reduced across steps in higher precision @@ -3838,6 +3908,7 @@ def forward( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, @@ -4073,14 +4144,25 @@ def forward( out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) out_part = out_f16 else: + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -4217,6 +4299,7 @@ def forward( ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_type = softmax_type ctx.dQKV_quantizer = dQKV_quantizer @@ -4405,18 +4488,32 @@ def backward(ctx, dout, *_args): dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors - dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + if ctx.pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q, k, v]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if ctx.pad_between_seqs and ctx.use_flash_attn_3 and ctx.dqkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, ctx.dqkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state @@ -4524,6 +4621,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, d_softmax_offset, None, ) @@ -4740,6 +4838,7 @@ def attn_forward_func_with_cp( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b38b66c3e6..ca848a9480 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1658,10 +1658,13 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, inference_params=inference_params, flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, num_splits=num_splits, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if use_fused_attention: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1f1637cecd..6565e9f6f6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -651,7 +651,7 @@ def get_attention_backend( # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 - # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | % 256 == 0 # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Flash v4 | FP16/BF16 | TODO | sm80+ | bshd,sbhd,thd | TODO @@ -691,9 +691,9 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if inference_params.is_paged: - if use_flash_attention_2 and inference_params.page_size < 256: + if use_flash_attention_2 and inference_params.page_size % 256 != 0: if FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 for page size < 256") + logger.debug("Disabling FlashAttention 2 for page size not divisible by 256") use_flash_attention_2 = False if use_flash_attention_2: if not FlashAttentionUtils.is_installed: @@ -703,6 +703,16 @@ def get_attention_backend( "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" ) use_flash_attention_2 = False + else: + # Non-paged KV cache still passes a block_table to FA2 for thd_2bshd support, + # and FA2 enforces page_size % 256 == 0 on the effective page size (max_seqlen_kv). + if use_flash_attention_2 and max_seqlen_kv % 256 != 0: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for non-paged KV cache" + " with max_seqlen_kv not divisible by 256" + ) + use_flash_attention_2 = False if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: logger.debug("Disabling FlashAttention 4 as it does not support KV cache.") use_flash_attention_4 = False @@ -844,15 +854,18 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if pad_between_seqs: if ( # pylint: disable=too-many-boolean-expressions - (use_flash_attention_2 and FlashAttentionUtils.is_installed) - or (use_flash_attention_3 and FlashAttentionUtils.v3_is_installed) - or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed) - ): + use_flash_attention_2 and FlashAttentionUtils.is_installed + ) or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed): logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " + "Disabling FlashAttention 2 and 4 for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) - use_flash_attention = False + use_flash_attention_2 = False + use_flash_attention_4 = False + # FA3 supports pad_between_seqs via seqused_q/seqused_k + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for pad_between_seqs = True") + use_unfused_attention = False if device_compute_capability == (12, 0): if cudnn_version < (9, 18, 1): if use_fused_attention: @@ -1273,9 +1286,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_flash_attention_3 and deterministic and FlashAttentionUtils.v3_is_installed: - if head_dim_qk >= 256: + if is_training and max(head_dim_qk, head_dim_v) >= 256: logger.debug( - "Disabling FlashAttention 3 for deterministic execution with head_dim_qk >= 256." + "Disabling FlashAttention 3 for deterministic backward with" + " max(head_dim_qk, head_dim_v) >= 256. Found: head_dim_qk = %s, head_dim_v = %s.", + head_dim_qk, + head_dim_v, ) use_flash_attention_3 = False if use_fused_attention and deterministic: