guarding max_logits fused attention for cudnn < 9.21.0#3091
guarding max_logits fused attention for cudnn < 9.21.0#3091francesco-bertolotti wants to merge 2 commits into
Conversation
d395938 to
ae53b5b
Compare
Greptile SummaryGuards the F16 arbitrary-seqlen FusedAttention backend against
Confidence Score: 5/5The change is a minimal, well-targeted addition to an existing eligibility condition and correctly prevents a graph-build crash on cuDNN < 9.21.0 without altering any other code paths. The single added condition is logically correct, uses the right version encoding (92100 for 9.21.0), and is consistent with the surrounding guards. For the PyTorch frontend the fix is effective even without the companion Python-side filter because the Python code already falls back on No_Backend. Non-PyTorch frontends that call nvte_get_fused_attn_backend directly are also protected. transformer_engine/pytorch/attention/dot_product_attention/utils.py warrants a look — the PR description calls for a companion early-exit filter there that was not added. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["get_attention_backend (Python)"] --> B{use_fused_attention?}
B -->|Yes| C["tex.get_fused_attn_backend (C++)"]
C --> D{F16 arbitrary-seqlen eligibility}
D --> E{"NEW: not return_max_logit OR cudnn >= 92100"}
E -->|False - max_logit + cuDNN older than 9.21| F["return No_Backend"]
E -->|True| G["other conditions..."]
G --> H["return F16_arbitrary_seqlen"]
F --> I["Python: use_fused_attention = False, fall back to Unfused"]
H --> J["Python: proceed with FusedAttention"]
B -->|No| K["Try FlashAttention or Unfused"]
Reviews (3): Last reviewed commit: "removing python side check" | Re-trigger Greptile |
|
@francesco-bertolotti, could you please fix the DCO? There are instructions here. Thanks! |
|
/te-ci pytorch L1 |
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
8a66831 to
ba2f30e
Compare
|
Sorry, forgot about DCO. Should be fixed now. |
|
/te-ci pytorch L1 |
Description
get_attention_backendselects FusedAttention forreturn_max_logit=Trueregardless of the cuDNN version, but cuDNN only supports emittingMaxalongside the softmaxStatsfrom cuDNN 9.21.0. On older cuDNN versions the forward pass fails at graph-build time with:Reproduction (observed on A100 / sm80 with cuDNN 9.10.2; the failure is cuDNN-version dependent, not architecture dependent):
Root cause
FusedAttention requests both the
StatsandMaxoutputs from the SDPA node (fused_attn_f16_arbitrary_seqlen.cu,sdpa_options.set_logit_max(Max)). In cudnn-frontend, that combination is only representable through the unified softmax descriptor (CUDNN_ATTR_OPERATION_SDPA_FWD_SOFTMAX_DESC), which requires a cuDNN backend >= 9.21.0:cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h: the composite SDPA node only wiresStats/Max/Sum_expinto aUnifiedSoftmaxNodewheneffective_cudnn_ver >= 92100; on older versions onlyStatsis set (viaCUDNN_ATTR_OPERATION_SDPA_FWD_STATSDESC).cudnn-frontend/include/cudnn_frontend/node/softmax.h(CompositeSoftmaxNode::pre_validate_node): the legacy composite softmax node only allows the output combinations{},{stats}, or{max, sum_exp}— the{stats, max}combination requested forreturn_max_logit=Trueis rejected, producing the error above.cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h: theMaxoutput is only added to the allowed outputs foreffective_cudnn_ver >= 92100.Fix
Add a filter in
get_attention_backendthat disables FusedAttention forreturn_max_logit=Truewhen the cuDNN version is below 9.21.0, falling back to UnfusedDotProductAttention (FlashAttention is already disabled formax_logit). This follows the existing pattern of cuDNN-version filters intransformer_engine/pytorch/attention/dot_product_attention/utils.py.The same policy is also enforced in
nvte_get_fused_attn_backend(transformer_engine/common/fused_attn/fused_attn.cpp) so that non-PyTorch frontends are covered as well — the FP8 branch already checks!return_max_logitunconditionally:Fixes # (issue)
Type of change
Changes
get_attention_backendwhenreturn_max_logit=Trueand cuDNN < 9.21.0, so backend selection falls back to UnfusedDotProductAttention instead of failing at cuDNN graph-build time.nvte_get_fused_attn_backend(fused_attn.cpp, F16 arbitrary-seqlen condition) so non-PyTorch frontends are covered as well.Checklist: