Skip to content

Commit d9d0dba

Browse files
authored
Merge branch 'main' into mvle/evo2-pretrain
2 parents 060e196 + a29272f commit d9d0dba

31 files changed

Lines changed: 982 additions & 52 deletions

File tree

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
url = https://github.com/NVIDIA/Megatron-LM.git
44
[submodule "3rdparty/NeMo"]
55
path = 3rdparty/NeMo
6-
url = https://github.com/NVIDIA/NeMo.git
6+
url = https://github.com/NVIDIA-NeMo/NeMo.git
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
scope: partial-conv
2+
time_limit: 14400
3+
key_segments:
4+
# Modify keys to be renamed (str) or excluded (False) from run identifier. By default, all args under script_args are included.
5+
dataset_config: False
6+
dataset_dir: False
7+
data_base_path: False
8+
num_workers: False
9+
limit_val_batches: False
10+
val_check_interval: False
11+
experiment_name: False
12+
workspace: False
13+
restore_from_checkpoint_path: False
14+
activation_checkpoint_layers: False
15+
lora_enabled: False
16+
lr: False
17+
min_lr: False
18+
warmup_steps: False
19+
accumulate_grad_batches: False
20+
clip_grad: False
21+
weight_decay: False
22+
attention_dropout: False
23+
hidden_dropout: False
24+
precision: False
25+
seq_length: False
26+
script_args:
27+
# All arguments referenced in the script string must be specified here.
28+
# Arguments not referenced in the script string must have the 'arg' field specified.
29+
# See jet/core/configs.py for the specification of the configuration class
30+
workspace: /workspace/bionemo2
31+
data_base_path: /data/evo2
32+
restore_from_checkpoint_path: checkpoints/nemo2_evo2_1b_8k
33+
nodes: 1
34+
model: evo2
35+
config_name: 1b
36+
num_workers: 1
37+
limit_val_batches: 20
38+
dataset_config: training_data_config.yaml
39+
dataset_dir: preprocessed_data
40+
val_check_interval: 5
41+
seq_length: 8192
42+
warmup_steps: 10
43+
activation_checkpoint_layers: 2
44+
lr: 0.000015
45+
min_lr: 0.0000149
46+
accumulate_grad_batches: 4
47+
max_steps: 1000
48+
gpus: 1
49+
clip_grad: 250
50+
weight_decay: 0.001
51+
attention_dropout: 0.01
52+
hidden_dropout: 0.01
53+
stop_steps: 100
54+
batch_size: 2
55+
variant: finetune
56+
precision: fp8
57+
products:
58+
- variant: finetune
59+
lora_enabled: ""
60+
task: finetune_from_ckpt
61+
experiment_name: evo2-finetune
62+
- variant: lora_finetune
63+
lora_enabled: "--lora-finetune"
64+
task: lora_finetune_from_ckpt
65+
experiment_name: evo2-lora-finetune
66+
script: |-
67+
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY train_${model} \
68+
-d ${data_base_path}/${dataset_config} \
69+
--dataset-dir=${data_base_path}/${dataset_dir} \
70+
--ckpt-dir=${data_base_path}/${restore_from_checkpoint_path} \
71+
${lora_enabled} \
72+
--model-size=${config_name} \
73+
--max-steps=${max_steps} \
74+
--experiment-name=${experiment_name}_${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s \
75+
--lr=${lr} \
76+
--min-lr=${min_lr} \
77+
--warmup-steps=${warmup_steps} \
78+
--result-dir=${tensorboard_dir} \
79+
--micro-batch-size=${batch_size} \
80+
--grad-acc-batches=${accumulate_grad_batches} \
81+
--limit-val-batches=${limit_val_batches} \
82+
--seq-length=${seq_length} \
83+
--clip-grad=${clip_grad} \
84+
--wd=${weight_decay} \
85+
--attention-dropout=${attention_dropout} \
86+
--hidden-dropout=${hidden_dropout} \
87+
--num-layers 4 \
88+
--hybrid-override-pattern 'SDH*' \
89+
--devices=${gpus} \
90+
--num-nodes=${nodes} \
91+
--val-check-interval=${val_check_interval} \
92+
--wandb-project=${wandb_project_name} \
93+
--wandb-group=${model}_${variant}_${config_name}_${task}_${target} \
94+
--create-tensorboard-logger \
95+
--activation-checkpoint-recompute-num-layers=${activation_checkpoint_layers} \
96+
--disable-checkpointing \
97+
--early-stop-on-step=${stop_steps} \
98+
--garbage-collect-at-inference;

models/esm2/src/esm/modeling_esm_te.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,12 @@ def __init__(self, config: NVEsmConfig):
138138
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
139139
if config.position_embedding_type == "rotary":
140140
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
141-
self.te_rope_emb = self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cuda()
141+
# Keep on CPU, pin for faster non_blocking H2D; don't persist in state_dict.
142+
self.register_buffer(
143+
"te_rope_emb",
144+
self.rotary_embeddings(max_seq_len=config.max_position_embeddings).cpu().pin_memory(),
145+
persistent=False,
146+
)
142147
else:
143148
self.te_rope_emb = None
144149

@@ -157,14 +162,28 @@ def forward(
157162
"""
158163
all_hidden_states = () if output_hidden_states else None
159164

165+
if self.te_rope_emb is not None:
166+
te_rope_emb = self.te_rope_emb.to(
167+
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
168+
)
169+
seq_len = hidden_states.shape[1]
170+
if te_rope_emb.size(0) < seq_len:
171+
raise RuntimeError(
172+
f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
173+
f"Increase max_position_embeddings."
174+
)
175+
te_rope_emb = te_rope_emb[:seq_len]
176+
else:
177+
te_rope_emb = None
178+
160179
for layer_module in self.layers:
161180
if output_hidden_states:
162181
all_hidden_states = (*all_hidden_states, hidden_states)
163182

164183
hidden_states = layer_module(
165184
hidden_states,
166185
attention_mask,
167-
rotary_pos_emb=self.te_rope_emb,
186+
rotary_pos_emb=te_rope_emb,
168187
)
169188

170189
hidden_states = self.emb_layer_norm_after(hidden_states)

recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ compute_environment: LOCAL_MACHINE
22
debug: false
33
distributed_type: MULTI_GPU
44
downcast_bf16: 'no'
5-
dynamo_config:
6-
dynamo_backend: INDUCTOR
75
enable_cpu_affinity: false
86
machine_rank: 0
97
main_training_function: main
@@ -16,3 +14,5 @@ tpu_env: []
1614
tpu_use_cluster: false
1715
tpu_use_sudo: false
1816
use_cpu: false
17+
dynamo_config:
18+
dynamo_backend: "NO"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
17+
from transformers.training_args import TrainingArguments
18+
19+
20+
class StopAfterNStepsCallback(TrainerCallback):
21+
"""Callback to interrupt training after a specified number of steps.
22+
23+
This allows us to use a learning rate scheduler consistent with the full training run while
24+
stopping after a pre-determined number of steps.
25+
"""
26+
27+
def __init__(self, max_steps: int):
28+
"""Initialize the callback."""
29+
self.max_steps = max_steps
30+
31+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
32+
"""Interrupt training after a specified number of steps."""
33+
if state.global_step >= self.max_steps:
34+
control.should_training_stop = True

recipes/amplify_accelerate_te_fp8/dataset.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626

2727

2828
def create_datasets_and_collator(
29-
pretained_model: str | os.PathLike,
29+
pretrained_model: str | os.PathLike,
3030
max_length: int,
3131
data_size: Literal["full", "sanity", "parquet"],
3232
) -> tuple[Dataset, Dataset, DataCollatorForLanguageModeling]:
3333
"""Create the datasets and the data collator.
3434
3535
Args:
36-
pretained_model: The path or tag of the pre-trained model to load the tokenizer from.
36+
pretrained_model: The path or tag of the pre-trained model to load the tokenizer from.
3737
max_length: The maximum length of the sequences.
3838
data_size: The size of the dataset to load. If "full", use and pre-process the full UR100P
3939
CSV dataset. This takes a long time without a cached dataset. If "small", use and
@@ -43,7 +43,7 @@ def create_datasets_and_collator(
4343
Returns:
4444
A tuple containing the train dataset, the eval dataset, and the data collator.
4545
"""
46-
tokenizer = AutoTokenizer.from_pretrained(pretained_model)
46+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
4747

4848
def tokenize(examples):
4949
"""Tokenize the examples."""
@@ -84,8 +84,7 @@ def tokenize(examples):
8484
train_dataset = train_dataset.shuffle(seed=42)
8585

8686
for dataset in [train_dataset, eval_dataset]:
87-
dataset.set_transform(tokenize, output_all_columns=True)
88-
dataset.remove_columns(["sequence", "name"])
87+
dataset.set_transform(tokenize)
8988

9089
data_collator = DataCollatorForLanguageModeling(
9190
tokenizer=tokenizer,

recipes/amplify_accelerate_te_fp8/test_train.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434

3535
_fp8_available, _fp8_reason = check_fp8_support()
3636

37+
requires_multi_gpu = pytest.mark.skipif(
38+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
39+
reason="Test requires at least 2 GPUs",
40+
)
41+
3742

3843
@pytest.fixture(scope="session")
3944
def session_temp_dir(tmp_path_factory):
@@ -175,26 +180,70 @@ def test_accelerate_launch(accelerate_config, tmp_path):
175180
assert train_py.exists(), f"train.py not found at {train_py}"
176181
assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"
177182

178-
# Run 'accelerate launch train.py' as a subprocess
179-
env = os.environ.copy()
180-
181-
subprocess.run(
182-
[
183-
sys.executable,
184-
"-m",
185-
"accelerate.commands.launch",
186-
"--config_file",
187-
str(accelerate_config_path),
188-
str(train_py),
189-
"--config-name",
190-
"L0_sanity",
191-
f"trainer.output_dir={tmp_path}",
192-
],
193-
cwd=recipe_dir,
183+
cmd = [
184+
sys.executable,
185+
"-m",
186+
"accelerate.commands.launch",
187+
"--config_file",
188+
str(accelerate_config_path),
189+
str(train_py),
190+
"--config-name",
191+
"L0_sanity",
192+
f"trainer.output_dir={tmp_path}",
193+
"trainer.save_steps=1000",
194+
"trainer.eval_steps=1000",
195+
"trainer.do_eval=false",
196+
]
197+
198+
result = subprocess.run(
199+
cmd,
200+
check=False,
201+
text=True,
194202
stdout=subprocess.PIPE,
195203
stderr=subprocess.PIPE,
204+
timeout=240,
205+
)
206+
207+
if result.returncode != 0:
208+
print(f"STDOUT:\n{result.stdout}")
209+
print(f"STDERR:\n{result.stderr}")
210+
pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}")
211+
212+
213+
@requires_multi_gpu
214+
def test_accelerate_launch_multi_gpu(tmp_path):
215+
"""Test that accelerate launch runs successfully."""
216+
# Find the recipe directory and train.py
217+
recipe_dir = Path(__file__).parent
218+
train_py = recipe_dir / "train.py"
219+
220+
cmd = [
221+
sys.executable,
222+
"-m",
223+
"accelerate.commands.launch",
224+
"--config_file",
225+
str(recipe_dir / "accelerate_config" / "bf16_config.yaml"),
226+
"--num_processes",
227+
"2",
228+
str(train_py),
229+
"--config-name",
230+
"L0_sanity",
231+
f"trainer.output_dir={tmp_path}",
232+
"trainer.save_steps=1000",
233+
"trainer.eval_steps=1000",
234+
"trainer.do_eval=false",
235+
]
236+
237+
result = subprocess.run(
238+
cmd,
239+
check=False,
196240
text=True,
197-
check=True,
241+
stdout=subprocess.PIPE,
242+
stderr=subprocess.PIPE,
198243
timeout=240,
199-
env=env,
200244
)
245+
246+
if result.returncode != 0:
247+
print(f"STDOUT:\n{result.stdout}")
248+
print(f"STDERR:\n{result.stderr}")
249+
pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}")

recipes/amplify_accelerate_te_fp8/train.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from omegaconf import DictConfig
2323
from transformers import AutoConfig, AutoModelForMaskedLM
2424
from transformers.trainer import Trainer
25-
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
2625
from transformers.training_args import TrainingArguments
2726

27+
from callbacks import StopAfterNStepsCallback
2828
from dataset import create_datasets_and_collator
2929
from metrics import compute_metrics
3030

@@ -45,7 +45,7 @@ def main(args: DictConfig):
4545
)
4646

4747
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(
48-
pretained_model=args.model_tag,
48+
pretrained_model=args.model_tag,
4949
max_length=config.max_length,
5050
data_size=args.data_size,
5151
)
@@ -79,25 +79,8 @@ def main(args: DictConfig):
7979
if training_args.do_eval:
8080
trainer.evaluate()
8181

82-
torch.distributed.destroy_process_group()
83-
84-
85-
class StopAfterNStepsCallback(TrainerCallback):
86-
"""Callback to interrupt training after a specified number of steps.
87-
88-
This allows us to use a learning rate scheduler consistent with the full training run while
89-
stopping after a pre-determined number of steps.
90-
"""
91-
92-
def __init__(self, max_steps: int):
93-
"""Initialize the callback."""
94-
self.max_steps = max_steps
95-
96-
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
97-
"""Interrupt training after a specified number of steps."""
98-
if state.global_step >= self.max_steps:
99-
control.should_training_stop = True
100-
return control
82+
if torch.distributed.is_available() and torch.distributed.is_initialized():
83+
torch.distributed.destroy_process_group()
10184

10285

10386
if __name__ == "__main__":
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Dockerfile
2+
README.md
3+
checkpoint_export/
4+
outputs/
5+
.ruff_cache
6+
__pycache__
7+
.pytest_cache
8+
.ruff.toml
9+
.dockerignore

recipes/esm2_accelerate/.ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
extend = "../.ruff.toml"

0 commit comments

Comments
 (0)