Skip to content

Commit a2bdc85

Browse files
committed
Remove TODO about adding numerical test
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 3e77e8d commit a2bdc85

2 files changed

Lines changed: 67 additions & 11 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor | dict[str
201201
"""Alias for forward_step, also log the pad mask since sequences may not all have the same length."""
202202
if len(batch) == 0:
203203
return
204-
forward_out = self.forward_step(batch)
204+
assert self.training is False, "predict_step should be called in eval mode"
205+
with torch.no_grad():
206+
forward_out = self.forward_step(batch)
205207
if not isinstance(forward_out, Tensor):
206208
return forward_out
207209
# Reminder: the model's predictions for input i land at output i+1. To get everything to align, we prepend the

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import subprocess
2323
import sys
24+
import tempfile
2425
from pathlib import Path
2526

2627
import pytest
@@ -54,7 +55,7 @@ def checkpoint_1b_8k_bf16_path() -> Path:
5455
)
5556
else:
5657
raise e
57-
return checkpoint_path
58+
yield checkpoint_path
5859

5960

6061
@pytest.fixture(scope="module")
@@ -69,11 +70,11 @@ def checkpoint_7b_1m_path() -> Path:
6970
)
7071
else:
7172
raise e
72-
return checkpoint_path
73+
yield checkpoint_path
7374

7475

