Skip to content

Commit 2d6c7d9

Browse files
committed
added review feedback and added llama train and finetune test
Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
1 parent 7eebe89 commit 2d6c7d9

5 files changed

Lines changed: 107 additions & 132 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/data/sharded_eden_dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
# high performance implementation of the EdenDataModule, assuming some items are pre-computed + sharded fasta files and fasta index files.
17-
17+
# Contributed by: BaseCamp Research https://basecamp-research.com/ https://github.com/NVIDIA/bionemo-framework/pull/1091
1818
import argparse
1919
import csv
2020
import os
@@ -100,7 +100,7 @@ def __init__(
100100
log_dir: Optional[str] = None,
101101
**kwargs,
102102
):
103-
"""Initialize the ShardedEdenDataModule."""
103+
"""Initialize the ShardedEdenDataModule. See sub-packages/bionemo-evo2/src/bionemo/evo2/data/sharded_eden_dataloader.md for how to prepare the input data. """
104104
super().__init__()
105105
self.sequence_db_dir = sequence_db_dir
106106
self.train_window_db_path = train_window_db_path

sub-packages/bionemo-evo2/src/bionemo/evo2/models/llama.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
17-
import math
16+
1817
from dataclasses import dataclass
1918
from typing import Optional
2019

21-
import torch
2220
from nemo.collections import llm
21+
from nemo.collections.llm.gpt.model.llama import apply_rope_scaling
2322

2423

2524
@dataclass
@@ -57,48 +56,6 @@ def configure_model(self, *args, **kwargs):
5756
return model
5857

5958

60-
def apply_rope_scaling(
61-
inv_freq,
62-
factor: int = 8,
63-
low_freq_factor: int = 1,
64-
high_freq_factor: int = 4,
65-
old_context_len: int = 8192,
66-
):
67-
"""Apply RoPE scaling for extending context length in Llama models.
68-
69-
This implements the NTK-aware RoPE scaling method used in Llama 3.1 models to
70-
extend context length beyond the original training length.
71-
72-
Args:
73-
inv_freq: Original inverse frequency tensor
74-
factor: Scaling factor for context length extension
75-
low_freq_factor: Factor for low frequency components
76-
high_freq_factor: Factor for high frequency components
77-
old_context_len: Original context length
78-
79-
Returns:
80-
torch.Tensor: Modified inverse frequency tensor for extended context
81-
"""
82-
logging.info(
83-
f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, old_context_len={old_context_len}."
84-
)
85-
86-
low_freq_wavelen = old_context_len / low_freq_factor
87-
high_freq_wavelen = old_context_len / high_freq_factor
88-
89-
wavelen = 2 * math.pi / inv_freq
90-
# wavelen < high_freq_wavelen: do nothing
91-
# wavelen > low_freq_wavelen: divide by factor
92-
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
93-
# otherwise: interpolate between the two, using a smooth factor
94-
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
95-
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
96-
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
97-
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
98-
99-
return inv_freq_llama
100-
101-
10259
@dataclass
10360
class Eden11BConfig(EdenConfig):
10461
"""Eden-flavoured Llama-3.1 ~14B (keeps all Eden behaviors)."""

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -756,9 +756,10 @@ def train(args: argparse.Namespace) -> nl.Trainer:
756756
"distribute_saved_activations": False if args.sequence_parallel else True,
757757
"cross_entropy_loss_fusion": args.cross_entropy_loss_fusion,
758758
"fp32_residual_connection": not args.no_fp32_residual_connection,
759-
"add_bias_output": args.add_bias_output,
760759
**activation_checkpointing_args,
761760
}
761+
if args.add_bias_output:
762+
config_modifiers_init["add_bias_output"] = args.add_bias_output
762763
if args.spike_no_more_embedding_init:
763764
config_modifiers_init["embedding_init_method_std"] = 1.0
764765
# When using spike_no_more_embedding_init, we don't want to share embeddings and outputs.
@@ -803,12 +804,10 @@ def train(args: argparse.Namespace) -> nl.Trainer:
803804
config_modifiers_init["lowercase_loss_reweighting"] = args.mamba_lowercase_loss_weight
804805
if args.model_size not in MAMBA_MODEL_OPTIONS:
805806
raise ValueError(f"Invalid model size for Mamba: {args.model_size}")
806-
add_bias_output = config_modifiers_init.pop("add_bias_output")
807-
if add_bias_output:
808-
raise ValueError("Bias output is not supported for Mamba models.")
809807
model_config = MAMBA_MODEL_OPTIONS[args.model_size](**config_modifiers_init)
810808
model = MambaModel(model_config, tokenizer=data_module.tokenizer)
811809
elif model_type == "llama":
810+
config_modifiers_init.pop("to_upper")
812811
model_config = LLAMA_MODEL_OPTIONS[args.model_size](**config_modifiers_init)
813812
model = llm.LlamaModel(model_config, tokenizer=data_module.tokenizer)
814813

