Skip to content

Commit 2913ea0

Browse files
authored
add fp8 to config (#1434)
Add fp8 to esm2 15b config. Also: - reduce non-15b steps from 20k to 10k. - remove an esm2 650m job - remove esm2 3b from automated tracking --------- Signed-off-by: jwilber <jwilber@nvidia.com>
1 parent 8c4a616 commit 2913ea0

4 files changed

Lines changed: 42 additions & 16 deletions

File tree

.github/workflows/convergence-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
runs-on: ubuntu-latest
4040
strategy:
4141
matrix:
42-
model_config: ${{ github.event_name == 'schedule' && fromJSON('["esm2_native_te_650m", "esm2_native_te_3b", "esm2_native_te_15b", "codonfm_ptl_te"]') || fromJSON(format('["{0}"]', github.event.inputs.model_config)) }}
42+
model_config: ${{ github.event_name == 'schedule' && fromJSON('["esm2_native_te_650m", "esm2_native_te_15b", "codonfm_ptl_te"]') || fromJSON(format('["{0}"]', github.event.inputs.model_config)) }}
4343
fail-fast: false
4444
steps:
4545
- name: Checkout

ci/lepton/model_convergence/configs/recipes/esm2_native_te_15b.yaml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ framework: native # native, accelerate
3232
precision: bf16 # likely bf16 or fp8
3333
te_enabled: true
3434
fp8_enabled: false
35+
fp8_recipe: ""
36+
fp8_format: ""
3537
# thd_enabled: false
3638

3739
# Catchall for additional features/configs
@@ -123,6 +125,28 @@ products:
123125
micro_batch_size: 4
124126
wandb_name: "esm2_native_15b__fsdp2__baseline__${now:%Y%m%d-%H%M%S}__${gitsha:}"
125127
job_name: "esm2-native-15b-fsdp2-baseline"
128+
# TE bshd perf, FSDP2, FP8
129+
- config: L1_15B_perf_test
130+
task_cmd: train_fsdp2
131+
parallelism_strategy: fsdp2
132+
thd_enabled: false
133+
fp8_enabled: true
134+
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
135+
fp8_format: E4M3
136+
micro_batch_size: 4
137+
wandb_name: "esm2_native_15b__fsdp2__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
138+
job_name: "esm2-native-15b-fsdp2-fp8"
139+
# TE thd perf, FSDP2, FP8
140+
- config: L1_15B_perf_test
141+
task_cmd: train_fsdp2
142+
parallelism_strategy: fsdp2
143+
thd_enabled: true
144+
fp8_enabled: true
145+
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
146+
fp8_format: E4M3
147+
micro_batch_size: 4
148+
wandb_name: "esm2_native_15b__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
149+
job_name: "esm2-native-15b-fsdp2-thd-fp8"
126150

127151
############################################################
128152
# run script
@@ -156,4 +180,6 @@ run_script: |
156180
checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
157181
+checkpoint.save_checkpoints=${save_checkpoints} \
158182
+checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \
159-
fp8_config.enabled=${fp8_enabled}
183+
fp8_config.enabled=${fp8_enabled} \
184+
fp8_config.fp8_recipe=${fp8_recipe} \
185+
fp8_config.fp8_format=${fp8_format}

ci/lepton/model_convergence/configs/recipes/esm2_native_te_3b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ wandb_init_args:
5858
# these should match the keys in the recipe's config file
5959
model_tag: nvidia/esm2_t36_3B_UR50D
6060
# task_cmd: train_fsdp2 # mfsdp
61-
num_train_steps: 20_000
61+
num_train_steps: 10_000
6262
# dataset commands
6363
micro_batch_size: 16
6464
load_dataset_kwargs_path: nvidia/esm2_uniref_pretraining_data
6565
load_dataset_kwargs_streaming: true
6666
load_dataset_kwargs_revision: 4ac1d2973567e46b8ca95901f4b4793a21305995 # pragma: allowlist secret
6767

6868
# lr commands
69-
num_warmup_steps: 2_000
69+
num_warmup_steps: 1_000
7070
# checkpoint controls
7171
ckpt_dir: ""
7272
save_checkpoints: false

ci/lepton/model_convergence/configs/recipes/esm2_native_te_650m.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ wandb_init_args:
5858
# these should match the keys in the recipe's config file
5959
model_tag: nvidia/esm2_t36_650M_UR50D
6060
# task_cmd: train_fsdp2 # mfsdp
61-
num_train_steps: 20_000
61+
num_train_steps: 10_000
6262
# dataset commands
6363
micro_batch_size: 16
6464
load_dataset_kwargs_path: nvidia/esm2_uniref_pretraining_data
@@ -67,7 +67,7 @@ load_dataset_kwargs_revision: 4ac1d2973567e46b8ca95901f4b4793a21305995 # pragma:
6767
num_workers: 1
6868

6969
# lr commands
70-
num_warmup_steps: 2_000
70+
num_warmup_steps: 1_000
7171
# checkpoint controls
7272
ckpt_dir: ""
7373
save_checkpoints: false
@@ -91,16 +91,16 @@ products:
9191
wandb_name: "esm2_native_650m__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}"
9292
job_name: "esm2-native-650m-fsdp2-thd"
9393
# OSS Convergence Baseline
94-
- config: L1_650M
95-
model_tag: facebook/esm2_t33_650M_UR50D
96-
num_nodes: 8
97-
num_devices: 8
98-
task_cmd: train_fsdp2
99-
parallelism_strategy: fsdp2
100-
thd_enabled: false
101-
micro_batch_size: 32
102-
wandb_name: "esm2_native_650m__fsdp2__${now:%Y%m%d-%H%M%S}__${gitsha:}"
103-
job_name: "esm2-native-650m-fsdp2"
94+
# - config: L1_650M
95+
# model_tag: facebook/esm2_t33_650M_UR50D
96+
# num_nodes: 8
97+
# num_devices: 8
98+
# task_cmd: train_fsdp2
99+
# parallelism_strategy: fsdp2
100+
# thd_enabled: false
101+
# micro_batch_size: 32
102+
# wandb_name: "esm2_native_650m__fsdp2__${now:%Y%m%d-%H%M%S}__${gitsha:}"
103+
# job_name: "esm2-native-650m-fsdp2"
104104

105105
############################################################
106106
# run script

0 commit comments

Comments
 (0)