From 10e4cfc889d4bc91685d61fe4acb9afa3a7ab362 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 24 Apr 2026 15:35:25 -0700 Subject: [PATCH 01/18] [PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP Add support for padding between sequences (pad_between_seqs) in the FlashAttention 3 backend when used with context parallelism (CP). Key changes: - backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward - context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths, zero FA3 padding garbage in CP forward, fix a2a backward alignment - dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens - utils.py: Gate FA3 deterministic backward for hdim>=256, fix flash_attn_supported override for cross-attention and large head_dim, disable UnfusedDotProductAttention for pad_between_seqs, add SM100+ FA3 skip Signed-off-by: Sudhakar Singh --- .../dot_product_attention/backends.py | 32 +++- .../dot_product_attention/context_parallel.py | 141 +++++++++++++++--- .../dot_product_attention.py | 3 + .../attention/dot_product_attention/utils.py | 34 +++-- 4 files changed, 173 insertions(+), 37 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..da41564bc1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -803,10 +803,13 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + pad_between_seqs: Optional[bool] = False, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, num_splits: Optional[int] = 1, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -1005,8 +1008,16 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q if qkv_format == "thd" else None, - cu_seqlens_kv if qkv_format == "thd" else None, + ( + cu_seqlens_q_padded + if pad_between_seqs + else (cu_seqlens_q if qkv_format == "thd" else None) + ), + ( + cu_seqlens_kv_padded + if pad_between_seqs + else (cu_seqlens_kv if qkv_format == "thd" else None) + ), self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -1018,7 +1029,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, - pad_between_seqs=False, + pad_between_seqs=pad_between_seqs, use_flash_attn_3=use_flash_attn_3, fp8_output=fp8_output, ) @@ -1063,8 +1074,12 @@ def forward( else: func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None): - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append( + cu_seqlens_q_padded if pad_between_seqs else cu_seqlens_q + ) + fa_optional_forward_args_thd.append( + cu_seqlens_kv_padded if pad_between_seqs else cu_seqlens_kv + ) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if use_flash_attn_4: @@ -1120,6 +1135,13 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["num_splits"] = num_splits + if pad_between_seqs: + fa_3_optional_forward_kwargs["seqused_q"] = ( + cu_seqlens_q[1:] - cu_seqlens_q[:-1] + ) + fa_3_optional_forward_kwargs["seqused_k"] = ( + cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + ) if inference_params is None: fa_3_optional_forward_kwargs["deterministic"] = self.deterministic else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7b10593acf..a151705f51 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -640,6 +640,8 @@ def get_fa_args( dq=None, dk=None, dv=None, + seqused_q=None, + seqused_k=None, ): """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: @@ -649,7 +651,9 @@ def get_fa_args( *[None] * 4, # k_new, v_new, qv, out cu_seqlens_q, cu_seqlens_kv, - *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + None, # cu_seqlens_k_new + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, *[None] @@ -667,8 +671,8 @@ def get_fa_args( return [ cu_seqlens_q, cu_seqlens_kv, - None, # sequed_q - None, # sequed_k + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, dq, @@ -678,8 +682,8 @@ def get_fa_args( return [ None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sequed_q - None, # sequed_k + None, # seqused_q + None, # seqused_k max_seqlen_q, max_seqlen_kv, dq, @@ -1000,6 +1004,9 @@ def cp_p2p_fwd_flash_attn( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1026,6 +1033,20 @@ def cp_p2p_fwd_flash_attn( fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + seqused_q = None + seqused_k = None + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + # Derive actual token counts per batch element from cu_seqlens + seqused_q = cu_seqlens_q_per_step[1:] - cu_seqlens_q_per_step[:-1] + seqused_k = cu_seqlens_kv_per_step[1:] - cu_seqlens_kv_per_step[:-1] + # Override cu_seqlens to padded layout for tensor memory layout + cu_seqlens_q_ = cu_seqlens_q_padded + cu_seqlens_kv_ = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_ = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_ = cu_seqlens_q_padded // 2 + fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1034,6 +1055,8 @@ def cp_p2p_fwd_flash_attn( cu_seqlens_kv=cu_seqlens_kv_, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -1297,6 +1320,9 @@ def cp_p2p_bwd_flash_attn( rng_states, softmax_lse, softmax_lse_, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1305,7 +1331,10 @@ def cp_p2p_bwd_flash_attn( section, ): """Per-tile backward call of CP P2P with FlashAttention backend""" - dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: @@ -1330,17 +1359,33 @@ def cp_p2p_bwd_flash_attn( max_seqlen_q_ = max_seqlen_q // 2 softmax_lse__ = softmax_lse_ + seqused_q = None + seqused_k = None + cu_seqlens_q_bwd = cu_seqlens_q_per_step[cp_size - step - 1] + cu_seqlens_kv_bwd = cu_seqlens_kv_per_step[cp_size - step - 1] + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q_bwd[1:] - cu_seqlens_q_bwd[:-1] + seqused_k = cu_seqlens_kv_bwd[1:] - cu_seqlens_kv_bwd[:-1] + cu_seqlens_q_bwd = cu_seqlens_q_padded + cu_seqlens_kv_bwd = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_bwd = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_bwd = cu_seqlens_q_padded // 2 + fa_backward_args_thd = get_fa_args( False, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + cu_seqlens_q=cu_seqlens_q_bwd, + cu_seqlens_kv=cu_seqlens_kv_bwd, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if use_flash_attn_3: fa_backward_kwargs["is_causal"] = causal_ @@ -1779,6 +1824,9 @@ def forward( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # cp_size = 4: @@ -1821,7 +1869,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) elif i <= rank: @@ -1848,7 +1898,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1875,7 +1927,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1900,7 +1954,11 @@ def forward( ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( - cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, + *prepare_outputs, + section, + ) ) # softmax_lse correction @@ -2150,6 +2208,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8_meta = fp8_meta @@ -2560,6 +2619,9 @@ def backward(ctx, dout, *_args): rng_states, softmax_lse, softmax_lse_, + ctx.pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, @@ -2575,7 +2637,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) elif i >= (cp_size - rank - 1): section = "lower-triangle" @@ -2586,7 +2650,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "upper-triangle" @@ -2597,7 +2663,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "all" @@ -2608,7 +2676,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) # dq, dk, dv are reduced across steps in higher precision @@ -3842,6 +3912,7 @@ def forward( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, @@ -4076,14 +4147,25 @@ def forward( out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) out_part = out_f16 else: + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -4220,6 +4302,7 @@ def forward( ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_type = softmax_type ctx.dQKV_quantizer = dQKV_quantizer @@ -4408,18 +4491,32 @@ def backward(ctx, dout, *_args): dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors - dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + if ctx.pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q, k, v]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if ctx.pad_between_seqs and ctx.use_flash_attn_3 and ctx.dqkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, ctx.dqkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state @@ -4527,6 +4624,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, d_softmax_offset, None, ) @@ -4743,6 +4841,7 @@ def attn_forward_func_with_cp( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..0bdfb7b431 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1536,10 +1536,13 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, inference_params=inference_params, flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, num_splits=num_splits, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if use_fused_attention: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..b1eca5106c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -651,7 +651,7 @@ def get_attention_backend( # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 - # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | % 256 == 0 # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Flash v4 | FP16/BF16 | TODO | sm80+ | bshd,sbhd,thd | TODO @@ -691,9 +691,9 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if inference_params.is_paged: - if use_flash_attention_2 and inference_params.page_size < 256: + if use_flash_attention_2 and inference_params.page_size % 256 != 0: if FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 for page size < 256") + logger.debug("Disabling FlashAttention 2 for page size not divisible by 256") use_flash_attention_2 = False if use_flash_attention_2: if not FlashAttentionUtils.is_installed: @@ -703,6 +703,16 @@ def get_attention_backend( "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" ) use_flash_attention_2 = False + else: + # Non-paged KV cache still passes a block_table to FA2 for thd_2bshd support, + # and FA2 enforces page_size % 256 == 0 on the effective page size (max_seqlen_kv). + if use_flash_attention_2 and max_seqlen_kv % 256 != 0: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for non-paged KV cache" + " with max_seqlen_kv not divisible by 256" + ) + use_flash_attention_2 = False if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: logger.debug("Disabling FlashAttention 4 as it does not support KV cache.") use_flash_attention_4 = False @@ -844,15 +854,17 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if pad_between_seqs: if ( # pylint: disable=too-many-boolean-expressions - (use_flash_attention_2 and FlashAttentionUtils.is_installed) - or (use_flash_attention_3 and FlashAttentionUtils.v3_is_installed) - or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed) - ): + use_flash_attention_2 and FlashAttentionUtils.is_installed + ) or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed): logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " + "Disabling FlashAttention 2 for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) - use_flash_attention = False + use_flash_attention_2 = False + # FA3 supports pad_between_seqs via seqused_q/seqused_k + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for pad_between_seqs = True") + use_unfused_attention = False if device_compute_capability == (12, 0): if cudnn_version < (9, 18, 1): if use_fused_attention: @@ -1303,9 +1315,9 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_flash_attention_3 and deterministic and FlashAttentionUtils.v3_is_installed: - if head_dim_qk >= 256: + if head_dim_qk > 128: logger.debug( - "Disabling FlashAttention 3 for deterministic execution with head_dim_qk >= 256." + "Disabling FlashAttention 3 for deterministic execution with head_dim_qk > 128." ) use_flash_attention_3 = False if use_fused_attention and deterministic: From 2a49dee12400ef47a1f8181d28325466c6731331 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 24 Apr 2026 15:35:35 -0700 Subject: [PATCH 02/18] [PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 96 +++++++++++-------- tests/pytorch/attention/test_attention.py | 34 +++---- .../attention/test_attention_with_cp.py | 22 ++++- 3 files changed, 91 insertions(+), 61 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..ba77821867 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -37,6 +37,7 @@ def generate_input_shapes( config: ModelConfig, world_size: int, kernel_backend: str, + fa_pad_between_seqs: str = "False", ): if qkv_format == "bshd": q_input_shape = ( @@ -105,9 +106,12 @@ def generate_input_shapes( ).cuda() cu_seqlens_q = torch.clone(cu_seqlens_q_padded) - # Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does, - # cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only. - if kernel_backend == "FusedAttention": + # Generate padded data (cu_seqlens_q reflects non-padded lengths, so it + # differs from cu_seqlens_q_padded) for FusedAttention always, and for + # FlashAttention only when its test param requests it. DPA auto-detects + # pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded + # mismatch. + if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True": cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() # NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded` @@ -186,6 +190,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + fa_pad_between_seqs="False", deterministic="False", log_level=logging.WARNING, ): @@ -288,7 +293,7 @@ def run_dpa_with_cp( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs) q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() @@ -531,11 +536,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): + for tensor, name in zip(tensors, names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + assert torch.all(~torch.isnan(tensor)), f"{name} has nan values" + assert torch.all(~torch.isinf(tensor)), f"{name} has inf values" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -588,49 +593,60 @@ def run_dpa_with_cp( if is_training: dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] - ).item() - == 0 - ) + num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - ( + cu_seqlens_q_padded - cu_seqlens_q + )[:-1] cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 + num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - ( + cu_seqlens_kv_padded - cu_seqlens_kv + )[:-1] + # FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover). + # Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an + # output rather than mutating it in-place, triggering PyTorch's aliasing constraint. + # Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs. + if fa_pad_between_seqs == "True": + # out_ is a view inside the CP custom autograd Function, so in-place + # zeroing is blocked by PyTorch. Clone to break the view relationship. + out_ = out_.clone() + for x in [out, out_, dq]: + for b in range(config.batch_size): + x[ + cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1] + ] = 0.0 + x[cu_seqlens_q_padded[-1] :] = 0.0 + for x in [dk, dv]: + for b in range(config.batch_size): + x[ + cu_seqlens_kv_padded[b + 1] + - num_pads_kv[b] : cu_seqlens_kv_padded[b + 1] + ] = 0.0 + x[cu_seqlens_kv_padded[-1] :] = 0.0 + # Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py). + for xname, x, cu, np_ in [ + ("dq_", dq_, cu_seqlens_q_padded, num_pads_q), + ("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv), + ("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv), + ]: + nnz = torch.count_nonzero(x[cu[-1] :]).item() + assert nnz == 0, ( + f"{xname} has {nnz} nonzero values in tail padding — " + "context_parallel.py should zero padding positions" ) + for b in range(config.batch_size): + if np_[b] > 0: + nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item() + assert nnz == 0, ( + f"{xname} has {nnz} nonzero values in batch {b} padding — " + "context_parallel.py should zero padding positions" + ) else: - # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() out_ = out_ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c9ea791444..56aaed72fa 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -124,7 +124,7 @@ def reset_global_fp8_state(): @pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) -@pytest.mark.parametrize("pad_between_seqs", [False]) +@pytest.mark.parametrize("pad_between_seqs", [False, True]) def test_dot_product_attention( dtype, model_configs, @@ -157,6 +157,8 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if pad_between_seqs and qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") if qkv_format == "thd" and "padding" not in config.attn_mask_type: config.attn_mask_type = ( "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" @@ -195,19 +197,6 @@ def test_dot_product_attention( ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention - # mannually pads and unpads the input and output of FlashAttention for testing purposes - if ( - pad_between_seqs - and FlashAttentionUtils.is_installed - and not ( - config.max_seqlen_q != config.max_seqlen_kv - and config.attn_mask_type in ["causal", "padding_causal"] - ) - and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) - ): - flash_attn_supported = True - # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") @@ -1330,12 +1319,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: block.softmax_offset.requires_grad = True # Run a forward and backward pass - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: q = inp_orig[0] k = inp_orig[1] v = inp_orig[2] d_out = out_grad_orig - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: q = inp[0] k = inp[1] v = inp[2] @@ -1351,14 +1340,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=config.max_seqlen_kv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, - cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, + cu_seqlens_q_padded=( + cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), + cu_seqlens_kv_padded=( + cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, fast_zero_fill=True, + pad_between_seqs=pad_between_seqs, # Only pass num_splits when exercising the FlashAttention path num_splits=config.num_splits if backend == "FlashAttention" else 1, ) @@ -1372,12 +1366,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: if is_training: return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: return out, max_logit, (None, None, None, d_softmax_offset) - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) if is_training: diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..302f6a88ab 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -91,11 +91,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("pad_between_seqs", [False, True]) +def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + if pad_between_seqs: + if qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") + if not FlashAttentionUtils.v3_is_installed: + pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") + if cp_comm_type == "a2a+p2p": + pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type @@ -148,6 +157,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + fa_pad_between_seqs=pad_between_seqs, log_level=pytest_logging_level, ), ) @@ -386,6 +396,7 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: config_copy = copy.deepcopy(config) @@ -404,6 +415,15 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") + deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + if deterministic: + if config.softmax_type != "vanilla": + pytest.skip( + "Deterministic mode does not support non-vanilla softmax with FusedAttention" + ) + if config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + run_distributed( get_bash_arguments( num_gpus_per_node=num_gpus, From 34e3d62e5de0de169d0abc0c0493e109b4a2edef Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 24 Apr 2026 15:35:43 -0700 Subject: [PATCH 03/18] [QA] Add CP deterministic tests to L3 and support TE_PATH in FA test Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + qa/L3_pytorch_FA_versions_test/test.sh | 47 ++++++++++++++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index db13e9f1e0..6319b78a53 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -30,6 +30,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 642eb93b06..edc0f4d353 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -2,13 +2,25 @@ # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -pip3 install pytest==8.2.1 +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" # Limit parallel build jobs to avoid overwhelming system resources export MAX_JOBS=32 @@ -41,6 +53,35 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py + NUM_GPUS=$(nvidia-smi -L | wc -l) + echo "Detected $NUM_GPUS GPU(s)" + if [ "$NUM_GPUS" -ge 5 ]; then + CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 )) + CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS) + echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" + + CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_LOG_DIR/pytest.xml \ + $TE_PATH/tests/pytorch/attention/test_attention.py & + PID_ATTN=$! + + CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \ + $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP=$! + wait $PID_ATTN || test_fail "test_attention.py" + wait $PID_CP || test_fail "test_attention_with_cp.py" + else + echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + fi done + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 From 4745f98282387df069fc132906b0f1b4e8451980 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 24 Apr 2026 16:01:36 -0700 Subject: [PATCH 04/18] [PyTorch] Fix FA3 deterministic gate to match upstream backward constraint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index b1eca5106c..f02aab53dc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1315,9 +1315,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_flash_attention_3 and deterministic and FlashAttentionUtils.v3_is_installed: - if head_dim_qk > 128: + if is_training and max(head_dim_qk, head_dim_v) >= 256: logger.debug( - "Disabling FlashAttention 3 for deterministic execution with head_dim_qk > 128." + "Disabling FlashAttention 3 for deterministic backward with" + " max(head_dim_qk, head_dim_v) >= 256. Found: head_dim_qk = %s, head_dim_v = %s.", + head_dim_qk, + head_dim_v, ) use_flash_attention_3 = False if use_fused_attention and deterministic: From 4be004f3cc63afde3174adaf8ad5de8d0689812b Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 24 Apr 2026 16:49:55 -0700 Subject: [PATCH 05/18] [PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index f02aab53dc..852ceb7662 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -857,10 +857,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_flash_attention_2 and FlashAttentionUtils.is_installed ) or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed): logger.debug( - "Disabling FlashAttention 2 for qkv_format = thd when there is " + "Disabling FlashAttention 2 and 4 for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention_2 = False + use_flash_attention_4 = False # FA3 supports pad_between_seqs via seqused_q/seqused_k if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for pad_between_seqs = True") From a2b0f1b9d6aeb73ed574cc593f2cae182b793c55 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 24 Apr 2026 21:40:55 -0700 Subject: [PATCH 06/18] [QA] Fix cutlass-dsl utils shadow in FA versions test FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass` adds cutlass/base_dsl/ to sys.path. That directory contains a utils/ package that shadows tests/pytorch/utils.py, breaking collection of test_attention_with_cp.py with: ImportError: cannot import name 'ModelConfig' from 'utils' Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py is always resolved first, regardless of what FA4 dependencies install. Signed-off-by: Sudhakar Singh --- qa/L3_pytorch_FA_versions_test/test.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index edc0f4d353..514cbc19f9 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -52,6 +52,9 @@ do cd ../../ fi + # Ensure local test utils is found before nvidia-cutlass-dsl's utils package + export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-} + # Run tests NUM_GPUS=$(nvidia-smi -L | wc -l) echo "Detected $NUM_GPUS GPU(s)" From fc9182fde84df3146d432a2ecde9f47b79eac91f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 29 Apr 2026 11:11:20 -0700 Subject: [PATCH 07/18] skip tests which OOM in deterministic+backward+hopper+large_configs as its a known cudnn issue Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 302f6a88ab..8fab2b4c0d 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -423,6 +423,12 @@ def test_cp_with_fused_attention( ) if config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + if qkv_format == "thd" and config.num_heads >= 20 and get_device_compute_capability() == (9, 0): + pytest.skip( + "Deterministic FusedAttention backward with THD format OOMs on sm90" + " for this particular test config since cuDNN reserves memory" + " proportional to bHSS (known cuDNN issue)." + ) run_distributed( get_bash_arguments( From 7928bc9f07d22e714953b4636d24eafd16c3bbab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:12:16 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 8fab2b4c0d..20a23eab0d 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -423,7 +423,11 @@ def test_cp_with_fused_attention( ) if config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") - if qkv_format == "thd" and config.num_heads >= 20 and get_device_compute_capability() == (9, 0): + if ( + qkv_format == "thd" + and config.num_heads >= 20 + and get_device_compute_capability() == (9, 0) + ): pytest.skip( "Deterministic FusedAttention backward with THD format OOMs on sm90" " for this particular test config since cuDNN reserves memory" From 2464f433834b8d5d4d60a0fa073dc1b3d7c976be Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 30 Apr 2026 12:45:46 -0700 Subject: [PATCH 09/18] make cp det and nondet tests run in parallel whenever possible Signed-off-by: Sudhakar Singh --- qa/L1_pytorch_distributed_unittest/test.sh | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 6319b78a53..7eb34a62e4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,6 +22,24 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +# Run CP tests (deterministic + non-deterministic) first so they can be parallelized. +# Each needs 4 GPUs, so >=8 GPUs allows them to run concurrently on disjoint GPU sets. +NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())") +echo "Detected $NUM_GPUS GPU(s)" +if [ "$NUM_GPUS" -ge 8 ]; then + echo "Running CP tests in parallel: non-deterministic on GPUs 0-3, deterministic on GPUs 4-7" + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP_NONDET=$! + CUDA_VISIBLE_DEVICES=4,5,6,7 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP_DET=$! + wait $PID_CP_NONDET || test_fail "test_attention_with_cp.py" + wait $PID_CP_DET || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" +else + echo "Running CP tests sequentially: need >=8 GPUs for parallel execution" + python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" +fi + python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" @@ -29,8 +47,6 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" -NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" From 13ba00468afdf4a433c06977cbb806b0e2d606c5 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 5 May 2026 15:06:38 -0700 Subject: [PATCH 10/18] [QA] L3: gate CP tests per-arch to avoid CI timeout PR 2596 added deterministic CP runs to the L3 FA-versions matrix, multiplying CP wall time across every FA version and causing CI timeouts (pipeline 50243000). Run CP tests once per arch instead, picking the FA version each arch's CP code path actually supports: - sm90 (H100): FA3 3.0.0b1 - context_parallel.py is FA3-only on Hopper (use_flash_attn_3 threaded throughout, FA4 not wired in; pad_between_seqs gated on use_flash_attn_3 at lines 1038, 1366) - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 Non-CP test_attention.py still runs for every FA version in the array. Also drop FA 2.7.3 from the sm90 list (no longer maintained as a target) and bump the FA4 pin from 4.0.0b8 to 4.0.0b11. b8 has an SM90 backward kernel bug fixed by upstream PR #2513 in b11 (get_smem_store_C() got multiple values for argument 'transpose'). Signed-off-by: Sudhakar Singh --- qa/L3_pytorch_FA_versions_test/test.sh | 60 ++++++++++++++++---------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 514cbc19f9..e86224ab66 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -30,10 +30,20 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri export FLASH_ATTN_CUDA_ARCHS=$sm_arch if [ $sm_arch -gt 90 ] then - FA_versions=(2.8.3 4.0.0b8) + FA_versions=(2.8.3 4.0.0b11) elif [ $sm_arch -eq 90 ] then - FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8) + FA_versions=(2.8.3 3.0.0b1 4.0.0b11) +fi + +# CP tests are expensive and run only once per arch: +# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper +# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 +# Non-CP tests still run for every FA version in the array. +if [ $sm_arch -eq 90 ]; then + CP_FA_VERSION="3.0.0b1" +else + CP_FA_VERSION="${FA_versions[-1]}" fi for fa_version in "${FA_versions[@]}" @@ -58,27 +68,33 @@ do # Run tests NUM_GPUS=$(nvidia-smi -L | wc -l) echo "Detected $NUM_GPUS GPU(s)" - if [ "$NUM_GPUS" -ge 5 ]; then - CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 )) - CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS) - echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" - - CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ - --junitxml=$XML_LOG_DIR/pytest.xml \ - $TE_PATH/tests/pytorch/attention/test_attention.py & - PID_ATTN=$! - - CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ - --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \ - $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & - PID_CP=$! - - wait $PID_ATTN || test_fail "test_attention.py" - wait $PID_CP || test_fail "test_attention_with_cp.py" + if [ "$fa_version" = "$CP_FA_VERSION" ]; then + echo "Running CP tests with FA $fa_version (CP version for sm$sm_arch)" + if [ "$NUM_GPUS" -ge 5 ]; then + CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 )) + CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS) + echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" + + CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_LOG_DIR/pytest.xml \ + $TE_PATH/tests/pytorch/attention/test_attention.py & + PID_ATTN=$! + + CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \ + $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP=$! + + wait $PID_ATTN || test_fail "test_attention.py" + wait $PID_CP || test_fail "test_attention_with_cp.py" + else + echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + fi else - echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + echo "Skipping CP tests for FA $fa_version (CP only runs with FA $CP_FA_VERSION on sm$sm_arch)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" fi done From e41bb968e25e3c73edde532a5ae19e714c081f81 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 5 May 2026 16:23:45 -0700 Subject: [PATCH 11/18] [QA] L3: skip pre-installed FA3 build, per-FA junit XMLs Three follow-ups on top of 13ba0046 (L3 per-arch CP gating): 1. Skip the inline FA3 source build when flash_attn_interface is already importable. This makes the script a no-op on FA3 install when the base image has FA3 baked in (companion to TE !573 on te_ci, which auto-sets INSTALL_FA3=${RUN_L3_TESTS} so FA3 is preinstalled for L3 pipelines). Saves ~20 min of L3 H100 wall time once both land. Falls back to the existing inline build when FA3 is not pre-installed. 2. Suffix junit XMLs with the FA version (pytest_test_attention_fa2_8_3.xml etc.) so per-iteration results are preserved instead of overwritten. Pipeline 50348672 had no per-FA timing visibility because pytest.xml was clobbered by each loop iteration. 3. Include FA version in test_fail messages so CI dashboards show which FA iteration caused a failure (was "test_attention.py", now "test_attention.py (FA 2.8.3)"). Also fold the CP_FA_VERSION assignment into the same if-block as FA_versions (was a separate if-block immediately after) since the two are arch-keyed in lockstep. Signed-off-by: Sudhakar Singh --- qa/L3_pytorch_FA_versions_test/test.sh | 50 +++++++++++++++----------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index e86224ab66..09ff15d90b 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -28,22 +28,18 @@ export MAX_JOBS=32 # Iterate over Flash Attention versions sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` export FLASH_ATTN_CUDA_ARCHS=$sm_arch +# CP tests are expensive and run only once per arch: +# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper +# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 +# Non-CP tests still run for every FA version in the array. if [ $sm_arch -gt 90 ] then FA_versions=(2.8.3 4.0.0b11) + CP_FA_VERSION="${FA_versions[-1]}" elif [ $sm_arch -eq 90 ] then FA_versions=(2.8.3 3.0.0b1 4.0.0b11) -fi - -# CP tests are expensive and run only once per arch: -# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper -# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 -# Non-CP tests still run for every FA version in the array. -if [ $sm_arch -eq 90 ]; then - CP_FA_VERSION="3.0.0b1" -else - CP_FA_VERSION="${FA_versions[-1]}" + CP_FA_VERSION="3.0.0b1" fi for fa_version in "${FA_versions[@]}" @@ -57,9 +53,15 @@ do then pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation else - git clone https://github.com/Dao-AILab/flash-attention.git - cd flash-attention/hopper && python setup.py install - cd ../../ + # FA3 source build (~20 min). Skip if already pre-installed in the base image + # via Dockerfile.base INSTALL_FA3=1 (auto-set when RUN_L3_TESTS=1). + if python3 -c "import flash_attn_interface" 2>/dev/null; then + echo "FA3 already installed (from base image); skipping source build" + else + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/hopper && python setup.py install + cd ../../ + fi fi # Ensure local test utils is found before nvidia-cutlass-dsl's utils package @@ -68,6 +70,14 @@ do # Run tests NUM_GPUS=$(nvidia-smi -L | wc -l) echo "Detected $NUM_GPUS GPU(s)" + + # Suffix junit XMLs with the FA version so per-iteration results are preserved + # (otherwise pytest.xml is overwritten on each loop iteration and we lose timing + # data for all but the last FA version). + fa_tag="${fa_version//./_}" + XML_ATTN="$XML_LOG_DIR/pytest_test_attention_fa${fa_tag}.xml" + XML_CP="$XML_LOG_DIR/pytest_test_attention_with_cp_fa${fa_tag}.xml" + if [ "$fa_version" = "$CP_FA_VERSION" ]; then echo "Running CP tests with FA $fa_version (CP version for sm$sm_arch)" if [ "$NUM_GPUS" -ge 5 ]; then @@ -76,25 +86,25 @@ do echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ - --junitxml=$XML_LOG_DIR/pytest.xml \ + --junitxml=$XML_ATTN \ $TE_PATH/tests/pytorch/attention/test_attention.py & PID_ATTN=$! CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ - --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \ + --junitxml=$XML_CP \ $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & PID_CP=$! - wait $PID_ATTN || test_fail "test_attention.py" - wait $PID_CP || test_fail "test_attention_with_cp.py" + wait $PID_ATTN || test_fail "test_attention.py (FA $fa_version)" + wait $PID_CP || test_fail "test_attention_with_cp.py (FA $fa_version)" else echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_CP $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py (FA $fa_version)" fi else echo "Skipping CP tests for FA $fa_version (CP only runs with FA $CP_FA_VERSION on sm$sm_arch)" - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" fi done From 7b8ca1e70006dc87918916b4558580f6545ca22d Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 6 May 2026 12:05:08 -0700 Subject: [PATCH 12/18] b200 shouldnt run FA3 even if present Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 20a23eab0d..29940899b2 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -100,8 +100,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_bet if pad_between_seqs: if qkv_format != "thd": pytest.skip("pad_between_seqs only applies to THD format!") - if not FlashAttentionUtils.v3_is_installed: - pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") + if not FlashAttentionUtils.v3_is_installed or get_device_compute_capability() > (9, 0): + pytest.skip("pad_between_seqs with CP requires Flash Attention v3 on Hopper (sm90)!") if cp_comm_type == "a2a+p2p": pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") From 8794aa80124e314332f03143f11ba810d2c52304 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 01:03:56 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index a406f8d867..2aed5c319c 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -307,9 +307,7 @@ def _submit(pool: PoolWorker, **kwargs) -> None: @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) @pytest.mark.parametrize("pad_between_seqs", [False, True]) -def test_cp_with_flash_attention( - cp_pool, dtype, model, qkv_format, cp_comm_type, pad_between_seqs -): +def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type, pad_between_seqs): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 pool = cp_pool(num_gpus) From c4b6e076b4d976ba3c5bfdb62462525b8e48cef2 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 22 May 2026 04:03:32 -0700 Subject: [PATCH 14/18] L3: drop stale RUN_L3_TESTS=1 note; use flash_attn_3 for FA3 check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address two pending review comments: 1. The "auto-set when RUN_L3_TESTS=1" annotation on the base-image FA3 preinstall is no longer accurate; drop it so readers don't grep for a coupling that doesn't exist. 2. `flash_attn_interface` reads like a generic FA API even though the top-level shim is only created by the FA3 install. Switching to `import flash_attn_3` makes the FA3-specific intent unambiguous and matches the FA3 package layout produced by the source build. Local validation on H100 (sm90) with FA3 active, TE worktree resolving to the editable install (verified via three-layer import check from /tmp): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (3:52), 33 passed / 0 failed det (2:55). 33 pad-True nondet passes + 21 pad-True det passes confirm the FA3+THD+CP path is exercised; 5 det OOM cases skip cleanly via the existing inline guard. Same test scope is exercised by L1_pytorch_distributed_unittest (parallel det+nondet) and the FA3 iteration of L3_pytorch_FA_versions_test; the changes here are L3-only documentation/detection tweaks and do not alter the Python test code, but the L1+L3 CP execution was re-run on the cleaned PR head end-to-end as proof. Signed-off-by: Sudhakar Singh --- qa/L3_pytorch_FA_versions_test/test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 09ff15d90b..d32896ca8d 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -54,8 +54,8 @@ do pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation else # FA3 source build (~20 min). Skip if already pre-installed in the base image - # via Dockerfile.base INSTALL_FA3=1 (auto-set when RUN_L3_TESTS=1). - if python3 -c "import flash_attn_interface" 2>/dev/null; then + # via Dockerfile.base INSTALL_FA3=1. + if python3 -c "import flash_attn_3" 2>/dev/null; then echo "FA3 already installed (from base image); skipping source build" else git clone https://github.com/Dao-AILab/flash-attention.git From d3bd4e4aea8fb3a61b5fc742228f5fc634f4fa5b Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 22 May 2026 14:11:25 -0700 Subject: [PATCH 15/18] Address review nits: bHSS-gated OOM skip; drop Dockerfile.base specifics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Det FusedAttention backward THD/sm90 OOM skip: gate on the actual memory pressure (b*H*S*S) instead of num_heads >= 20. The cuDNN workspace is proportional to bHSS, so a future config with H >= 20 but small b or S would be needlessly skipped under the old guard, while a config with H < 20 but large b*S that hit the same OOM wouldn't be caught. Threshold 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS in 1.07B–4.29B) and lets cp_1_0/ cp_2_1/cp_2_4/cp_3_2/cp_3_4 (bHSS ~0.40B) keep running. 2. L3 FA3 install comment: drop the "Dockerfile.base INSTALL_FA3=1" reference. The detection check is the contract; mentioning a specific image variable couples this script to an out-of-tree provisioning detail that may evolve independently. Local validation on H100 (sm90) with FA3 active and TE worktree resolving to editable (verified via /tmp-cwd three-layer import check after reinstall — the /usr/local TE shadow had reappeared between sessions): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (4:09), 33 passed / 0 failed det (3:14). 33 pad-True nondet passes + 21 pad-True det passes; 5 det OOM cases skip via the new bHSS gate — same cases as the old num_heads-only gate. Signed-off-by: Sudhakar Singh --- qa/L3_pytorch_FA_versions_test/test.sh | 3 +-- tests/pytorch/attention/test_attention_with_cp.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index d32896ca8d..30f1fc38c0 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -53,8 +53,7 @@ do then pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation else - # FA3 source build (~20 min). Skip if already pre-installed in the base image - # via Dockerfile.base INSTALL_FA3=1. + # FA3 source build (~20 min). Skip if FA3 is already installed. if python3 -c "import flash_attn_3" 2>/dev/null; then echo "FA3 already installed (from base image); skipping source build" else diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2aed5c319c..9a95793b9c 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -639,16 +639,22 @@ def test_cp_with_fused_attention( pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + # Det FusedAttention backward with THD on sm90 OOMs because cuDNN reserves + # workspace proportional to b*H*S*S. Gate on that product, not num_heads, + # so the skip stays correct if a new config has small b/S but H >= 20. if ( _deterministic and qkv_format == "thd" - and config.num_heads >= 20 and get_device_compute_capability() == (9, 0) + and config.batch_size + * config.num_heads + * config.max_seqlen_q + * config.max_seqlen_kv + >= 1_000_000_000 ): pytest.skip( "Deterministic FusedAttention backward with THD format OOMs on sm90" - " for this particular test config since cuDNN reserves memory" - " proportional to bHSS (known cuDNN issue)." + " for large bHSS configs (known cuDNN issue)." ) _submit( From 0638d5823edfa3e40c3b195a0b513ae4db96fa9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 21:12:57 +0000 Subject: [PATCH 16/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9a95793b9c..53a66490d9 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -646,10 +646,7 @@ def test_cp_with_fused_attention( _deterministic and qkv_format == "thd" and get_device_compute_capability() == (9, 0) - and config.batch_size - * config.num_heads - * config.max_seqlen_q - * config.max_seqlen_kv + and config.batch_size * config.num_heads * config.max_seqlen_q * config.max_seqlen_kv >= 1_000_000_000 ): pytest.skip( From 1563b1056a9ede0bc29cb91739b4796c912be682 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 22 May 2026 17:05:58 -0700 Subject: [PATCH 17/18] Name the OOM-skip threshold and explain the 128*bHSS workspace observation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review nits on the deterministic THD-backward OOM guard: 1. Replace the magic number 1_000_000_000 with the named constant SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable and labeled. 2. Replace the prefatory comment with a short note tying the number to cuDNN's actual workspace request (~128 * bHSS bytes, measured on cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is 128 GiB, which doesn't fit on H100's 80 GB. 3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up internally so workspace grows super-linearly past b=2 (b=4 asks for 4x the b=2 workspace, not 2x). The current fused-essential matrix is all b=2, so the threshold stays correct for what the test exercises; the note is there so the next person doesn't have to rediscover it. Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 53a66490d9..5d9ac3eda1 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -639,15 +639,17 @@ def test_cp_with_fused_attention( pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") - # Det FusedAttention backward with THD on sm90 OOMs because cuDNN reserves - # workspace proportional to b*H*S*S. Gate on that product, not num_heads, - # so the skip stays correct if a new config has small b/S but H >= 20. + # cuDNN det THD backward workspace on sm90 is ~128 * bHSS bytes; at 1<<30 + # that's 128 GiB, won't fit on H100's 80 GB. Exact at b=2 + power-of-2 S; + # for b>=3 cuDNN rounds batch up internally so workspace grows super-linearly + # (e.g. b=4 wants 4x b=2's workspace, not 2x) — revisit if a config uses b>2. + SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30 if ( _deterministic and qkv_format == "thd" and get_device_compute_capability() == (9, 0) and config.batch_size * config.num_heads * config.max_seqlen_q * config.max_seqlen_kv - >= 1_000_000_000 + >= SM90_DET_FUSED_THD_BWD_MAX_BHSS ): pytest.skip( "Deterministic FusedAttention backward with THD format OOMs on sm90" From a27e30154ac621392bd7fdbf928e8afc9e7208cc Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 22 May 2026 17:09:38 -0700 Subject: [PATCH 18/18] Reword OOM-skip comment as observations, not cuDNN-internal claims We measured the workspace request from outside cuDNN, so the comment should say "observed" rather than asserting what cuDNN does. Reframes the ~128 * bHSS bytes formula and the super-linear b>=3 behavior as empirical observations from our sweep. No code change. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5d9ac3eda1..a03f51f6c9 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -639,10 +639,11 @@ def test_cp_with_fused_attention( pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") - # cuDNN det THD backward workspace on sm90 is ~128 * bHSS bytes; at 1<<30 - # that's 128 GiB, won't fit on H100's 80 GB. Exact at b=2 + power-of-2 S; - # for b>=3 cuDNN rounds batch up internally so workspace grows super-linearly - # (e.g. b=4 wants 4x b=2's workspace, not 2x) — revisit if a config uses b>2. + # Observed: cuDNN det THD backward asks for ~128 * bHSS bytes of workspace + # on sm90; at 1<<30 that's 128 GiB, won't fit on H100's 80 GB. Held exactly + # at b=2 + power-of-2 S in our sweep; for b>=3 the workspace was observed to + # grow super-linearly (b=4 took ~4x the b=2 amount, not 2x) — revisit if a + # config uses b>2. SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30 if ( _deterministic