@@ -64,9 +64,11 @@ jit_initializers: True
6464# Set true to load weights from pytorch
6565from_pt : True
6666split_head_dim : True
67- attention : ' flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom
67+ attention : ' flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
6868use_base2_exp : True
6969use_experimental_scheduler : True
70+ # For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
71+ ulysses_shards : -1
7072flash_min_seq_length : 4096
7173dropout : 0.0
7274
@@ -81,38 +83,38 @@ mask_padding_tokens: True
8183attention_sharding_uniform : True
8284
8385flash_block_sizes : {
84- " block_q" : 512,
85- " block_kv_compute" : 512,
86- " block_kv" : 512,
87- " block_q_dkv" : 512,
88- " block_kv_dkv" : 512,
89- " block_kv_dkv_compute" : 512,
90- " block_q_dq" : 512,
91- " block_kv_dq" : 512,
86+ " block_q " : 512,
87+ " block_kv_compute " : 512,
88+ " block_kv " : 512,
89+ " block_q_dkv " : 512,
90+ " block_kv_dkv " : 512,
91+ " block_kv_dkv_compute " : 512,
92+ " block_q_dq " : 512,
93+ " block_kv_dq " : 512,
9294 " use_fused_bwd_kernel " : False,
9395}
9496# Use on v6e
9597# flash_block_sizes: {
96- # "block_q" : 3024,
97- # "block_kv_compute" : 1024,
98- # "block_kv" : 2048,
99- # "block_q_dkv" : 3024,
100- # "block_kv_dkv" : 2048,
101- # "block_kv_dkv_compute" : 1024,
102- # "block_q_dq" : 3024,
103- # "block_kv_dq" : 2048,
98+ # "block_q": 3024,
99+ # "block_kv_compute": 1024,
100+ # "block_kv": 2048,
101+ # "block_q_dkv": 3024,
102+ # "block_kv_dkv": 2048,
103+ # "block_kv_dkv_compute": 1024,
104+ # "block_q_dq": 3024,
105+ # "block_kv_dq": 2048,
104106# "use_fused_bwd_kernel": False,
105107# }
106108# Use on v5p
107109# flash_block_sizes: {
108- # "block_q" : 3024,
109- # "block_kv_compute" : 1024,
110- # "block_kv" : 2048,
111- # "block_q_dkv" : 1024,
112- # "block_kv_dkv" : 3072,
113- # "block_kv_dkv_compute" : 256,
114- # "block_q_dq" : 1024,
115- # "block_kv_dq" : 3072
110+ # "block_q": 3024,
111+ # "block_kv_compute": 1024,
112+ # "block_kv": 2048,
113+ # "block_q_dkv": 1024,
114+ # "block_kv_dkv": 3072,
115+ # "block_kv_dkv_compute": 256,
116+ # "block_q_dq": 1024,
117+ # "block_kv_dq": 3072
116118# }
117119# GroupNorm groups
118120norm_num_groups : 32
0 commit comments