Skip to content

Commit 2a1cd8a

Browse files
committed
tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1
The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB cards (H100) and pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue). Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main): per-test peak `torch.cuda.max_memory_allocated`: before: 70.0 GiB (fp8_14) after : 36.1 GiB (fp8_14) -48% per-test peak `nvidia-smi memory.used`: before: 96.8 GiB after : 51.3 GiB -47% test outcome (B200, develop FE, this TE): identical 618F / 2196P / 891S, wall time within ~3% The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards. fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal. Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
1 parent a014300 commit 2a1cd8a

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ def get_model(dtype, config):
19111911
model_configs_fp8_vs_f16 = {
19121912
# test: ModelConfig(b, sq, hq, dqk)
19131913
"fp8_9": ModelConfig(
1914-
2,
1914+
1,
19151915
4096,
19161916
128,
19171917
192,
@@ -1926,22 +1926,22 @@ def get_model(dtype, config):
19261926
attn_mask_type="causal",
19271927
),
19281928
"fp8_11": ModelConfig(
1929-
2,
1929+
1,
19301930
4096,
19311931
128,
19321932
192,
19331933
head_dim_v=128,
19341934
attn_mask_type="causal_bottom_right",
19351935
),
1936-
"fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
1937-
"fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)),
1938-
"fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
1939-
"fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
1936+
"fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
1937+
"fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)),
1938+
"fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
1939+
"fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
19401940
"fp8_16": ModelConfig(
1941-
2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
1941+
1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
19421942
),
19431943
"fp8_17": ModelConfig(
1944-
2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
1944+
1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
19451945
),
19461946
"fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
19471947
"fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),

0 commit comments

Comments
 (0)