@@ -1905,37 +1905,37 @@ def get_model(dtype, config):
19051905model_configs_fp8_vs_f16 = {
19061906 # test: ModelConfig(b, sq, hq, dqk)
19071907 "fp8_9" : ModelConfig (
1908- 1 ,
1909- 4096 ,
1908+ 2 ,
1909+ 2048 ,
19101910 128 ,
19111911 192 ,
19121912 head_dim_v = 128 ,
19131913 ),
19141914 "fp8_10" : ModelConfig (
1915- 1 ,
1916- 4096 ,
1915+ 2 ,
1916+ 2048 ,
19171917 128 ,
19181918 192 ,
19191919 head_dim_v = 128 ,
19201920 attn_mask_type = "causal" ,
19211921 ),
19221922 "fp8_11" : ModelConfig (
1923- 1 ,
1924- 4096 ,
1923+ 2 ,
1924+ 2048 ,
19251925 128 ,
19261926 192 ,
19271927 head_dim_v = 128 ,
19281928 attn_mask_type = "causal_bottom_right" ,
19291929 ),
19301930 "fp8_12" : ModelConfig (1 , 8192 , 32 , 128 , num_gqa_groups = 4 , attn_mask_type = "causal" ),
1931- "fp8_13" : ModelConfig (1 , 8192 , 32 , 128 , attn_mask_type = "causal" , window_size = (128 , 0 )),
1932- "fp8_14" : ModelConfig (1 , 8192 , 64 , 64 , num_gqa_groups = 8 , attn_mask_type = "causal" ),
1931+ "fp8_13" : ModelConfig (2 , 8192 , 32 , 128 , num_gqa_groups = 4 , attn_mask_type = "causal" , window_size = (128 , 0 )),
1932+ "fp8_14" : ModelConfig (2 , 4096 , 64 , 64 , num_gqa_groups = 8 , attn_mask_type = "causal" ),
19331933 "fp8_15" : ModelConfig (1 , 8192 , 64 , 64 , attn_mask_type = "causal" , window_size = (128 , 0 )),
19341934 "fp8_16" : ModelConfig (
19351935 1 , 8192 , 64 , 64 , num_gqa_groups = 8 , attn_mask_type = "causal" , softmax_type = "learnable"
19361936 ),
19371937 "fp8_17" : ModelConfig (
1938- 1 , 8192 , 64 , 64 , attn_mask_type = "causal" , window_size = (128 , 0 ), softmax_type = "learnable"
1938+ 2 , 4096 , 64 , 64 , attn_mask_type = "causal" , window_size = (128 , 0 ), softmax_type = "learnable"
19391939 ),
19401940 "fp8_18" : ModelConfig (1 , 8192 , 32 , 128 , num_gqa_groups = 4 , attn_mask_type = "padding" ),
19411941 "fp8_19" : ModelConfig (2 , 2048 , 16 , 128 , attn_mask_type = "padding_causal" ),
0 commit comments