|
22 | 22 | import subprocess |
23 | 23 | import sys |
24 | 24 |
|
| 25 | +import pytest |
25 | 26 | import torch |
26 | 27 | from lightning.fabric.plugins.environments.lightning import find_free_network_port |
27 | 28 |
|
28 | 29 | from bionemo.core.data.load import load |
29 | 30 | from bionemo.noodles.nvfaidx import NvFaidx |
30 | 31 | from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file |
| 32 | +from bionemo.testing.subprocess_utils import run_command_in_subprocess |
31 | 33 |
|
32 | 34 |
|
| 35 | + |
| 36 | +def small_training_llama_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""): |
| 37 | + cmd = ( |
| 38 | + f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} " |
| 39 | + "--model-size 8B --num-layers 2 --limit-val-batches 1 " |
| 40 | + "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback " |
| 41 | + f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " |
| 42 | + f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}" |
| 43 | + ) |
| 44 | + return cmd |
| 45 | + |
33 | 46 | def test_predict_evo2_runs( |
34 | 47 | tmp_path, num_sequences: int = 5, target_sequence_lengths: list[int] = [3149, 3140, 1024, 3149, 3149] |
35 | 48 | ): |
@@ -104,3 +117,90 @@ def test_predict_evo2_runs( |
104 | 117 | idx = seq_idx_map[seq_name] # look up the out of order prediction index for this sequence. |
105 | 118 | assert preds["pad_mask"][idx].sum() == expected_len |
106 | 119 | assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), 512) |
| 120 | + |
| 121 | + |
| 122 | +@pytest.mark.timeout(512) # Optional: fail if the test takes too long. |
| 123 | +def test_predict_evo2_llama_runs( |
| 124 | + tmp_path, num_sequences: int = 5, target_sequence_lengths: list[int] = [3149, 3140, 1024, 3149, 3149] |
| 125 | +): |
| 126 | + """ |
| 127 | + This test first trains a small Llama model to create a checkpoint, then runs the `predict_evo2` command |
| 128 | + with that checkpoint and mock data in a temporary directory. |
| 129 | + It uses the temporary directory provided by pytest as the working directory. |
| 130 | + The command is run in a subshell, and we assert that it returns an exit code of 0. |
| 131 | + """ |
| 132 | + # First, train a small Llama model to create a checkpoint |
| 133 | + num_steps = 2 |
| 134 | + train_command = small_training_llama_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps) |
| 135 | + stdout_pretrain: str = run_command_in_subprocess(command=train_command, path=str(tmp_path)) |
| 136 | + assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain |
| 137 | + |
| 138 | + # Find the created checkpoint |
| 139 | + log_dir = tmp_path / "pretrain" / "evo2" |
| 140 | + checkpoints_dir = log_dir / "checkpoints" |
| 141 | + assert checkpoints_dir.exists(), "Checkpoints folder does not exist." |
| 142 | + |
| 143 | + expected_checkpoint_suffix = f"{num_steps}.0-last" |
| 144 | + matching_subfolders = [ |
| 145 | + p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) |
| 146 | + ] |
| 147 | + assert matching_subfolders, ( |
| 148 | + f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." |
| 149 | + ) |
| 150 | + assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found." |
| 151 | + checkpoint_path = matching_subfolders[0] |
| 152 | + |
| 153 | + # Now create the FASTA file for prediction |
| 154 | + fasta_file_path = tmp_path / "test_llama.fasta" |
| 155 | + create_fasta_file( |
| 156 | + fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE |
| 157 | + ) |
| 158 | + |
| 159 | + # Build the command string for Llama model prediction. |
| 160 | + # Note: The command assumes that `predict_evo2` is in your PATH. |
| 161 | + output_dir = tmp_path / "test_llama_output" |
| 162 | + command = ( |
| 163 | + f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_path} " |
| 164 | + f"--output-dir {output_dir} --model-type llama --model-size 8B --tensor-parallel-size 1 " |
| 165 | + "--pipeline-model-parallel-size 1 --context-parallel-size 1 --num-layers 2" |
| 166 | + ) |
| 167 | + |
| 168 | + # Run the command in a subshell, using the temporary directory as the current working directory. |
| 169 | + result = subprocess.run( |
| 170 | + command, |
| 171 | + shell=True, # Use the shell to interpret wildcards (e.g. SDH*) |
| 172 | + cwd=tmp_path, # Run in the temporary directory |
| 173 | + capture_output=True, # Capture stdout and stderr for debugging |
| 174 | + text=True, # Decode output as text |
| 175 | + ) |
| 176 | + |
| 177 | + # For debugging purposes, print the output if the test fails. |
| 178 | + if result.returncode != 0: |
| 179 | + sys.stderr.write("STDOUT:\n" + result.stdout + "\n") |
| 180 | + sys.stderr.write("STDERR:\n" + result.stderr + "\n") |
| 181 | + |
| 182 | + # Assert that the command completed successfully. |
| 183 | + assert result.returncode == 0, "predict_evo2 command with Llama model failed." |
| 184 | + |
| 185 | + # Assert that the output directory was created. |
| 186 | + pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt")) |
| 187 | + assert len(pred_files) == 1, "Expected 1 prediction file (for this test), got {}".format(len(pred_files)) |
| 188 | + with open(output_dir / "seq_idx_map.json", "r") as f: |
| 189 | + seq_idx_map = json.load( |
| 190 | + f |
| 191 | + ) # This gives us the mapping from the sequence names to the indices in the predictions. |
| 192 | + preds = torch.load(pred_files[0]) |
| 193 | + assert isinstance(preds, dict) |
| 194 | + assert "token_logits" in preds |
| 195 | + assert "pad_mask" in preds |
| 196 | + assert "seq_idx" in preds |
| 197 | + assert len(preds["token_logits"]) == len(preds["pad_mask"]) == len(preds["seq_idx"]) == num_sequences |
| 198 | + assert len(seq_idx_map) == num_sequences |
| 199 | + fasta = NvFaidx(fasta_file_path) |
| 200 | + for i, seq_name in enumerate(sorted(fasta.keys())): |
| 201 | + expected_len = target_sequence_lengths[i] |
| 202 | + idx = seq_idx_map[seq_name] # look up the out of order prediction index for this sequence. |
| 203 | + assert preds["pad_mask"][idx].sum() == expected_len |
| 204 | + assert preds["token_logits"][idx].shape == (max(target_sequence_lengths), 512) |
| 205 | + |
| 206 | + |
0 commit comments