Skip to content

Added thd cudnn guard#3092

Open
francesco-bertolotti wants to merge 4 commits into
NVIDIA:mainfrom
francesco-bertolotti:f14-thd-guard
Open

Added thd cudnn guard#3092
francesco-bertolotti wants to merge 4 commits into
NVIDIA:mainfrom
francesco-bertolotti:f14-thd-guard

Conversation

@francesco-bertolotti

Copy link
Copy Markdown
Contributor

Description

nvte_get_fused_attn_backend selects the F16 arbitrary-seqlen backend for mixed THD layouts (e.g. thd_bshd_bshd, used by KV caching) on Ampere/Ada GPUs, where cuDNN does not support THD (ragged offset) tensors before 9.18.1. Instead of falling back to another attention backend, the forward pass fails at cuDNN graph-build time:

RuntimeError: transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:418
in function operator(): cuDNN Error: THD (ragged offset) is only supported in Hopper and
above : 80.

Reproduction (observed on A100 / sm80 with cuDNN 9.10.2):

pytest -x tests/pytorch/attention/test_kv_cache.py
# fails in the first non-skipped thd case, e.g.
# test_kv_cache[False-False-TransformerLayer-FusedAttention-False-thd-infer_0-dtype0]
# with qkv_layout = thd_bshd_bshd (non-paged); paged_kv_thd_* layouts fail the same way

Root cause

The supported-format reference is cudnn-frontend's SDPA support surface (cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h), which rejects ragged tensors when sm < 90 && backend_version < 91801, matching the cuDNN 9.18.1 release notes ("Support for scaled dot-product attention backward with THD layout on RTX-PRO 6000 and Ampere-architecture GPUs for the F16 datatype has been added").

TE's backend selection intends to enforce this ("THD requires sm90", fused_attn.cpp qkv-format clause), but mixed THD layouts slip through both of its layers:

  • nvte_get_fused_attn_backend (fused_attn.cpp): for thd_bshd_bshd, nvte_get_qkv_format returns NVTE_THD_2BSHD, not NVTE_THD, so the sm90-gated pure-THD branch never applies. The layout falls to the mixed-format branch, which ORs the q_format and kv_format conditions together — kv_format == NVTE_BSHD alone satisfies it, so the failed (q_format == NVTE_THD && sm_arch_ >= 90) disjunct is simply skipped instead of vetoing the backend. The clause implements "at least one of q/kv has a supported format" where the requirement is "both do".
  • get_attention_backend (utils.py): the "Filter: QKV layout" section only gates qkv_format == "thd", but get_qkv_format reports "thd_2bshd" for these layouts, and the only architecture it checks for THD is sm120.

As a result the F16 arbitrary-seqlen backend is selected and the failure surfaces later as a cuDNN graph-build error, rather than backend selection falling back to FlashAttention/UnfusedDotProductAttention.

Fix

Mirror the cudnn-frontend rule: THD requires sm90+, or cuDNN 9.18.1+ on Ampere/Ada, in both selection layers:

  1. fused_attn.cpp: add a guard that closes the OR-masking hole for any layout involving THD:
        // THD (ragged offset) support: Hopper+ (sm90) always; Ampere/Ada (sm80/sm89) only
        // from cuDNN 9.18.1 ("SDPA backward with THD layout on RTX-PRO 6000 and
        // Ampere-architecture GPUs"; fprop on Ampere is undocumented, so gate both).
        // The qkv format clause above ORs q_format and kv_format conditions together, so a
        // valid kv_format (e.g. paged_kv_thd_bshd_bshd, where kv is BSHD) would otherwise
        // mask an invalid THD q_format on sm80 with older cuDNN.
        ((q_format != NVTE_QKV_Format::NVTE_THD && kv_format != NVTE_QKV_Format::NVTE_THD) ||
         sm_arch_ >= 90 || cudnn_runtime_version >= 91801) &&
  1. fused_attn.cpp: relax the three sm_arch_ >= 90 THD conditions in the qkv-format clause to (sm_arch_ >= 90 || cudnn_runtime_version >= 91801), so that pure and mixed THD layouts are consistently enabled on Ampere/Ada with cuDNN 9.18.1+.

  2. get_attention_backend (utils.py): add the equivalent filter so PyTorch users get a clear debug message; it checks q_format/kv_format rather than qkv_format to also cover the thd_2bshd/thd_2sbhd KV-cache layouts (this requires capturing kv_format at the get_qkv_format call site, where it was previously discarded):

    # THD support on Ampere/Ada requires cuDNN 9.18.1+ ("SDPA backward with THD layout on
    # RTX-PRO 6000 and Ampere-architecture GPUs"). Check q_format/kv_format, not just
    # qkv_format, since KV-cache layouts (e.g. paged_kv_thd_bshd_bshd) have
    # qkv_format = thd_2bshd.
    if "thd" in (q_format, kv_format) and device_compute_capability < (9, 0):
        if cudnn_version < (9, 18, 1):
            if use_fused_attention:
                logger.debug(
                    "Disabling FusedAttention as qkv_format = thd is not supported for"
                    " compute capability < sm90 and cuDNN version < 9.18.1"
                )
            use_fused_attention = False

