Skip to content

Commit 6b4720a

Browse files
committed
address changes recommended by Kshitij
Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
1 parent ead680c commit 6b4720a

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,37 +1905,37 @@ def get_model(dtype, config):
19051905
model_configs_fp8_vs_f16 = {
19061906
# test: ModelConfig(b, sq, hq, dqk)
19071907
"fp8_9": ModelConfig(
1908-
1,
1909-
4096,
1908+
2,
1909+
2048,
19101910
128,
19111911
192,
19121912
head_dim_v=128,
19131913
),
19141914
"fp8_10": ModelConfig(
1915-
1,
1916-
4096,
1915+
2,
1916+
2048,
19171917
128,
19181918
192,
19191919
head_dim_v=128,
19201920
attn_mask_type="causal",
19211921
),
19221922
"fp8_11": ModelConfig(
1923-
1,
1924-
4096,
1923+
2,
1924+
2048,
19251925
128,
19261926
192,
19271927
head_dim_v=128,
19281928
attn_mask_type="causal_bottom_right",
19291929
),
19301930
"fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
1931-
"fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)),
1932-
"fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
1931+
"fp8_13": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)),
1932+
"fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
19331933
"fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
19341934
"fp8_16": ModelConfig(
19351935
1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
19361936
),
19371937
"fp8_17": ModelConfig(
1938-
1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
1938+
2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
19391939
),
19401940
"fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
19411941
"fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),

0 commit comments

Comments
 (0)