Skip to content

Commit 1563b10

Browse files
Name the OOM-skip threshold and explain the 128*bHSS workspace observation
Address review nits on the deterministic THD-backward OOM guard: 1. Replace the magic number 1_000_000_000 with the named constant SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable and labeled. 2. Replace the prefatory comment with a short note tying the number to cuDNN's actual workspace request (~128 * bHSS bytes, measured on cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is 128 GiB, which doesn't fit on H100's 80 GB. 3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up internally so workspace grows super-linearly past b=2 (b=4 asks for 4x the b=2 workspace, not 2x). The current fused-essential matrix is all b=2, so the threshold stays correct for what the test exercises; the note is there so the next person doesn't have to rediscover it. Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1 parent 3b1e4ce commit 1563b10

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -639,15 +639,17 @@ def test_cp_with_fused_attention(
639639
pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention")
640640
if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training:
641641
pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad")
642-
# Det FusedAttention backward with THD on sm90 OOMs because cuDNN reserves
643-
# workspace proportional to b*H*S*S. Gate on that product, not num_heads,
644-
# so the skip stays correct if a new config has small b/S but H >= 20.
642+
# cuDNN det THD backward workspace on sm90 is ~128 * bHSS bytes; at 1<<30
643+
# that's 128 GiB, won't fit on H100's 80 GB. Exact at b=2 + power-of-2 S;
644+
# for b>=3 cuDNN rounds batch up internally so workspace grows super-linearly
645+
# (e.g. b=4 wants 4x b=2's workspace, not 2x) — revisit if a config uses b>2.
646+
SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30
645647
if (
646648
_deterministic
647649
and qkv_format == "thd"
648650
and get_device_compute_capability() == (9, 0)
649651
and config.batch_size * config.num_heads * config.max_seqlen_q * config.max_seqlen_kv
650-
>= 1_000_000_000
652+
>= SM90_DET_FUSED_THD_BWD_MAX_BHSS
651653
):
652654
pytest.skip(
653655
"Deterministic FusedAttention backward with THD format OOMs on sm90"

0 commit comments

Comments
 (0)