@@ -241,6 +241,15 @@ def sequences():
241241 return [row ["Sequence" ] for row in reader ]
242242
243243
244+ @pytest .fixture
245+ def coding_sequences ():
246+ with (Path (__file__ ).parent / "data" / "cds_prompts.csv" ).open (newline = "" ) as f :
247+ from csv import DictReader
248+
249+ reader = DictReader (f )
250+ return [row ["Sequence" ] for row in reader ]
251+
252+
244253def get_trainer (pipeline_parallel = 1 ):
245254 import nemo .lightning as nl
246255
@@ -468,10 +477,13 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc
468477 )
469478
470479
471- def mid_point_split (* , seq , num_tokens ):
472- mid_point = 2 * ( len (seq ) // 4 )
480+ def mid_point_split (* , seq , num_tokens : int | None = None , fraction : float = 0.5 ):
481+ mid_point = int ( fraction * len (seq ))
473482 prompt = seq [:mid_point ]
474- target = seq [mid_point : mid_point + num_tokens ] # Only compare to the section of sequence directly
483+ if num_tokens is not None :
484+ target = seq [mid_point : mid_point + num_tokens ] # Only compare to the section of sequence directly
485+ else :
486+ target = seq [mid_point :]
475487 return prompt , target
476488
477489
@@ -550,3 +562,92 @@ def test_batch_generate(
550562 assert all (mp >= 0.90 * ep for mp , ep in zip (match_percents , expected_matchpercents )), (
551563 f"Expected at least 90% of { matchperc_print_expected = } , got { matchperc_print = } "
552564 )
565+
566+
567+ @pytest .mark .parametrize (
568+ "ckpt_name,model_tokenizer_provider,expected_matchpercents" ,
569+ [
570+ ("evo2/1b-8k-bf16:1.0" , get_model_and_tokenizer , [86.4 , 78.8 , 87.6 ]),
571+ ("evo2/1b-8k:1.0" , get_model_and_tokenizer , [86.4 , 78.8 , 87.6 ]),
572+ ("evo2_mamba/7b-8k:0.1" , get_model_and_tokenizer_ignore_vortex , [86.5 , 88.4 , 88.2 ]),
573+ # ("evo2/7b-8k:1.0", get_model_and_tokenizer, [88.8, 88.5, 82.2]),
574+ # ("evo2/7b-1m:1.0", get_model_and_tokenizer, [88.8, 88.5, 82.2]),
575+ ],
576+ )
577+ def test_batch_generate_coding_sequences (
578+ coding_sequences : list [str ],
579+ ckpt_name : str ,
580+ model_tokenizer_provider : Callable ,
581+ expected_matchpercents : list [float ],
582+ ):
583+ assert len (coding_sequences ) > 0
584+ is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
585+ skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
586+ if skip :
587+ # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
588+ pytest .skip (f"Skipping { ckpt_name } because it is not supported on { device_info } ({ compute_capability } )" )
589+ if "evo2_mamba" in ckpt_name and os .environ .get ("BIONEMO_DATA_SOURCE" ) != "pbss" :
590+ # TODO: add evo2_mamba/7b-8k to NGC and remove this skip
591+ pytest .skip (f"Skipping { ckpt_name } because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`." )
592+ # only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support
593+ vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name
594+ inference_wrapped_model , mcore_tokenizer = model_tokenizer_provider (ckpt_name , vortex_style_fp8 = vortex_style_fp8 )
595+
596+ match_percents : list [float ] = []
597+ cds_lengths : list [int | None ] = []
598+ original_cds_lengths : list [int ] = [len (seq ) for seq in coding_sequences ]
599+ seq_prompts = [mid_point_split (seq = seq , num_tokens = None , fraction = 0.3 ) for seq in coding_sequences ]
600+ num_tokens = max (len (sq [1 ]) for sq in seq_prompts ) + 15
601+ from megatron .core .inference .common_inference_params import CommonInferenceParams
602+ from nemo .collections .llm .inference import generate
603+
604+ results = generate (
605+ model = inference_wrapped_model ,
606+ max_batch_size = 1 , # vortex only supports batch size 1
607+ tokenizer = mcore_tokenizer ,
608+ prompts = [sq [0 ] for sq in seq_prompts ],
609+ random_seed = 42 ,
610+ inference_params = CommonInferenceParams (
611+ temperature = 1.0 ,
612+ top_k = 1 ,
613+ top_p = 0.0 ,
614+ return_log_probs = False ,
615+ num_tokens_to_generate = num_tokens ,
616+ ),
617+ )
618+
619+ for i , (result , (prompt , target )) in enumerate (zip (results , seq_prompts )):
620+ gen_seq = result .generated_text
621+ logging .info (f"{ ckpt_name } { torch .distributed .get_rank ()= } { gen_seq = } " )
622+ logging .info (f"{ ckpt_name } { torch .distributed .get_rank ()= } { target = } " )
623+ full_seq = prompt + gen_seq
624+ stop_codons = {"TAA" , "TAG" , "TGA" }
625+ assert full_seq [:3 ] == "ATG" # start codon
626+ cds_length = None
627+ for codon_start in range (0 , len (full_seq ), 3 ):
628+ codon = full_seq [codon_start : codon_start + 3 ]
629+ if codon in stop_codons :
630+ cds_length = codon_start + 3
631+ break
632+ match_percent = calculate_sequence_identity (target , gen_seq )
633+ logging .info (
634+ f"{ ckpt_name } { torch .distributed .get_rank ()= } { match_percent = } expected: { expected_matchpercents [i ]} "
635+ )
636+ match_percents .append (match_percent )
637+ cds_lengths .append (cds_length )
638+ # 99% of the time, you have a stop codon within the first 96 codons if everything were random.
639+
640+ assert len (match_percents ) == len (expected_matchpercents )
641+ assert len (cds_lengths ) == len (original_cds_lengths )
642+ matchperc_print = [f"{ mp :.1f} %" for mp in match_percents ]
643+ matchperc_print_expected = [f"{ ep :.1f} %" for ep in expected_matchpercents ]
644+ # By chance you expect to have a stop codon within the first 96 codons if everything were random
645+ # so verify that we are putting the first stop codon after this point, as well as it being at least 90% of the
646+ # original sequence length.
647+ assert all (
648+ pcl is None or ((pcl - len (pmpt ) > 96 * 3 or len (tgt ) < 96 * 3 ) and pcl >= 0.9 * ocl )
649+ for pcl , ocl , (pmpt , tgt ) in zip (cds_lengths , original_cds_lengths , seq_prompts )
650+ ), f"Expected at least 70% of { original_cds_lengths = } , got { cds_lengths = } "
651+ assert all (mp >= 0.90 * ep for mp , ep in zip (match_percents , expected_matchpercents )), (
652+ f"Expected at least 90% of { matchperc_print_expected = } , got { matchperc_print = } "
653+ )
0 commit comments