7576
@pytest.mark.parametrize(
76-
"ddp,pp,tp,wi",
77+
"ddp,pp,wi",
7778
[
7879
pytest.param(1, 1, "epoch", id="ddp=1,pp=1,wi=epoch"),
7980
pytest.param(2, 1, "epoch", id="ddp=2,pp=1,wi=epoch"),
@@ -91,7 +92,6 @@ def test_predict_evo2_runs(
9192
tmp_path,
9293
ddp: int,
9394
pp: int,
94-
tp: int,
9595
wi: str,
9696
checkpoint_1b_8k_bf16_path: Path,
9797
num_sequences: int = 5,
@@ -105,7 +105,7 @@ def test_predict_evo2_runs(
105105
Since it's the full output this does not support CP, so we only test with TP=1. We also want coverage of the
106106
case where the sequence lengths are different and not necessarily divisible by CP.
107107
"""
108-
world_size = ddp * pp * tp
108+
world_size = ddp * pp
109109
if world_size > torch.cuda.device_count():
110110
pytest.skip(f"World size {world_size} is less than the number of GPUs {torch.cuda.device_count()}")
111111
fasta_file_path = tmp_path / "test.fasta"
@@ -125,7 +125,7 @@ def test_predict_evo2_runs(
125125
command = (
126126
f"torchrun --nproc_per_node {world_size} --nnodes 1 --no-python "
127127
f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_1b_8k_bf16_path} "
128-
f"--output-dir {output_dir} --model-size 1b --tensor-parallel-size {tp} "
128+
f"--output-dir {output_dir} --model-size 1b "
129129
f"--micro-batch-size 3 --write-interval {wi} "
130130
f"--pipeline-model-parallel-size {pp} --num-nodes 1 --devices {world_size}"
131131
)
@@ -180,6 +180,56 @@ def test_predict_evo2_runs(
180180
assert token_logits.shape == (max(target_sequence_lengths), 512)
181181

182182

183+
@pytest.fixture(scope="module")
184+
def baseline_predictions_7b_1m_results(
185+
checkpoint_7b_1m_path: Path,
186+
num_sequences: int = 5,
187+
target_sequence_lengths: list[int] = [2048, 2048, 2048, 2048, 2048],
188+
) -> dict[int, float]:
189+
with tempfile.TemporaryDirectory() as tmp_dir:
190+
tmp_path = Path(tmp_dir)
191+
fasta_file_path = tmp_path / "test.fasta"
192+
create_fasta_file(
193+
fasta_file_path,
194+
num_sequences,
195+
sequence_lengths=target_sequence_lengths,
196+
repeating_dna_pattern=ALU_SEQUENCE,
197+
)
198+
output_dir = tmp_path / "test_output"
199+
command = (
200+
f"torchrun --nproc_per_node 1 --nnodes 1 --no-python "
201+
f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_7b_1m_path} "
202+
f"--num-layers 4 --hybrid-override-pattern SDH* " # subset of layers for testing
203+
# FIXME changing batch size from 3 to 1 required dropping rel=1e-6 to rel=1e-3
204+
# even when model parallelism is not used. This should be investigated.
205+
f"--micro-batch-size 3 "
206+
f"--output-dir {output_dir} --model-size 7b_arc_longcontext "
207+
f"--num-nodes 1 --write-interval epoch "
208+
"--output-log-prob-seqs --log-prob-collapse-option sum"
209+
)
210+
# Create a mock data directory.
211+
# a local copy of the environment
212+
env = dict(**os.environ)
213+
open_port = find_free_network_port()
214+
env["MASTER_PORT"] = str(open_port)
215+
result = subprocess.run(
216+
command,
217+
shell=True, # Use the shell to interpret wildcards (e.g. SDH*)
218+
cwd=tmp_path, # Run in the temporary directory
219+
capture_output=True, # Capture stdout and stderr for debugging
220+
env=env, # Pass in the env where we override the master port.
221+
text=True, # Decode output as text
222+
)
223+
assert result.returncode == 0, "predict_evo2 command failed."
224+
# Assert that the output directory was created.
225+
pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt"))
226+
preds = [torch.load(pf) for pf in pred_files]
227+
preds = batch_collator(
228+
[p for p in preds if p is not None],
229+
)
230+
yield dict(zip([i.item() for i in preds["seq_idx"]], [p.item() for p in preds["log_probs_seqs"]]))
231+
232+
183233
@pytest.mark.parametrize(
184234
"ddp,cp,pp,tp,fp8,wi",
185235
[
@@ -216,6 +266,7 @@ def test_predict_evo2_runs_with_log_probs(
216266
fp8: bool,
217267
wi: str,
218268
checkpoint_7b_1m_path: Path,
269+
baseline_predictions_7b_1m_results: dict[int, float],
219270
num_sequences: int = 5,
220271
target_sequence_lengths: list[int] = [2048, 2048, 2048, 2048, 2048],
221272
):
@@ -228,6 +279,7 @@ def test_predict_evo2_runs_with_log_probs(
228279
"""
229280

230281
world_size = ddp * cp * pp * tp
282+
mp_size = cp * pp * tp
231283
if world_size > torch.cuda.device_count():
232284
pytest.skip(f"World size {world_size} is less than the number of GPUs {torch.cuda.device_count()}")
233285
is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device())
@@ -290,7 +342,6 @@ def test_predict_evo2_runs_with_log_probs(
290342
f
291343
) # This gives us the mapping from the sequence names to the indices in the predictions.
292344
preds = [torch.load(pf) for pf in pred_files]
293-
preds = [torch.load(pf) for pf in pred_files]
294345
preds = batch_collator(
295346
[p for p in preds if p is not None],
296347
)
@@ -299,6 +350,9 @@ def test_predict_evo2_runs_with_log_probs(
299350
assert "seq_idx" in preds
300351
assert len(preds["log_probs_seqs"]) == len(preds["seq_idx"]) == num_sequences
301352
assert len(seq_idx_map) == num_sequences
302-
# TODO consider some kind of numerical test on the log probabilities returned. For now though there is no
303-
# correct answer, and the model is just a subset so it is not even a real model we would expect a good result
304-
# from. Checking that output is made without error will still capture API drift.
353+
for original_idx, log_probs in zip(preds["seq_idx"], preds["log_probs_seqs"]):
354+
if mp_size > 1 or fp8:
355+
rel = 1e-3
356+
else:
357+
rel = 1e-6
358+
assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel)

0 commit comments

Comments
 (0)