1717# limitations under the License.
1818
1919import logging
20+ import os
2021from pathlib import Path
21- from typing import Literal , Set
22+ from typing import Callable , Literal
2223
2324import numpy as np
2425import pytest
4647def load_weights_sharded_inplace_nemo2_to_mcore (
4748 model : MegatronModelType ,
4849 distributed_checkpoint_dir : str | Path ,
49- skip_keys_with_these_prefixes : Set [str ],
50+ skip_keys_with_these_prefixes : set [str ],
5051 ckpt_format : Literal ["zarr" , "torch_dist" ] = "torch_dist" ,
5152):
5253 logger .info ("Start setting up state dict" )
@@ -274,11 +275,19 @@ def get_trainer(pipeline_parallel=1):
274275 )
275276
276277
277- def get_model_and_tokenizer (ckpt_name , vortex_style_fp8 = False ):
278+ def get_model_and_tokenizer_raw (ckpt_dir_or_name : Path | str , ** kwargs ):
279+ """
280+ Load a model and tokenizer from a checkpoint directory or name. If you supply a Path argument then we assume that
281+ the path is already a checkpoint directory, otherwise we load the checkpoint from NGC or PBSS depending on
282+ the environment variable BIONEMO_DATA_SOURCE.
283+ """
278284 trainer = get_trainer ()
279285 from bionemo .core .data .load import load
280286
281- ckpt_dir : Path = load (ckpt_name )
287+ if isinstance (ckpt_dir_or_name , Path ):
288+ ckpt_dir : Path = ckpt_dir_or_name
289+ else :
290+ ckpt_dir : Path = load (ckpt_dir_or_name )
282291 from nemo .collections .llm import inference
283292
284293 inference_wrapped_model , mcore_tokenizer = inference .setup_model_and_tokenizer (
@@ -287,20 +296,23 @@ def get_model_and_tokenizer(ckpt_name, vortex_style_fp8=False):
287296 params_dtype = torch .bfloat16 ,
288297 inference_batch_times_seqlen_threshold = 8192 , # TODO
289298 inference_max_seq_length = 8192 , # TODO
290- vortex_style_fp8 = vortex_style_fp8 ,
291- # use_te_rng_tracker=True,
292- # te_rng_tracker=True,
293- # inference_rng_tracker=True,
294- # enable_cuda_graph=True,
295- # cudagraph_rng_tracker=True,
296- # flash_decode=True,
297299 recompute_granularity = None ,
298300 recompute_num_layers = None ,
299301 recompute_method = None ,
302+ ** kwargs ,
300303 )
301304 return inference_wrapped_model , mcore_tokenizer
302305
303306
307+ def get_model_and_tokenizer (ckpt_name , vortex_style_fp8 = False ):
308+ return get_model_and_tokenizer_raw (ckpt_name , vortex_style_fp8 = vortex_style_fp8 )
309+
310+
311+ def get_model_and_tokenizer_ignore_vortex (ckpt_name , vortex_style_fp8 = False ):
312+ # Capture and remove the vortex_style_fp8 argument for mamba models.
313+ return get_model_and_tokenizer_raw (ckpt_name )
314+
315+
304316def calc_matchrate (* , tokenizer , in_seq , logits ):
305317 softmax_logprobs = torch .log_softmax (logits , dim = - 1 )
306318 softmax_logprobs = softmax_logprobs [:, :- 1 ]
@@ -476,24 +488,30 @@ def calculate_sequence_identity(seq1: str, seq2: str) -> float | None:
476488
477489
478490@pytest .mark .parametrize (
479- "ckpt_name,expected_matchpercents" ,
491+ "ckpt_name,model_tokenizer_provider, expected_matchpercents" ,
480492 [
481- ("evo2/1b-8k-bf16:1.0" , [96.8 , 29.7 , 76.6 , 71.6 ]),
482- ("evo2/1b-8k:1.0" , [96.8 , 29.7 , 76.6 , 71.6 ]),
483- # ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57]),
484- # ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57]),
493+ ("evo2/1b-8k-bf16:1.0" , get_model_and_tokenizer , [96.8 , 29.7 , 76.6 , 71.6 ]),
494+ ("evo2/1b-8k:1.0" , get_model_and_tokenizer , [96.8 , 29.7 , 76.6 , 71.6 ]),
495+ ("evo2_mamba/7b-8k:0.1" , get_model_and_tokenizer_ignore_vortex , [99.2 , 51.0 , 73.0 , 82.6 ]),
496+ # ("evo2/7b-8k:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
497+ # ("evo2/7b-1m:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
485498 ],
486499)
487- def test_batch_generate (sequences : list [str ], ckpt_name : str , expected_matchpercents : list [float ]):
500+ def test_batch_generate (
501+ sequences : list [str ], ckpt_name : str , model_tokenizer_provider : Callable , expected_matchpercents : list [float ]
502+ ):
488503 assert len (sequences ) > 0
489504 is_fp8_supported , compute_capability , device_info = check_fp8_support (torch .cuda .current_device ())
490505 skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
491506 if skip :
492507 # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
493508 pytest .skip (f"Skipping { ckpt_name } because it is not supported on { device_info } ({ compute_capability } )" )
509+ if "evo2_mamba" in ckpt_name and os .environ .get ("BIONEMO_DATA_SOURCE" ) != "pbss" :
510+ # TODO: add evo2_mamba/7b-8k to NGC and remove this skip
511+ pytest .skip (f"Skipping { ckpt_name } because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`." )
494512 # only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support
495513 vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name
496- inference_wrapped_model , mcore_tokenizer = get_model_and_tokenizer (ckpt_name , vortex_style_fp8 = vortex_style_fp8 )
514+ inference_wrapped_model , mcore_tokenizer = model_tokenizer_provider (ckpt_name , vortex_style_fp8 = vortex_style_fp8 )
497515
498516 match_percents = []
499517 num_tokens = 500
0 commit comments