Skip to content

Commit b764770

Browse files
authored
Adding inference CDS length tests (#991)
### Description CDS length tests ### Type of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 7d8b81b commit b764770

2 files changed

Lines changed: 108 additions & 3 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Sequence,Name,Percent
2+
ATGAGTCAGAATACGCTGAAAGTTCATGATTTAAATGAAGATGCGGAATTTGATGAGAACGGAGTTGAGGTTTTTGACGAAAAGGCCTTAGTAGAACAGGAACCCAGTGATAACGATTTGGCCGAAGAGGAACTGTTATCGCAGGGAGCCACACAGCGTGTGTTGGACGCGACTCAGCTTTACCTTGGTGAGATTGGTTATTCACCACTGTTAACGGCCGAAGAAGAAGTTTATTTTGCGCGTCGCGCACTGCGTGGAGATGTCGCCTCTCGCCGCCGGATGATCGAGAGTAACTTGCGTCTGGTGGTAAAAATTGCCCGCCGTTATGGCAATCGTGGTCTGGCGTTGCTGGACCTTATCGAAGAGGGCAACCTGGGGCTGATCCGCGCGGTAGAGAAGTTTGACCCGGAACGTGGTTTCCGCTTCTCAACATACGCAACCTGGTGGATTCGCCAGACGATTGAACGGGCGATTATGAACCAAACCCGTACTATTCGTTTGCCGATTCACATCGTAAAGGAGCTGAACGTTTACCTGCGAACCGCACGTGAGTTGTCCCATAAGCTGGACCATGAACCAAGTGCGGAAGAGATCGCAGAGCAACTGGATAAGCCAGTTGATGACGTCAGCCGTATGCTTCGTCTTAACGAGCGCATTACCTCGGTAGACACCCCGCTGGGTGGTGATTCCGAAAAAGCGTTGCTGGACATCCTGGCCGATGAAAAAGAGAACGGTCCGGAAGATACCACGCAAGATGACGATATGAAGCAGAGCATCGTCAAATGGCTGTTCGAGCTGAACGCCAAACAGCGTGAAGTGCTGGCACGTCGATTCGGTTTGCTGGGGTACGAAGCGGCAACACTGGAAGATGTAGGTCGTGAAATTGGCCTCACCCGTGAACGTGTTCGCCAGATTCAGGTTGAAGGCCTGCGCCGTTTGCGCGAAATCCTGCAAACGCAGGGGCTGAATATCGAAGCGCTGTTCCGCGAGTAA,rpoS_NC_000913.3_cds_NP_417221.1_2701,100%
3+
ATGAGCGACCTTGCGAGAGAAATTACACCGGTCAACATTGAGGAGGAGCTGAAGAGCTCCTATCTGGATTATGCGATGTCGGTCATTGTTGGCCGTGCGCTGCCGGATGTCCGAGATGGCCTGAAGCCGGTACACCGTCGCGTACTTTACGCCATGAACGTATTGGGCAATGACTGGAACAAAGCCTATAAAAAATCTGCCCGTGTCGTTGGTGACGTAATCGGTAAATACCATCCCCACGGCGATTCCGCAGTGTATGACACCATCGTTCGTATGGCGCAGCCATTCTCGCTGCGTTACATGCTGGTGGATGGTCAGGGTAACTTCGGTTCTATTGACGGCGACTCCGCGGCGGCAATGCGTTATACGGAGATCCGTCTGGCGAAAATCGCCCACGAACTGATGGCCGATCTCGAAAAAGAGACGGTGGATTTCGTGGATAACTATGACGGTACGGAAAAAATTCCGGACGTCATGCCGACCAAAATTCCGAATCTGCTGGTGAACGGTTCTTCCGGTATCGCAGTAGGTATGGCGACGAATATCCCGCCGCACAACCTGACGGAAGTGATTAACGGCTGCCTGGCGTATATCGACAACGAAGACATCAGCATTGAAGGGCTGATGGAACATATTCCGGGGCCGGACTTCCCGACCGCCGCGATCATCAACGGTCGTCGTGGTATCGAAGAAGCCTACCGCACCGGTCGTGGCAAAGTGTACATTCGCGCCCGCGCGGAAGTTGAAGCTGACGCCAAAACGGGCCGTGAAACCATCATCGTCCATGAAATTCCCTATCAGGTGAACAAAGCGCGCCTGATCGAGAAAATCGCCGAGCTGGTGAAAGATAAACGCGTGGAAGGCATCAGCGCGCTGCGTGACGAATCCGACAAAGACGGGATGCGCATCGTGATTGAAGTGAAACGCGATGCGGTGGGCGAGGTGGTGCTTAATAATCTCTACTCCCAGACCCAGCTACAGGTTTCCTTCGGTATTAACATGGTGGCGCTGCATCACGGCCAGCCGAAGATCATGAACCTGAAAGATATCATTTCAGCGTTCGTGCGCCACCGCCGTGAAGTGGTGACGCGTCGGACTATTTTTGAACTGCGTAAAGCCCGTGACCGTGCGCATATCCTTGAAGCTCTGGCGATTGCGCTGGCCAACATCGACCCGATTATCGAACTGATTCGCCGCGCGCCAACGCCGGCGGAAGCAAAAGCGGCGCTGATTTCGCGTCCGTGGGATCTGGGCAACGTTGCTGCGATGCTGGAGCGCGCTGGTGATGACGCCGCGCGTCCGGAATGGCTGGAGCCAGAATTTGGCGTGCGTGACGGTCAGTACTACCTGACTGAACAGCAGGCGCAGGCGATTCTGGATCTGCGTTTGCAGAAACTGACCGGCCTGGAGCATGAAAAACTGCTCGACGAATACAAAGAGCTGCTGGAGCAGATTGCTGAATTGCTGCACATTCTGGGCAGCGCCGATCGCCTGATGGAAGTGATCCGCGAAGAGATGGAGTTAATTCGCGATCAGTTCGGCGATGAGCGTCGTACCGAAATCACCGCCAACAGCGCCGATATTAATATCGAAGATCTGATTAGCCAGGAAGATGTTGTCGTGACGCTGTCTCACCAGGGTTACGTCAAATATCAACCGCTGACAGATTACGAAGCGCAACGTCGTGGTGGGAAAGGTAAATCTGCCGCGCGTATTAAAGAAGAAGACTTTATCGACCGCCTGCTGGTGGCTAACACCCATGACACCATCCTCTGCTTCTCCAGCCGGGGCCGTCTGTACTGGATGAAGGTCTATCAGCTGCCGGAAGCCAGCCGCGGCGCGCGCGGTCGTCCGATCGTCAACCTGCTGCCGCTGGAAGCCAACGAACGTATCACCGCGATTCTGCCGGTTCGTGAGTATGAAGAAGGCGTCAACGTCTTTATGGCGACCGCCAGCGGTACCGTGAAGAAAACGGCGCTGACCGAATTCAGCCGTCCGCGTTCCGCCGGTATTATCGCGGTGAACCTCAACGACGGCGACGAGCTGATTGGCGTTGACCTGACTTCTGGTTCTGACGAAGTCATGCTGTTCTCGGCCGCGGGTAAAGTGGTGCGCTTCAAAGAAGACGCCGTCCGTGCGATGGGGCGTACCGCGACCGGTGTGCGCGGTATTAAGCTGGCGGGAGACGATAAAGTCGTCTCTCTGATCATCCCACGCGGCGAAGGCGCTATTCTGACCGTAACGCAAAACGGCTACGGGAAGCGTACCGCAGCGGACGAGTACCCGACCAAGTCTCGTGCGACGCAGGGCGTTATCTCTATCAAAGTGACCGAGCGCAACGGTTCCGTTGTCGGTGCGGTACAGGTAGACGATTGCGACCAGATCATGATGATCACGGATGCCGGTACTCTGGTGCGTACCCGTGTGTCCGAGATCAGCGTAGTGGGACGTAATACCCAGGGCGTTATCCTTATCCGCACGGCGGAAGATGAAAACGTGGTGGGTCTGCAACGCGTTGCTGAACCGGTAGATGACGAAGAACTCGACGCTATCGACGGCAGCGTGGCGGAAGGGGATGAGGATATCGCCCCGGAAGCGGAAAGCGATGACGACGTTGCGGATGACGCTGACGAGTAA,gyrA_NC_003197.2_cds_NP_461214.1_2209,100%
4+
ATGGACTCTATCGTCGGCGACGCAATTGACGAGGCCGAGGCCGAGGACATGGGGGATGAGTCGGCTCAGGTCGACGGCGCGGCCAACATCAACCGGTCCGGGACGATGACTGACGACGAACTGAAAGCGGTTCTCAAAGACCTCCAGACCAACATCACGGTGGTCGGGTGCGGCGGTGCCGGCGGTAACACCGTCAACCGGATGCACGAGGAGGGAATCAAGGGGGCGAAGCTCGTCGCCGCCAACACCGACGTGCAGCACCTCGTGGAAATCGGGGCCGATACGAAGATTCTCATGGGCGAGCAGAAGACCCAAGGCCGCGGCGCGGGCTCGCTCCCGCAGGTCGGTGAGGAGGCCGCCCTCGAATCCCAAGAGGAGATTTACGACGCCATCGAGGGCTCCGACATGGTGTTCGTCACCGCCGGACTCGGCGGCGGCACCGGCACCGGTTCGGCTCCCGTCGTCGCCAAGGCGGCCCGCGAGTCGGGCGCGCTCACCATCGCCATCGTCACGACGCCCTTTACGGCCGAAGGCGAGGTACGACGGACGAACGCCGAGGCCGGTCTCGAACGGCTCCGCGACGTGTCGGACACGGTCATCGTCGTCCCGAACGACCGCCTGCTCGACGCCGTGGGCAAACTCCCCGTCCGGCAGGCGTTCAAGGTCTCCGACGAGGTGCTGATGCGCTCGGTCAAGGGCATCACCGAACTCATCACTAAGCCCGGTCTCGTCAACCTCGACTTCGCCGACGTGAAGACCGTCATGGAGCGCGGCGGCGTCGCCATGATCGGTCTCGGCGAGTCCGACTCCGAGTCCAAGGCTCAGGAGTCCGTCAAGTCCGCCCTCCGCTCGCCGCTTCTTGACGTGGACATCTCCGGCGCGAACTCCGCGCTCGTCAACGTCACCGGCGGTTCGGACATGAGCATCGAGGAGGCCGAGGGCGTCGTCGAGGAGATTTACGACCGCATCGACCCCGACGCGCGCATCATCTGGGGGACCTCCGTCGACGACGAACTCGAAGGCATGATGCGGACGATGATCGTCGTCACCGGCGTCGAGTCGCCCCAAATCTACGGCCGCAACGGCGAGGCACAGGCGCACGCCGAAGAGCGTCTCGAAGACATCGACTACGTCGAGTAG,ftsZ_NC_013967.1_cds_WP_004044352.1_564,100%

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_evo2.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ def sequences():
241241
return [row["Sequence"] for row in reader]
242242

243243

244+
@pytest.fixture
245+
def coding_sequences():
246+
with (Path(__file__).parent / "data" / "cds_prompts.csv").open(newline="") as f:
247+
from csv import DictReader
248+
249+
reader = DictReader(f)
250+
return [row["Sequence"] for row in reader]
251+
252+
244253
def get_trainer(pipeline_parallel=1):
245254
import nemo.lightning as nl
246255

@@ -468,10 +477,13 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc
468477
)
469478

