@@ -32,6 +32,8 @@ framework: native # native, accelerate
3232precision : bf16 # likely bf16 or fp8
3333te_enabled : true
3434fp8_enabled : false
35+ fp8_recipe : " "
36+ fp8_format : " "
3537# thd_enabled: false
3638
3739# Catchall for additional features/configs
@@ -90,6 +92,28 @@ products:
9092 micro_batch_size : 48
9193 wandb_name : " esm2_native_650m__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}"
9294 job_name : " esm2-native-650m-fsdp2-thd"
95+ # TE bshd perf, FSDP2, FP8
96+ - config : L1_650M
97+ task_cmd : train_fsdp2
98+ parallelism_strategy : fsdp2
99+ thd_enabled : false
100+ fp8_enabled : true
101+ fp8_recipe : transformer_engine.common.recipe.Float8BlockScaling
102+ fp8_format : E4M3
103+ micro_batch_size : 48
104+ wandb_name : " esm2_native_650m__fsdp2__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
105+ job_name : " esm2-native-650m-fsdp2-fp8"
106+ # TE thd perf, FSDP2, FP8
107+ - config : L1_650M
108+ task_cmd : train_fsdp2
109+ parallelism_strategy : fsdp2
110+ thd_enabled : true
111+ fp8_enabled : true
112+ fp8_recipe : transformer_engine.common.recipe.Float8BlockScaling
113+ fp8_format : E4M3
114+ micro_batch_size : 48
115+ wandb_name : " esm2_native_650m__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
116+ job_name : " esm2-native-650m-fsdp2-thd-fp8"
93117 # OSS Convergence Baseline
94118 # - config: L1_650M
95119 # model_tag: facebook/esm2_t33_650M_UR50D
@@ -137,4 +161,6 @@ run_script: |
137161 checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
138162 +checkpoint.save_checkpoints=${save_checkpoints} \
139163 +checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \
140- fp8_config.enabled=${fp8_enabled}
164+ fp8_config.enabled=${fp8_enabled} \
165+ fp8_config.fp8_recipe=${fp8_recipe} \
166+ fp8_config.fp8_format=${fp8_format}
0 commit comments