Skip to content

Commit ac0c217

Browse files
committed
Add test coverage for fine-tuning
1 parent c22a90e commit ac0c217

1 file changed

Lines changed: 99 additions & 6 deletions

File tree

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def small_training_cmd(path, max_steps, val_check, devices: int = 1, additional_
6161
return cmd
6262

6363

64+
def small_training_finetune_cmd(path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = ""):
65+
cmd = (
66+
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
67+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 "
68+
"--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback "
69+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
70+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}"
71+
)
72+
return cmd
73+
74+
6475
def small_training_mamba_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
6576
cmd = (
6677
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
@@ -72,9 +83,22 @@ def small_training_mamba_cmd(path, max_steps, val_check, devices: int = 1, addit
7283
return cmd
7384

7485

86+
def small_training_mamba_finetune_cmd(
87+
path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = ""
88+
):
89+
cmd = (
90+
f"train_evo2 --mock-data --result-dir {path} --devices {devices} "
91+
"--model-size hybrid_mamba_8b --num-layers 2 --hybrid-override-pattern M- --limit-val-batches 1 "
92+
"--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback "
93+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
94+
f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}"
95+
)
96+
return cmd
97+
98+
7599
@pytest.mark.timeout(256) # Optional: fail if the test takes too long.
76100
@pytest.mark.slow
77-
def test_train_evo2_runs(tmp_path):
101+
def test_train_evo2_finetune_runs(tmp_path):
78102
"""
79103
This test runs the `train_evo2` command with mock data in a temporary directory.
80104
It uses the temporary directory provided by pytest as the working directory.
@@ -83,7 +107,8 @@ def test_train_evo2_runs(tmp_path):
83107
num_steps = 2
84108
# Note: The command assumes that `train_evo2` is in your PATH.
85109
command = small_training_cmd(tmp_path, max_steps=num_steps, val_check=num_steps)
86-
run_command_in_subprocess(command=command, path=str(tmp_path))
110+
stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path))
111+
assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
87112

88113
log_dir = tmp_path / "evo2"
89114
checkpoints_dir = log_dir / "checkpoints"
@@ -109,22 +134,56 @@ def test_train_evo2_runs(tmp_path):
109134
# Recursively search for files with tensorboard logger
110135
event_files = list(tensorboard_dir.rglob("events.out.tfevents*"))
111136
assert event_files, f"No TensorBoard event files found under {tensorboard_dir}"
137+
assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
138+
command_finetune = small_training_finetune_cmd(
139+
tmp_path / "finetune", max_steps=num_steps, val_check=num_steps, prev_ckpt=matching_subfolders[0]
140+
)
141+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
142+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
143+
144+
log_dir_ft = tmp_path / "finetune" / "evo2"
145+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
146+
tensorboard_dir_ft = log_dir_ft / "dev"
147+
148+
# Check if logs dir exists
149+
assert log_dir_ft.exists(), "Logs folder should exist."
150+
# Check if checkpoints dir exists
151+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
152+
153+
expected_checkpoint_suffix = f"{num_steps}.0-last"
154+
# Check if any subfolder ends with the expected suffix
155+
matching_subfolders_ft = [
156+
p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
157+
]
158+
159+
assert matching_subfolders_ft, (
160+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}."
161+
)
162+
163+
# Check if directory with tensorboard logs exists
164+
assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist."
165+
# Recursively search for files with tensorboard logger
166+
event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*"))
167+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}"
168+
169+
assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found."
112170

113171

114172
@pytest.mark.timeout(256) # Optional: fail if the test takes too long.
115173
@pytest.mark.slow
116-
def test_train_evo2_mamba_runs(tmp_path):
174+
def test_train_evo2_mamba_finetune_runs(tmp_path):
117175
"""
118176
This test runs the `train_evo2` command with mock data in a temporary directory.
119177
It uses the temporary directory provided by pytest as the working directory.
120178
The command is run in a subshell, and we assert that it returns an exit code of 0.
121179
"""
122180
num_steps = 2
123181
# Note: The command assumes that `train_evo2` is in your PATH.
124-
command = small_training_mamba_cmd(tmp_path, max_steps=num_steps, val_check=num_steps)
125-
run_command_in_subprocess(command=command, path=str(tmp_path))
182+
command = small_training_mamba_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps)
183+
stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path))
184+
assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
126185

127-
log_dir = tmp_path / "evo2"
186+
log_dir = tmp_path / "pretrain" / "evo2"
128187
checkpoints_dir = log_dir / "checkpoints"
129188
tensorboard_dir = log_dir / "dev"
130189

@@ -149,6 +208,40 @@ def test_train_evo2_mamba_runs(tmp_path):
149208
event_files = list(tensorboard_dir.rglob("events.out.tfevents*"))
150209
assert event_files, f"No TensorBoard event files found under {tensorboard_dir}"
151210

211+
assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
212+
command_finetune = small_training_mamba_finetune_cmd(
213+
tmp_path / "finetune", max_steps=num_steps, val_check=num_steps, prev_ckpt=matching_subfolders[0]
214+
)
215+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
216+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
217+
218+
log_dir_ft = tmp_path / "finetune" / "evo2"
219+
checkpoints_dir_ft = log_dir_ft / "checkpoints"
220+
tensorboard_dir_ft = log_dir_ft / "dev"
221+
222+
# Check if logs dir exists
223+
assert log_dir_ft.exists(), "Logs folder should exist."
224+
# Check if checkpoints dir exists
225+
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
226+
227+
expected_checkpoint_suffix = f"{num_steps}.0-last"
228+
# Check if any subfolder ends with the expected suffix
229+
matching_subfolders_ft = [
230+
p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
231+
]
232+
233+
assert matching_subfolders_ft, (
234+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}."
235+
)
236+
237+
# Check if directory with tensorboard logs exists
238+
assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist."
239+
# Recursively search for files with tensorboard logger
240+
event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*"))
241+
assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}"
242+
243+
assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found."
244+
152245

153246
@pytest.mark.timeout(256) # Optional: fail if the test takes too long.
154247
@pytest.mark.slow

0 commit comments

Comments
 (0)