With these guards, backend selection on sm80/sm89 with cuDNN < 9.18.1 returns No_Backend for THD layouts and falls back to FlashAttention/UnfusedDotProductAttention, and tests/pytorch/attention/test_kv_cache.py passes (FusedAttention thd cases are skipped as unsupported).

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a THD architecture/version guard to the F16 arbitrary-seqlen condition in nvte_get_fused_attn_backend, closing the hole where a valid kv_format masked an unsupported THD q_format on sm80/sm89 with cuDNN < 9.18.1.
  • Relax the existing sm_arch_ >= 90 THD conditions to also accept cuDNN >= 9.18.1, matching cudnn-frontend's support surface and the cuDNN 9.18.1 release notes.
  • Add the equivalent filter to get_attention_backend (checking q_format/kv_format to cover thd_2bshd/thd_2sbhd KV-cache layouts) so PyTorch backend selection logs a clear reason and falls back cleanly.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a backend-selection bug where mixed THD layouts (e.g. thd_bshd_bshd used by KV caching) slipped through the existing sm_arch_ >= 90 guards on Ampere/Ada, causing a crash at cuDNN graph-build time rather than a clean fallback to FlashAttention.

  • fused_attn.cpp: Relaxes three pure-THD sm_arch_ >= 90 conditions to (sm_arch_ >= 90 || cudnn_runtime_version >= 91801) and adds a closing AND-guard that blocks any layout where q_format or kv_format is NVTE_THD on sm < 90 without cuDNN 9.18.1+, fixing the OR-masking hole in the mixed-format branch.
  • fused_attn_f16_arbitrary_seqlen.cu: Adds sm_arch_ >= 90 to use_ragged_stats and the token-count substitution block, routing SM8X (when cuDNN ≥ 9.18.1 now admits it) through the BHSD-like path also used for SM12X; all call sites that conditioned on the old use_ragged_stats are updated consistently.
  • utils.py: Widens the \"Filter: QKV layout\" guard from qkv_format == \"thd\" to \"thd\" in (q_format, kv_format) to catch KV-cache layouts, and adds a new explicit sm < 9.0 + cuDNN < 9.18.1 check with a clear debug message.

Confidence Score: 5/5

Safe to merge for the primary goal of fixing the crash on Ampere/Ada with older cuDNN; the new SM8X + cuDNN 9.18.1+ code path in the .cu file is unvalidated but unreachable on the tested hardware.

The backend-selection fix in fused_attn.cpp cleanly closes the OR-masking hole and aligns with the cudnn-frontend support surface. The Python filter change is a matching, correct early-out. The .cu implementation changes are internally consistent with the BHSD-like SM12X treatment, and the new SM8X + cuDNN 9.18.1+ path is unreachable in the tested environment (A100 + cuDNN 9.10.2).

