Skip to content

Commit 7883b85

Browse files
committed
less steps for 15b, add fp8 for 650m
Signed-off-by: jwilber <jwilber@nvidia.com>
1 parent 00a7a9d commit 7883b85

2 files changed

Lines changed: 28 additions & 2 deletions

File tree

ci/lepton/model_convergence/configs/recipes/esm2_native_te_15b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ use_torch_compile: false
6161
# these should match the keys in the recipe's config file
6262
model_tag: nvidia/esm2_t48_15B_UR50D
6363
# task_cmd: train_fsdp2 # mfsdp
64-
num_train_steps: 20_000
64+
num_train_steps: 500
6565
# dataset commands
6666
micro_batch_size: 8
6767
load_dataset_kwargs_path: nvidia/esm2_uniref_pretraining_data

ci/lepton/model_convergence/configs/recipes/esm2_native_te_650m.yaml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ framework: native # native, accelerate
3232
precision: bf16 # likely bf16 or fp8
3333
te_enabled: true
3434
fp8_enabled: false
35+
fp8_recipe: ""
36+
fp8_format: ""
3537
# thd_enabled: false
3638

3739
# Catchall for additional features/configs
@@ -90,6 +92,28 @@ products:
9092
micro_batch_size: 48
9193
wandb_name: "esm2_native_650m__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}"
9294
job_name: "esm2-native-650m-fsdp2-thd"
95+
# TE bshd perf, FSDP2, FP8
96+
- config: L1_650M
97+
task_cmd: train_fsdp2
98+
parallelism_strategy: fsdp2
99+
thd_enabled: false
100+
fp8_enabled: true
101+
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
102+
fp8_format: E4M3
103+
micro_batch_size: 48
104+
wandb_name: "esm2_native_650m__fsdp2__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
105+
job_name: "esm2-native-650m-fsdp2-fp8"
106+
# TE thd perf, FSDP2, FP8
107+
- config: L1_650M
108+
task_cmd: train_fsdp2
109+
parallelism_strategy: fsdp2
110+
thd_enabled: true
111+
fp8_enabled: true
112+
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
113+
fp8_format: E4M3
114+
micro_batch_size: 48
115+
wandb_name: "esm2_native_650m__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
116+
job_name: "esm2-native-650m-fsdp2-thd-fp8"
93117
# OSS Convergence Baseline
94118
# - config: L1_650M
95119
# model_tag: facebook/esm2_t33_650M_UR50D
@@ -137,4 +161,6 @@ run_script: |
137161
checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
138162
+checkpoint.save_checkpoints=${save_checkpoints} \
139163
+checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \
140-
fp8_config.enabled=${fp8_enabled}
164+
fp8_config.enabled=${fp8_enabled} \
165+
fp8_config.fp8_recipe=${fp8_recipe} \
166+
fp8_config.fp8_format=${fp8_format}

0 commit comments

Comments
 (0)