2121import os
2222import subprocess
2323import sys
24+ import tempfile
2425from pathlib import Path
2526
2627import pytest
@@ -54,7 +55,7 @@ def checkpoint_1b_8k_bf16_path() -> Path:
5455 )
5556 else :
5657 raise e
57- return checkpoint_path
58+ yield checkpoint_path
5859
5960
6061@pytest .fixture (scope = "module" )
@@ -69,11 +70,11 @@ def checkpoint_7b_1m_path() -> Path:
6970 )
7071 else :
7172 raise e
72- return checkpoint_path
73+ yield checkpoint_path
7374
7475
7576@pytest .mark .parametrize (
76- "ddp,pp,tp, wi" ,
77+ "ddp,pp,wi" ,
7778 [
7879 pytest .param (1 , 1 , "epoch" , id = "ddp=1,pp=1,wi=epoch" ),
7980 pytest .param (2 , 1 , "epoch" , id = "ddp=2,pp=1,wi=epoch" ),
@@ -91,7 +92,6 @@ def test_predict_evo2_runs(
9192 tmp_path ,
9293 ddp : int ,
9394 pp : int ,
94- tp : int ,
9595 wi : str ,
9696 checkpoint_1b_8k_bf16_path : Path ,
9797 num_sequences : int = 5 ,
@@ -105,7 +105,7 @@ def test_predict_evo2_runs(
105105 Since it's the full output this does not support CP, so we only test with TP=1. We also want coverage of the
106106 case where the sequence lengths are different and not necessarily divisible by CP.
107107 """
108- world_size = ddp * pp * tp
108+ world_size = ddp * pp
109109 if world_size > torch .cuda .device_count ():
110110 pytest .skip (f"World size { world_size } is less than the number of GPUs { torch .cuda .device_count ()} " )
111111 fasta_file_path = tmp_path / "test.fasta"
@@ -125,7 +125,7 @@ def test_predict_evo2_runs(
125125 command = (
126126 f"torchrun --nproc_per_node { world_size } --nnodes 1 --no-python "
127127 f"predict_evo2 --fasta { fasta_file_path } --ckpt-dir { checkpoint_1b_8k_bf16_path } "
128- f"--output-dir { output_dir } --model-size 1b --tensor-parallel-size { tp } "
128+ f"--output-dir { output_dir } --model-size 1b "
129129 f"--micro-batch-size 3 --write-interval { wi } "
130130 f"--pipeline-model-parallel-size { pp } --num-nodes 1 --devices { world_size } "
131131 )
@@ -180,6 +180,56 @@ def test_predict_evo2_runs(
180180 assert token_logits .shape == (max (target_sequence_lengths ), 512 )
181181
182182
183+ @pytest .fixture (scope = "module" )
184+ def baseline_predictions_7b_1m_results (
185+ checkpoint_7b_1m_path : Path ,
186+ num_sequences : int = 5 ,
187+ target_sequence_lengths : list [int ] = [2048 , 2048 , 2048 , 2048 , 2048 ],
188+ ) -> dict [int , float ]:
189+ with tempfile .TemporaryDirectory () as tmp_dir :
190+ tmp_path = Path (tmp_dir )
191+ fasta_file_path = tmp_path / "test.fasta"
192+ create_fasta_file (
193+ fasta_file_path ,
194+ num_sequences ,
195+ sequence_lengths = target_sequence_lengths ,
196+ repeating_dna_pattern = ALU_SEQUENCE ,
197+ )
198+ output_dir = tmp_path / "test_output"
199+ command = (
200+ f"torchrun --nproc_per_node 1 --nnodes 1 --no-python "
201+ f"predict_evo2 --fasta { fasta_file_path } --ckpt-dir { checkpoint_7b_1m_path } "
202+ f"--num-layers 4 --hybrid-override-pattern SDH* " # subset of layers for testing
203+ # FIXME changing batch size from 3 to 1 required dropping rel=1e-6 to rel=1e-3
204+ # even when model parallelism is not used. This should be investigated.
205+ f"--micro-batch-size 3 "
206+ f"--output-dir { output_dir } --model-size 7b_arc_longcontext "
207+ f"--num-nodes 1 --write-interval epoch "
208+ "--output-log-prob-seqs --log-prob-collapse-option sum"
209+ )
210+ # Create a mock data directory.
211+ # a local copy of the environment
212+ env = dict (** os .environ )
213+ open_port = find_free_network_port ()
214+ env ["MASTER_PORT" ] = str (open_port )
215+ result = subprocess .run (
216+ command ,
217+ shell = True , # Use the shell to interpret wildcards (e.g. SDH*)
218+ cwd = tmp_path , # Run in the temporary directory
219+ capture_output = True , # Capture stdout and stderr for debugging
220+ env = env , # Pass in the env where we override the master port.
221+ text = True , # Decode output as text
222+ )
223+ assert result .returncode == 0 , "predict_evo2 command failed."
224+ # Assert that the output directory was created.
225+ pred_files = glob .glob (os .path .join (output_dir , "predictions__rank_*.pt" ))
226+ preds = [torch .load (pf ) for pf in pred_files ]
227+ preds = batch_collator (
228+ [p for p in preds if p is not None ],
229+ )
230+ yield dict (zip ([i .item () for i in preds ["seq_idx" ]], [p .item () for p in preds ["log_probs_seqs" ]]))
231+
232+
183233@pytest .mark .parametrize (
184234 "ddp,cp,pp,tp,fp8,wi" ,
185235 [
@@ -216,6 +266,7 @@ def test_predict_evo2_runs_with_log_probs(
216266 fp8 : bool ,
217267 wi : str ,
218268 checkpoint_7b_1m_path : Path ,
269+ baseline_predictions_7b_1m_results : dict [int , float ],
219270 num_sequences : int = 5 ,
220271 target_sequence_lengths : list [int ] = [2048 , 2048 , 2048 , 2048 , 2048 ],
221272):
@@ -228,6 +279,7 @@ def test_predict_evo2_runs_with_log_probs(
228279 """
229280
230281 world_size = ddp * cp * pp * tp
282+ mp_size = cp * pp * tp
231283 if world_size > torch .cuda .device_count ():
232284 pytest .skip (f"World size { world_size } is less than the number of GPUs { torch .cuda .device_count ()} " )
233285 is_fp8_supported , _ , _ = check_fp8_support (torch .cuda .current_device ())
@@ -290,7 +342,6 @@ def test_predict_evo2_runs_with_log_probs(
290342 f
291343 ) # This gives us the mapping from the sequence names to the indices in the predictions.
292344 preds = [torch .load (pf ) for pf in pred_files ]
293- preds = [torch .load (pf ) for pf in pred_files ]
294345 preds = batch_collator (
295346 [p for p in preds if p is not None ],
296347 )
@@ -299,6 +350,9 @@ def test_predict_evo2_runs_with_log_probs(
299350 assert "seq_idx" in preds
300351 assert len (preds ["log_probs_seqs" ]) == len (preds ["seq_idx" ]) == num_sequences
301352 assert len (seq_idx_map ) == num_sequences
302- # TODO consider some kind of numerical test on the log probabilities returned. For now though there is no
303- # correct answer, and the model is just a subset so it is not even a real model we would expect a good result
304- # from. Checking that output is made without error will still capture API drift.
353+ for original_idx , log_probs in zip (preds ["seq_idx" ], preds ["log_probs_seqs" ]):
354+ if mp_size > 1 or fp8 :
355+ rel = 1e-3
356+ else :
357+ rel = 1e-6
358+ assert log_probs .item () == pytest .approx (baseline_predictions_7b_1m_results [original_idx .item ()], rel = rel )
0 commit comments