Skip to content

Commit a48e7c5

Browse files
committed
address changes recommended by Kshitij
Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
1 parent 2a1cd8a commit a48e7c5

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,37 +1911,37 @@ def get_model(dtype, config):
19111911
model_configs_fp8_vs_f16 = {
19121912
# test: ModelConfig(b, sq, hq, dqk)
19131913
"fp8_9": ModelConfig(
1914-
1,
1915-
4096,
1914+
2,
1915+
2048,
19161916
128,
19171917
192,
19181918
head_dim_v=128,
19191919
),
19201920
"fp8_10": ModelConfig(
1921-
1,
1922-
4096,
1921+
2,
1922+
2048,
19231923
128,
19241924
192,
19251925
head_dim_v=128,
19261926
attn_mask_type="causal",
19271927
),
19281928
"fp8_11": ModelConfig(
1929-
1,
1930-
4096,
1929+
2,
1930+
2048,
19311931
128,
19321932
192,
19331933
head_dim_v=128,
19341934
attn_mask_type="causal_bottom_right",
19351935
),
19361936
"fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
1937-
"fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)),
1938-
"fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
1937+
"fp8_13": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)),
1938+
"fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
19391939
"fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
19401940
"fp8_16": ModelConfig(
19411941
1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
19421942
),
19431943
"fp8_17": ModelConfig(
1944-
1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
1944+
2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
19451945
),
19461946
"fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
19471947
"fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),

0 commit comments

Comments
 (0)