@@ -895,7 +894,7 @@ def train(args: argparse.Namespace) -> nl.Trainer:
895894
f"-GBS{global_batch_size}-MBS{args.micro_batch_size}-SkipLossRenorm{args.no_renormalize_loss}"
896895
f"-NOAC{args.no_activation_checkpointing}-SELAC{args.selective_activation_checkpointing}"
897896
f"-ACRNL{model_config.recompute_num_layers}"
898-
f"-PAT{model_config.hybrid_override_pattern}"
897+
f"-PAT{getattr(model_config, 'hybrid_override_pattern', 'None')}"
899898
f"-F32R{model_config.fp32_residual_connection}"
900899
f"-FCE{model_config.cross_entropy_loss_fusion}"
901900
f"-AIC{average_in_collective}"
@@ -1071,7 +1070,7 @@ def train(args: argparse.Namespace) -> nl.Trainer:
10711070
constant_steps=args.constant_steps,
10721071
)
10731072
# This is where the no weight decay condition is applied to the optimizer state.
1074-
opt = MegatronOptimizerModule(opt_config, sched, no_weight_decay_cond=model_config.hyena_no_weight_decay_cond_fn)
1073+
opt = MegatronOptimizerModule(opt_config, sched, no_weight_decay_cond=getattr(model_config, 'hyena_no_weight_decay_cond_fn', None))
10751074
opt.connect(model)
10761075
# Start training
10771076
trainer.fit(model, data_module)

sub-packages/bionemo-evo2/tests/bionemo/evo2/data/test_sharded_eden_dataset.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def test_dataset_reverse_complement(sequence_db_dir, window_dbs):
430430
# Test with N bases
431431
test_seq_with_n = "ATCN"
432432
rc_seq_with_n = dataset.reverse_complement(test_seq_with_n)
433-
assert rc_seq_with_n == "NCGAT"
433+
assert rc_seq_with_n == "NGAT"
434434

435435
# Clean up
436436
dataset.__del__()
@@ -480,54 +480,6 @@ def test_dataset_collate_fn(sequence_db_dir, window_dbs):
480480
dataset.__del__()
481481

482482

483-
def test_invalid_sequence_db_dir(window_dbs):
484-
"""Test error handling for invalid sequence database directory."""
485-
# Mock tokenizer
486-
mock_tokenizer = Mock()
487-
mock_tokenizer.bos_id = 1
488-
mock_tokenizer.eos_id = 2
489-
mock_tokenizer._sep_id = 3
490-
mock_tokenizer.pad_id = 0
491-
mock_tokenizer.text_to_ids.return_value = [10, 11, 12]
492-
493-
# Test with non-existent directory
494-
with pytest.raises(ValueError, match="No SQLite files found"):
495-
ShardedEdenDataset(
496-
tokenizer=mock_tokenizer,
497-
sequence_db_dir="/non/existent/path",
498-
window_db_path=window_dbs["train"],
499-
seq_length=8192,
500-
create_attention_mask=False,
501-
stride=7992,
502-
rc_aug=False,
503-
use_control_tags=False,
504-
split="train",
505-
)
506-
507-
508-
def test_invalid_window_db_path(sequence_db_dir):
509-
"""Test error handling for invalid window database path."""
510-
# Mock tokenizer
511-
mock_tokenizer = Mock()
512-
mock_tokenizer.bos_id = 1
513-
mock_tokenizer.eos_id = 2
514-
mock_tokenizer._sep_id = 3
515-
mock_tokenizer.pad_id = 0
516-
mock_tokenizer.text_to_ids.return_value = [10, 11, 12]
517-
518-
# Test with non-existent window database
519-
with pytest.raises(ValueError):
520-
ShardedEdenDataset(
521-
tokenizer=mock_tokenizer,
522-
sequence_db_dir=sequence_db_dir,
523-
window_db_path="/non/existent/windows.db",
524-
seq_length=8192,
525-
create_attention_mask=False,
526-
stride=7992,
527-
rc_aug=False,
528-
use_control_tags=False,
529-
split="train",
530-
)
531483

532484

533485
def test_window_min_length_threshold(temp_dir, train_parquet):
@@ -594,33 +546,3 @@ def test_dataset_length_and_iteration(sequence_db_dir, window_dbs):
594546
# Clean up
595547
dataset.__del__()
596548