fused_attn_f16_arbitrary_seqlen.cu — the SM8X + cuDNN 9.18.1+ BHSD-like path (use_ragged_stats = false, no token-count substitution) is new and untested; worth verifying against cudnn-frontend's SM8X support surface before a broad release.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Backend-selection guard extended: pure-THD sm_arch_ checks relaxed to (sm_arch_ >= 90
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu use_ragged_stats and the token-count substitution block (b=max_b, s_q=max_t_q) now require sm_arch_ >= 90, mirroring SM12X treatment for SM8X on cuDNN 9.18.1+. This SM8X + cuDNN 9.18.1+ code path is new and unvalidated by the author (tested A100 runs cuDNN 9.10.2).
transformer_engine/pytorch/attention/dot_product_attention/utils.py QKV-layout filter widened from qkv_format == "thd" to "thd" in (q_format, kv_format) to catch mixed KV-cache layouts; new explicit sm < 9.0 + cuDNN < 9.18.1 guard added with clear debug log. Both changes are correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_get_fused_attn_backend called] --> B{qkv_format == NVTE_THD?}
    B -- Yes --> C{sm_arch >= 90 OR cuDNN >= 9.18.1?}
    C -- No --> FAIL[Return No_Backend → fallback to FA/UDPA]
    C -- Yes --> D{cuDNN >= 9.0.1 + MHA==GQA, or cuDNN >= 9.0.6?}
    D -- No --> FAIL
    D -- Yes --> GUARD
    B -- No --> E{Mixed-format branch: q_format/kv_format check AND cuDNN >= 9.0.7?}
    E -- No --> FAIL
    E -- Yes --> GUARD
    GUARD{NEW GUARD: q_format != THD AND kv_format != THD?}
    GUARD -- Yes --> PASS[Continue to other checks]
    GUARD -- No, THD present --> GUARD2{sm_arch >= 90 OR cuDNN >= 9.18.1?}
    GUARD2 -- No --> FAIL
    GUARD2 -- Yes --> PASS
    PASS --> IMPL[fused_attn_f16_arbitrary_seqlen_fwd/bwd_impl]
    IMPL --> SM{sm_arch >= 90 AND sm_arch != 120?}
    SM -- Yes SM9X --> RAGGED[Token-count dims + ragged Stats offset]
    SM -- No SM8X/SM12X --> BHSD[BHSD-like dims + no ragged Stats offset]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py
@cyanguwa

cyanguwa commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

/te-ci L0

cyanguwa
cyanguwa previously approved these changes Jun 5, 2026
@cyanguwa

cyanguwa commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

It's good that we enabled THD on sm8x and cuDNN >=9.18.1, but I think we missed a bit about the LSE shapes, which is why I'm seeing these failures from the CI: cuDNN Error: Packed/ragged LSE is not supported for bprop thd on SM8X and SM12X GPUs. They originated from this cuDNN call.

On sm90+, with THD, there is a memory optimization for LSE, which is to create it in TH1 shape, rather than BHS1. On sm8x/12x, it looks like such optimization doesn't exist in cuDNN. I looked through fused_attn_f16_arbitrary_seqlen.cu file - could you make the following changes please so we comply with the supported shapes in cuDNN? Thanks!

L89/593:
bool use_ragged_stats = is_ragged_q && sm_arch >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;

L104/606:
if (sm_arch_ >= 90 && sm_arch_ != 120) {

L388:
if (use_ragged_stats) {

L808:
if (is_ragged_kv && sm_arch >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {

L1145/1155:
bool use_ragged_stats = is_ragged_q && sm_arch >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
if (use_ragged_stats) {

@KshitijLakhani, I think we missed L1145 when changing the LSE (Stats) shape to BSH1 for sm120 in #2693? The THD support for sm8x and sm12x is enabled after cuDNN 9.18.1, and the Stats or Max shape should be BSH1, not TH1.

@cyanguwa

Copy link
Copy Markdown
Collaborator

Also, utils.py should probably have "thd" in (q_format, kv_format) here too: https://github.com/francesco-bertolotti/TransformerEngine/blob/463e491027c91cb79fe82e9ce0992d3e0904158c/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L968

I wonder if the sm8x logic (L1004) should be merged with the sm12x logic (L983).

francesco-bertolotti and others added 4 commits June 11, 2026 06:55
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
…ls.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
@francesco-bertolotti

Copy link
Copy Markdown
Contributor Author

Hi @cyanguwa, I've applied all the requested changes.

I have to be honest that I don't have deep familiarity with this part of the codebase, I followed your instructions closely but would appreciate a review pass to catch anything I may have misapplied. Happy to make further corrections!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants