Skip to content

Commit 3baf730

Browse files
committed
add llama predict test
Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
1 parent a39524d commit 3baf730

1 file changed

Lines changed: 100 additions & 0 deletions

File tree

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,27 @@
2222
import subprocess
2323
import sys
2424

25+
import pytest
2526
import torch
2627
from lightning.fabric.plugins.environments.lightning import find_free_network_port
2728

2829
from bionemo.core.data.load import load
2930
from bionemo.noodles.nvfaidx import NvFaidx
3031
from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file
32+
from bionemo.testing.subprocess_utils import run_command_in_subprocess
3133

3234

35+
36+
def small_training_llama_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""):
37+
cmd = (
38+
f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} "
39+
"--model-size 8B --num-layers 2 --limit-val-batches 1 "
40+
"--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback "
41+
f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 "
42+
f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}"
43+
)
44+
return cmd
45+
3346
def test_predict_evo2_runs(
3447
tmp_path, num_sequences: int = 5, target_sequence_lengths: list[int] = [3149, 3140, 1024, 3149, 3149]
3548
):
@@ -104,3 +117,90 @@ def test_predict_evo2_runs(
104117
idx = seq_idx_map[seq_name] # look up the out of order prediction index for this sequence.
105118
assert preds["pad_mask"][idx].sum() == expected_len
106119
assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), 512)
120+
121+
122+
@pytest.mark.timeout(512) # Optional: fail if the test takes too long.
123+
def test_predict_evo2_llama_runs(
124+
tmp_path, num_sequences: int = 5, target_sequence_lengths: list[int] = [3149, 3140, 1024, 3149, 3149]
125+
):
126+
"""
127+
This test first trains a small Llama model to create a checkpoint, then runs the `predict_evo2` command
128+
with that checkpoint and mock data in a temporary directory.
129+
It uses the temporary directory provided by pytest as the working directory.
130+
The command is run in a subshell, and we assert that it returns an exit code of 0.
131+
"""
132+
# First, train a small Llama model to create a checkpoint
133+
num_steps = 2
134+
train_command = small_training_llama_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps)
135+
stdout_pretrain: str = run_command_in_subprocess(command=train_command, path=str(tmp_path))
136+
assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain
137+
138+
# Find the created checkpoint
139+
log_dir = tmp_path / "pretrain" / "evo2"
140+
checkpoints_dir = log_dir / "checkpoints"
141+
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."
142+
143+
expected_checkpoint_suffix = f"{num_steps}.0-last"
144+
matching_subfolders = [
145+
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
146+
]
147+
assert matching_subfolders, (
148+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
149+
)
150+
assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found."
151+
checkpoint_path = matching_subfolders[0]
152+
153+
# Now create the FASTA file for prediction
154+
fasta_file_path = tmp_path / "test_llama.fasta"
155+
create_fasta_file(
156+
fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE
157+
)
158+
159+
# Build the command string for Llama model prediction.
160+
# Note: The command assumes that `predict_evo2` is in your PATH.
161+
output_dir = tmp_path / "test_llama_output"
162+
command = (
163+
f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_path} "
164+
f"--output-dir {output_dir} --model-type llama --model-size 8B --tensor-parallel-size 1 "
165+
"--pipeline-model-parallel-size 1 --context-parallel-size 1 --num-layers 2"
166+
)
167+
168+
# Run the command in a subshell, using the temporary directory as the current working directory.
169+
result = subprocess.run(
170+
command,
171+
shell=True, # Use the shell to interpret wildcards (e.g. SDH*)
172+
cwd=tmp_path, # Run in the temporary directory
173+
capture_output=True, # Capture stdout and stderr for debugging
174+
text=True, # Decode output as text
175+
)
176+
177+
# For debugging purposes, print the output if the test fails.
178+
if result.returncode != 0:
179+
sys.stderr.write("STDOUT:\n" + result.stdout + "\n")
180+
sys.stderr.write("STDERR:\n" + result.stderr + "\n")
181+
182+
# Assert that the command completed successfully.
183+
assert result.returncode == 0, "predict_evo2 command with Llama model failed."
184+
185+
# Assert that the output directory was created.
186+
pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt"))
187+
assert len(pred_files) == 1, "Expected 1 prediction file (for this test), got {}".format(len(pred_files))
188+
with open(output_dir / "seq_idx_map.json", "r") as f:
189+
seq_idx_map = json.load(
190+
f
191+
) # This gives us the mapping from the sequence names to the indices in the predictions.
192+
preds = torch.load(pred_files[0])
193+
assert isinstance(preds, dict)
194+
assert "token_logits" in preds
195+
assert "pad_mask" in preds
196+
assert "seq_idx" in preds
197+
assert len(preds["token_logits"]) == len(preds["pad_mask"]) == len(preds["seq_idx"]) == num_sequences
198+
assert len(seq_idx_map) == num_sequences
199+
fasta = NvFaidx(fasta_file_path)
200+
for i, seq_name in enumerate(sorted(fasta.keys())):
201+
expected_len = target_sequence_lengths[i]
202+
idx = seq_idx_map[seq_name] # look up the out of order prediction index for this sequence.
203+
assert preds["pad_mask"][idx].sum() == expected_len
204+
assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), 512)
205+
206+

0 commit comments

Comments
 (0)