Skip to content

guarding max_logits fused attention for cudnn < 9.21.0#3091

Open
francesco-bertolotti wants to merge 2 commits into
NVIDIA:mainfrom
francesco-bertolotti:f14-max-logits-guard
Open

guarding max_logits fused attention for cudnn < 9.21.0#3091
francesco-bertolotti wants to merge 2 commits into
NVIDIA:mainfrom
francesco-bertolotti:f14-max-logits-guard

Conversation

@francesco-bertolotti

Copy link
Copy Markdown
Contributor

Description

get_attention_backend selects FusedAttention for return_max_logit=True regardless of the cuDNN version, but cuDNN only supports emitting Max alongside the softmax Stats from cuDNN 9.21.0. On older cuDNN versions the forward pass fails at graph-build time with:

RuntimeError: transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:419
in function operator(): cuDNN Error: CompositeSoftmaxNode can only output certain
combinations of stats, max and sum_exp: stats only, max and sum_exp only, or none of the above.

Reproduction (observed on A100 / sm80 with cuDNN 9.10.2; the failure is cuDNN-version dependent, not architecture dependent):

pytest -x tests/pytorch/attention/test_attention.py::test_dpa_max_logit
# fails in the first non-skipped case, e.g.
# test_dpa_max_logit[sbhd_sbhd_sbhd-max_logit_1-model_configs0-dtype0]

Root cause

FusedAttention requests both the Stats and Max outputs from the SDPA node (fused_attn_f16_arbitrary_seqlen.cu, sdpa_options.set_logit_max(Max)). In cudnn-frontend, that combination is only representable through the unified softmax descriptor (CUDNN_ATTR_OPERATION_SDPA_FWD_SOFTMAX_DESC), which requires a cuDNN backend >= 9.21.0:

  • cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h: the composite SDPA node only wires Stats/Max/Sum_exp into a UnifiedSoftmaxNode when effective_cudnn_ver >= 92100; on older versions only Stats is set (via CUDNN_ATTR_OPERATION_SDPA_FWD_STATSDESC).
  • cudnn-frontend/include/cudnn_frontend/node/softmax.h (CompositeSoftmaxNode::pre_validate_node): the legacy composite softmax node only allows the output combinations {}, {stats}, or {max, sum_exp} — the {stats, max} combination requested for return_max_logit=True is rejected, producing the error above.
  • cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h: the Max output is only added to the allowed outputs for effective_cudnn_ver >= 92100.

Fix

Add a filter in get_attention_backend that disables FusedAttention for return_max_logit=True when the cuDNN version is below 9.21.0, falling back to UnfusedDotProductAttention (FlashAttention is already disabled for max_logit). This follows the existing pattern of cuDNN-version filters in transformer_engine/pytorch/attention/dot_product_attention/utils.py.

# Filter: Return max_logit
if return_max_logit:
    ...
    # FusedAttention emits max_logit alongside the softmax stats, which cuDNN only
    # supports through the unified softmax node introduced in cuDNN 9.21.0. On older
    # cuDNN the composite softmax node rejects the stats+max combination, so fall back
    # to UnfusedDotProductAttention.
    if use_fused_attention and cudnn_version < (9, 21, 0):
        use_fused_attention = False
        logger.debug("Disabling FusedAttention for max_logit for cuDNN < 9.21.0")

The same policy is also enforced in nvte_get_fused_attn_backend (transformer_engine/common/fused_attn/fused_attn.cpp) so that non-PyTorch frontends are covered as well — the FP8 branch already checks !return_max_logit unconditionally:

        // max_logit
        // pre-9.21: no (the composite softmax node rejects the Stats + Max output combination)
        // 9.21+: yes (Stats + Max via the unified softmax node)
        (!return_max_logit || cudnn_runtime_version >= 92100) &&

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Disable FusedAttention in get_attention_backend when return_max_logit=True and cuDNN < 9.21.0, so backend selection falls back to UnfusedDotProductAttention instead of failing at cuDNN graph-build time.
  • Enforce the same requirement in nvte_get_fused_attn_backend (fused_attn.cpp, F16 arbitrary-seqlen condition) so non-PyTorch frontends are covered as well.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

Guards the F16 arbitrary-seqlen FusedAttention backend against return_max_logit=True on cuDNN < 9.21.0, where the composite softmax node rejects the Stats + Max output combination and fails at graph-build time. The fix adds a single boolean term to the eligibility condition in nvte_get_fused_attn_backend, mirroring the unconditional !return_max_logit guard already present in the FP8 branch.

  • Adds (!return_max_logit || cudnn_runtime_version >= 92100) to the flag_arb condition in fused_attn.cpp, correctly encoding cuDNN 9.21.0 as 92100 and matching the project's existing version-check style.
  • The PR description also describes a companion Python-side filter in get_attention_backend (utils.py) that would provide an early exit and debug log for the PyTorch frontend, but that change is not present in the diff; the C++ fix alone is sufficient to prevent the crash since Python respects the No_Backend return value.

Confidence Score: 5/5

The change is a minimal, well-targeted addition to an existing eligibility condition and correctly prevents a graph-build crash on cuDNN < 9.21.0 without altering any other code paths.

The single added condition is logically correct, uses the right version encoding (92100 for 9.21.0), and is consistent with the surrounding guards. For the PyTorch frontend the fix is effective even without the companion Python-side filter because the Python code already falls back on No_Backend. Non-PyTorch frontends that call nvte_get_fused_attn_backend directly are also protected.

transformer_engine/pytorch/attention/dot_product_attention/utils.py warrants a look — the PR description calls for a companion early-exit filter there that was not added.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds (!return_max_logit

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_attention_backend (Python)"] --> B{use_fused_attention?}
    B -->|Yes| C["tex.get_fused_attn_backend (C++)"]
    C --> D{F16 arbitrary-seqlen eligibility}
    D --> E{"NEW: not return_max_logit OR cudnn >= 92100"}
    E -->|False - max_logit + cuDNN older than 9.21| F["return No_Backend"]
    E -->|True| G["other conditions..."]
    G --> H["return F16_arbitrary_seqlen"]
    F --> I["Python: use_fused_attention = False, fall back to Unfused"]
    H --> J["Python: proceed with FusedAttention"]
    B -->|No| K["Try FlashAttention or Unfused"]
Loading

Reviews (3): Last reviewed commit: "removing python side check" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
@cyanguwa

cyanguwa commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

@francesco-bertolotti, could you please fix the DCO? There are instructions here. Thanks!

@cyanguwa

Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
@francesco-bertolotti

Copy link
Copy Markdown
Contributor Author

Sorry, forgot about DCO. Should be fixed now.

@cyanguwa

Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants