@@ -1911,37 +1911,37 @@ def get_model(dtype, config):
19111911model_configs_fp8_vs_f16 = {
19121912 # test: ModelConfig(b, sq, hq, dqk)
19131913 "fp8_9" : ModelConfig (
1914- 1 ,
1915- 4096 ,
1914+ 2 ,
1915+ 2048 ,
19161916 128 ,
19171917 192 ,
19181918 head_dim_v = 128 ,
19191919 ),
19201920 "fp8_10" : ModelConfig (
1921- 1 ,
1922- 4096 ,
1921+ 2 ,
1922+ 2048 ,
19231923 128 ,
19241924 192 ,
19251925 head_dim_v = 128 ,
19261926 attn_mask_type = "causal" ,
19271927 ),
19281928 "fp8_11" : ModelConfig (
1929- 1 ,
1930- 4096 ,
1929+ 2 ,
1930+ 2048 ,
19311931 128 ,
19321932 192 ,
19331933 head_dim_v = 128 ,
19341934 attn_mask_type = "causal_bottom_right" ,
19351935 ),
19361936 "fp8_12" : ModelConfig (1 , 8192 , 32 , 128 , num_gqa_groups = 4 , attn_mask_type = "causal" ),
1937- "fp8_13" : ModelConfig (1 , 8192 , 32 , 128 , attn_mask_type = "causal" , window_size = (128 , 0 )),
1938- "fp8_14" : ModelConfig (1 , 8192 , 64 , 64 , num_gqa_groups = 8 , attn_mask_type = "causal" ),
1937+ "fp8_13" : ModelConfig (2 , 8192 , 32 , 128 , num_gqa_groups = 4 , attn_mask_type = "causal" , window_size = (128 , 0 )),
1938+ "fp8_14" : ModelConfig (2 , 4096 , 64 , 64 , num_gqa_groups = 8 , attn_mask_type = "causal" ),
19391939 "fp8_15" : ModelConfig (1 , 8192 , 64 , 64 , attn_mask_type = "causal" , window_size = (128 , 0 )),
19401940 "fp8_16" : ModelConfig (
19411941 1 , 8192 , 64 , 64 , num_gqa_groups = 8 , attn_mask_type = "causal" , softmax_type = "learnable"
19421942 ),
19431943 "fp8_17" : ModelConfig (
1944- 1 , 8192 , 64 , 64 , attn_mask_type = "causal" , window_size = (128 , 0 ), softmax_type = "learnable"
1944+ 2 , 4096 , 64 , 64 , attn_mask_type = "causal" , window_size = (128 , 0 ), softmax_type = "learnable"
19451945 ),
19461946 "fp8_18" : ModelConfig (1 , 8192 , 32 , 128 , num_gqa_groups = 4 , attn_mask_type = "padding" ),
19471947 "fp8_19" : ModelConfig (2 , 2048 , 16 , 128 , attn_mask_type = "padding_causal" ),
0 commit comments