470479

471-
def mid_point_split(*, seq, num_tokens):
472-
mid_point = 2 * (len(seq) // 4)
480+
def mid_point_split(*, seq, num_tokens: int | None = None, fraction: float = 0.5):
481+
mid_point = int(fraction * len(seq))
473482
prompt = seq[:mid_point]
474-
target = seq[mid_point : mid_point + num_tokens] # Only compare to the section of sequence directly
483+
if num_tokens is not None:
484+
target = seq[mid_point : mid_point + num_tokens] # Only compare to the section of sequence directly
485+
else:
486+
target = seq[mid_point:]
475487
return prompt, target
476488

477489

@@ -550,3 +562,92 @@ def test_batch_generate(
550562
assert all(mp >= 0.90 * ep for mp, ep in zip(match_percents, expected_matchpercents)), (
551563
f"Expected at least 90% of {matchperc_print_expected=}, got {matchperc_print=}"
552564
)
565+
566+
567+
@pytest.mark.parametrize(
568+
"ckpt_name,model_tokenizer_provider,expected_matchpercents",
569+
[
570+
("evo2/1b-8k-bf16:1.0", get_model_and_tokenizer, [86.4, 78.8, 87.6]),
571+
("evo2/1b-8k:1.0", get_model_and_tokenizer, [86.4, 78.8, 87.6]),
572+
("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, [86.5, 88.4, 88.2]),
573+
# ("evo2/7b-8k:1.0", get_model_and_tokenizer, [88.8, 88.5, 82.2]),
574+
# ("evo2/7b-1m:1.0", get_model_and_tokenizer, [88.8, 88.5, 82.2]),
575+
],
576+
)
577+
def test_batch_generate_coding_sequences(
578+
coding_sequences: list[str],
579+
ckpt_name: str,
580+
model_tokenizer_provider: Callable,
581+
expected_matchpercents: list[float],
582+
):
583+
assert len(coding_sequences) > 0
584+
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
585+
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
586+
if skip:
587+
# This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
588+
pytest.skip(f"Skipping {ckpt_name} because it is not supported on {device_info} ({compute_capability})")
589+
if "evo2_mamba" in ckpt_name and os.environ.get("BIONEMO_DATA_SOURCE") != "pbss":
590+
# TODO: add evo2_mamba/7b-8k to NGC and remove this skip
591+
pytest.skip(f"Skipping {ckpt_name} because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`.")
592+
# only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support
593+
vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name
594+
inference_wrapped_model, mcore_tokenizer = model_tokenizer_provider(ckpt_name, vortex_style_fp8=vortex_style_fp8)
595+
596+
match_percents: list[float] = []
597+
cds_lengths: list[int | None] = []
598+
original_cds_lengths: list[int] = [len(seq) for seq in coding_sequences]
599+
seq_prompts = [mid_point_split(seq=seq, num_tokens=None, fraction=0.3) for seq in coding_sequences]
600+
num_tokens = max(len(sq[1]) for sq in seq_prompts) + 15
601+
from megatron.core.inference.common_inference_params import CommonInferenceParams
602+
from nemo.collections.llm.inference import generate
603+
604+
results = generate(
605+
model=inference_wrapped_model,
606+
max_batch_size=1, # vortex only supports batch size 1
607+
tokenizer=mcore_tokenizer,
608+
prompts=[sq[0] for sq in seq_prompts],
609+
random_seed=42,
610+
inference_params=CommonInferenceParams(
611+
temperature=1.0,
612+
top_k=1,
613+
top_p=0.0,
614+
return_log_probs=False,
615+
num_tokens_to_generate=num_tokens,
616+
),
617+
)
618+
619+
for i, (result, (prompt, target)) in enumerate(zip(results, seq_prompts)):
620+
gen_seq = result.generated_text
621+
logging.info(f"{ckpt_name} {torch.distributed.get_rank()=} {gen_seq=}")
622+
logging.info(f"{ckpt_name} {torch.distributed.get_rank()=} {target=}")
623+
full_seq = prompt + gen_seq
624+
stop_codons = {"TAA", "TAG", "TGA"}
625+
assert full_seq[:3] == "ATG" # start codon
626+
cds_length = None
627+
for codon_start in range(0, len(full_seq), 3):
628+
codon = full_seq[codon_start : codon_start + 3]
629+
if codon in stop_codons:
630+
cds_length = codon_start + 3
631+
break
632+
match_percent = calculate_sequence_identity(target, gen_seq)
633+
logging.info(
634+
f"{ckpt_name} {torch.distributed.get_rank()=} {match_percent=} expected: {expected_matchpercents[i]}"
635+
)
636+
match_percents.append(match_percent)
637+
cds_lengths.append(cds_length)
638+
# 99% of the time, you have a stop codon within the first 96 codons if everything were random.
639+
640+
assert len(match_percents) == len(expected_matchpercents)
641+
assert len(cds_lengths) == len(original_cds_lengths)
642+
matchperc_print = [f"{mp:.1f}%" for mp in match_percents]
643+
matchperc_print_expected = [f"{ep:.1f}%" for ep in expected_matchpercents]
644+
# By chance you expect to have a stop codon within the first 96 codons if everything were random
645+
# so verify that we are putting the first stop codon after this point, as well as it being at least 90% of the
646+
# original sequence length.
647+
assert all(
648+
pcl is None or ((pcl - len(pmpt) > 96 * 3 or len(tgt) < 96 * 3) and pcl >= 0.9 * ocl)
649+
for pcl, ocl, (pmpt, tgt) in zip(cds_lengths, original_cds_lengths, seq_prompts)
650+
), f"Expected at least 70% of {original_cds_lengths=}, got {cds_lengths=}"
651+
assert all(mp >= 0.90 * ep for mp, ep in zip(match_percents, expected_matchpercents)), (
652+
f"Expected at least 90% of {matchperc_print_expected=}, got {matchperc_print=}"
653+
)

0 commit comments

Comments
 (0)