597-
598-
def test_dataset_with_different_seq_lengths(sequence_db_dir, window_dbs):
599-
"""Test dataset with different sequence lengths."""
600-
# Mock tokenizer
601-
mock_tokenizer = Mock()
602-
mock_tokenizer.bos_id = 1
603-
mock_tokenizer.eos_id = 2
604-
mock_tokenizer._sep_id = 3
605-
mock_tokenizer.pad_id = 0
606-
mock_tokenizer.text_to_ids.return_value = [10, 11, 12]
607-
608-
# Test with different sequence lengths
609-
for seq_length in [4096, 8192, 16384]:
610-
dataset = ShardedEdenDataset(
611-
tokenizer=mock_tokenizer,
612-
sequence_db_dir=sequence_db_dir,
613-
window_db_path=window_dbs["train"],
614-
seq_length=seq_length,
615-
create_attention_mask=False,
616-
stride=seq_length - 200, # Adjust stride
617-
rc_aug=False,
618-
use_control_tags=False,
619-
split="train",
620-
)
621-
622-
# Verify sequence length
623-
assert dataset.seq_length == seq_length
624-
625-
# Clean up
626-
dataset.__del__()

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,29 @@ def small_training_mamba_finetune_cmd(
9696
return cmd
9797

9898

99+
def small_training_llama_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
100+
cmd = (
101+
f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} "
102+
"--model-size 8B --num-layers 2 --limit-val-batches 1 "
103+
"--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback "
104+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
105+
f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
106+
)
107+
return cmd
108+
109+
110+
def small_training_llama_finetune_cmd(
111+
path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = ""
112+
):
113+
cmd = (
114+
f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} "
115+
"--model-size 8B --num-layers 2 --limit-val-batches 1 "
116+
"--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback "
117+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
118+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}"
119+
)
120+
return cmd
121+
99122
@pytest.mark.timeout(512) # Optional: fail if the test takes too long.
100123
@pytest.mark.slow
101124
def test_train_evo2_finetune_runs(tmp_path):
@@ -243,6 +266,80 @@ def test_train_evo2_mamba_finetune_runs(tmp_path):
243266
assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found."
244267

245268

269+
270+
@pytest.mark.timeout(512) # Optional: fail if the test takes too long.
271+
@pytest.mark.slow
272+
def test_train_evo2_llama_finetune_runs(tmp_path):
273+
"""
274+
This test runs the `train_evo2` command with mock data in a temporary directory using Llama model.
275+
It uses the temporary directory provided by pytest as the working directory.
276+
The command is run in a subshell, and we assert that it returns an exit code of 0.
277+
"""
278+
num_steps = 2
279+
# Note: The command assumes that `train_evo2` is in your PATH.
280+
command = small_training_llama_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps)
281+
stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path))
282+
assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
283+
284+
log_dir = tmp_path / "pretrain" / "evo2"
285+
checkpoints_dir = log_dir / "checkpoints"
286+
tensorboard_dir = log_dir / "dev"
287+
288+
# Check if logs dir exists
289+
assert log_dir.exists(), "Logs folder should exist."
290+
# Check if checkpoints dir exists
291+
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."
292+
293+
expected_checkpoint_suffix = f"{num_steps}.0-last"
294+
# Check if any subfolder ends with the expected suffix
295+
matching_subfolders = [
296+
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
297+
]
298+
299+
assert matching_subfolders, (
300+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
301+
)
302+
303+
# Check if directory with tensorboard logs exists
304+
assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist."
305+
# Recursively search for files with tensorboard logger
306+
event_files = list(tensorboard_dir.rglob("events.out.tfevents*"))
307+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir}"
308+
309+
assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
310+
command_finetune = small_training_llama_finetune_cmd(
311+
tmp_path / "finetune", max_steps=num_steps, val_check=num_steps, prev_ckpt=matching_subfolders[0]
312+
)
313+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
314+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
315+
316+
log_dir_ft = tmp_path / "finetune" / "evo2"
317+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
318+
tensorboard_dir_ft = log_dir_ft / "dev"
319+
320+
# Check if logs dir exists
321+
assert log_dir_ft.exists(), "Logs folder should exist."
322+
# Check if checkpoints dir exists
323+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
324+
325+
expected_checkpoint_suffix = f"{num_steps}.0-last"
326+
matching_subfolders_ft = [
327+
p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
328+
]
329+
330+
assert matching_subfolders_ft, (
331+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}."
332+
)
333+
334+
# Check if directory with tensorboard logs exists
335+
assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist."
336+
# Recursively search for files with tensorboard logger
337+
event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*"))
338+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}"
339+
340+
assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found."
341+
342+
246343
@pytest.mark.timeout(256) # Optional: fail if the test takes too long.
247344
@pytest.mark.slow
248345
def test_train_evo2_stops(tmp_path):

0 commit comments

Comments
 (0)