Added thd cudnn guard#3092
Conversation
277847e to
9755745
Compare
Greptile SummaryThis PR fixes a backend-selection bug where mixed THD layouts (e.g.
Confidence Score: 5/5Safe to merge for the primary goal of fixing the crash on Ampere/Ada with older cuDNN; the new SM8X + cuDNN 9.18.1+ code path in the .cu file is unvalidated but unreachable on the tested hardware. The backend-selection fix in fused_attn.cpp cleanly closes the OR-masking hole and aligns with the cudnn-frontend support surface. The Python filter change is a matching, correct early-out. The .cu implementation changes are internally consistent with the BHSD-like SM12X treatment, and the new SM8X + cuDNN 9.18.1+ path is unreachable in the tested environment (A100 + cuDNN 9.10.2). fused_attn_f16_arbitrary_seqlen.cu — the SM8X + cuDNN 9.18.1+ BHSD-like path (use_ragged_stats = false, no token-count substitution) is new and untested; worth verifying against cudnn-frontend's SM8X support surface before a broad release. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_get_fused_attn_backend called] --> B{qkv_format == NVTE_THD?}
B -- Yes --> C{sm_arch >= 90 OR cuDNN >= 9.18.1?}
C -- No --> FAIL[Return No_Backend → fallback to FA/UDPA]
C -- Yes --> D{cuDNN >= 9.0.1 + MHA==GQA, or cuDNN >= 9.0.6?}
D -- No --> FAIL
D -- Yes --> GUARD
B -- No --> E{Mixed-format branch: q_format/kv_format check AND cuDNN >= 9.0.7?}
E -- No --> FAIL
E -- Yes --> GUARD
GUARD{NEW GUARD: q_format != THD AND kv_format != THD?}
GUARD -- Yes --> PASS[Continue to other checks]
GUARD -- No, THD present --> GUARD2{sm_arch >= 90 OR cuDNN >= 9.18.1?}
GUARD2 -- No --> FAIL
GUARD2 -- Yes --> PASS
PASS --> IMPL[fused_attn_f16_arbitrary_seqlen_fwd/bwd_impl]
IMPL --> SM{sm_arch >= 90 AND sm_arch != 120?}
SM -- Yes SM9X --> RAGGED[Token-count dims + ragged Stats offset]
SM -- No SM8X/SM12X --> BHSD[BHSD-like dims + no ragged Stats offset]
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
|
/te-ci L0 |
|
It's good that we enabled THD on sm8x and cuDNN >=9.18.1, but I think we missed a bit about the LSE shapes, which is why I'm seeing these failures from the CI: On sm90+, with THD, there is a memory optimization for LSE, which is to create it in TH1 shape, rather than BHS1. On sm8x/12x, it looks like such optimization doesn't exist in cuDNN. I looked through fused_attn_f16_arbitrary_seqlen.cu file - could you make the following changes please so we comply with the supported shapes in cuDNN? Thanks! L89/593: L104/606: L388: L808: L1145/1155: @KshitijLakhani, I think we missed L1145 when changing the LSE (Stats) shape to BSH1 for sm120 in #2693? The THD support for sm8x and sm12x is enabled after cuDNN 9.18.1, and the Stats or Max shape should be BSH1, not TH1. |
|
Also, utils.py should probably have I wonder if the sm8x logic (L1004) should be merged with the sm12x logic (L983). |
5a14343 to
28608a7
Compare
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
…ls.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
for more information, see https://pre-commit.ci
fc591d9 to
514d032
Compare
|
Hi @cyanguwa, I've applied all the requested changes. I have to be honest that I don't have deep familiarity with this part of the codebase, I followed your instructions closely but would appreciate a review pass to catch anything I may have misapplied. Happy to make further corrections! |
Description
nvte_get_fused_attn_backendselects the F16 arbitrary-seqlen backend for mixed THD layouts (e.g.thd_bshd_bshd, used by KV caching) on Ampere/Ada GPUs, where cuDNN does not support THD (ragged offset) tensors before 9.18.1. Instead of falling back to another attention backend, the forward pass fails at cuDNN graph-build time:Reproduction (observed on A100 / sm80 with cuDNN 9.10.2):
Root cause
The supported-format reference is cudnn-frontend's SDPA support surface (
cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h), which rejects ragged tensors whensm < 90 && backend_version < 91801, matching the cuDNN 9.18.1 release notes ("Support for scaled dot-product attention backward with THD layout on RTX-PRO 6000 and Ampere-architecture GPUs for the F16 datatype has been added").TE's backend selection intends to enforce this ("THD requires sm90",
fused_attn.cppqkv-format clause), but mixed THD layouts slip through both of its layers:nvte_get_fused_attn_backend(fused_attn.cpp): forthd_bshd_bshd,nvte_get_qkv_formatreturnsNVTE_THD_2BSHD, notNVTE_THD, so the sm90-gated pure-THD branch never applies. The layout falls to the mixed-format branch, which ORs the q_format and kv_format conditions together —kv_format == NVTE_BSHDalone satisfies it, so the failed(q_format == NVTE_THD && sm_arch_ >= 90)disjunct is simply skipped instead of vetoing the backend. The clause implements "at least one of q/kv has a supported format" where the requirement is "both do".get_attention_backend(utils.py): the "Filter: QKV layout" section only gatesqkv_format == "thd", butget_qkv_formatreports"thd_2bshd"for these layouts, and the only architecture it checks for THD is sm120.As a result the F16 arbitrary-seqlen backend is selected and the failure surfaces later as a cuDNN graph-build error, rather than backend selection falling back to FlashAttention/UnfusedDotProductAttention.
Fix
Mirror the cudnn-frontend rule: THD requires
sm90+, or cuDNN 9.18.1+ on Ampere/Ada, in both selection layers:fused_attn.cpp: add a guard that closes the OR-masking hole for any layout involving THD:fused_attn.cpp: relax the threesm_arch_ >= 90THD conditions in the qkv-format clause to(sm_arch_ >= 90 || cudnn_runtime_version >= 91801), so that pure and mixed THD layouts are consistently enabled on Ampere/Ada with cuDNN 9.18.1+.get_attention_backend(utils.py): add the equivalent filter so PyTorch users get a clear debug message; it checksq_format/kv_formatrather thanqkv_formatto also cover thethd_2bshd/thd_2sbhdKV-cache layouts (this requires capturingkv_formatat theget_qkv_formatcall site, where it was previously discarded):With these guards, backend selection on sm80/sm89 with cuDNN < 9.18.1 returns
No_Backendfor THD layouts and falls back to FlashAttention/UnfusedDotProductAttention, andtests/pytorch/attention/test_kv_cache.pypasses (FusedAttention thd cases are skipped as unsupported).Fixes # (issue)
Type of change
Changes
nvte_get_fused_attn_backend, closing the hole where a validkv_formatmasked an unsupported THDq_formaton sm80/sm89 with cuDNN < 9.18.1.sm_arch_ >= 90THD conditions to also accept cuDNN >= 9.18.1, matching cudnn-frontend's support surface and the cuDNN 9.18.1 release notes.get_attention_backend(checkingq_format/kv_formatto coverthd_2bshd/thd_2sbhdKV-cache layouts) so PyTorch backend selection logs a clear reason and falls back cleanly.Checklist: