Skip to content

Commit 86d4e15

Browse files
authored
[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM (#3020)
* tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1 The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB cards (H100) and pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue). Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main): per-test peak `torch.cuda.max_memory_allocated`: before: 70.0 GiB (fp8_14) after : 36.1 GiB (fp8_14) -48% per-test peak `nvidia-smi memory.used`: before: 96.8 GiB after : 51.3 GiB -47% test outcome (B200, develop FE, this TE): identical 618F / 2196P / 891S, wall time within ~3% The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards. fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal. Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com> * address changes recommended by Kshitij Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> * tests/attention: black format fp8_13 ModelConfig Line was 105 chars; black requires <=100 with the project's preview+ string_processing settings. Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com> --------- Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com> Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
1 parent 815bf36 commit 86d4e15

1 file changed

Lines changed: 12 additions & 10 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,36 +1926,38 @@ def get_model(dtype, config):
19261926
# test: ModelConfig(b, sq, hq, dqk)
19271927
"fp8_9": ModelConfig(
19281928
2,
1929-
4096,
1929+
2048,
19301930
128,
19311931
192,
19321932
head_dim_v=128,
19331933
),
19341934
"fp8_10": ModelConfig(
1935-
1,
1936-
4096,
1935+
2,
1936+
2048,
19371937
128,
19381938
192,
19391939
head_dim_v=128,
19401940
attn_mask_type="causal",
19411941
),
19421942
"fp8_11": ModelConfig(
19431943
2,
1944-
4096,
1944+
2048,
19451945
128,
19461946
192,
19471947
head_dim_v=128,
19481948
attn_mask_type="causal_bottom_right",
19491949
),
1950-
"fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
1951-
"fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)),
1952-
"fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
1953-
"fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
1950+
"fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
1951+
"fp8_13": ModelConfig(
1952+
2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)
1953+
),
1954+
"fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
1955+
"fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
19541956
"fp8_16": ModelConfig(
1955-
2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
1957+
1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
19561958
),
19571959
"fp8_17": ModelConfig(
1958-
2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
1960+
2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
19591961
),
19601962
"fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
19611963
"fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),

0 commit comments

Comments
 (0)