Skip to content

Commit 6256dcc

Browse files
committed
Remove mamba specific test file and consolidate duplicated code
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent fc97c2a commit 6256dcc

2 files changed

Lines changed: 36 additions & 199 deletions

File tree

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_evo2.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
# limitations under the License.
1818

1919
import logging
20+
import os
2021
from pathlib import Path
21-
from typing import Literal, Set
22+
from typing import Callable, Literal
2223

2324
import numpy as np
2425
import pytest
@@ -46,7 +47,7 @@
4647
def load_weights_sharded_inplace_nemo2_to_mcore(
4748
model: MegatronModelType,
4849
distributed_checkpoint_dir: str | Path,
49-
skip_keys_with_these_prefixes: Set[str],
50+
skip_keys_with_these_prefixes: set[str],
5051
ckpt_format: Literal["zarr", "torch_dist"] = "torch_dist",
5152
):
5253
logger.info("Start setting up state dict")
@@ -274,11 +275,19 @@ def get_trainer(pipeline_parallel=1):
274275
)
275276

276277

277-
def get_model_and_tokenizer(ckpt_name, vortex_style_fp8=False):
278+
def get_model_and_tokenizer_raw(ckpt_dir_or_name: Path | str, **kwargs):
279+
"""
280+
Load a model and tokenizer from a checkpoint directory or name. If you supply a Path argument then we assume that
281+
the path is already a checkpoint directory, otherwise we load the checkpoint from NGC or PBSS depending on
282+
the environment variable BIONEMO_DATA_SOURCE.
283+
"""
278284
trainer = get_trainer()
279285
from bionemo.core.data.load import load
280286

281-
ckpt_dir: Path = load(ckpt_name)
287+
if isinstance(ckpt_dir_or_name, Path):
288+
ckpt_dir: Path = ckpt_dir_or_name
289+
else:
290+
ckpt_dir: Path = load(ckpt_dir_or_name)
282291
from nemo.collections.llm import inference
283292

284293
inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
@@ -287,20 +296,23 @@ def get_model_and_tokenizer(ckpt_name, vortex_style_fp8=False):
287296
params_dtype=torch.bfloat16,
288297
inference_batch_times_seqlen_threshold=8192, # TODO
289298
inference_max_seq_length=8192, # TODO
290-
vortex_style_fp8=vortex_style_fp8,
291-
# use_te_rng_tracker=True,
292-
# te_rng_tracker=True,
293-
# inference_rng_tracker=True,
294-
# enable_cuda_graph=True,
295-
# cudagraph_rng_tracker=True,
296-
# flash_decode=True,
297299
recompute_granularity=None,
298300
recompute_num_layers=None,
299301
recompute_method=None,
302+
**kwargs,
300303
)
301304
return inference_wrapped_model, mcore_tokenizer
302305

303306

307+
def get_model_and_tokenizer(ckpt_name, vortex_style_fp8=False):
308+
return get_model_and_tokenizer_raw(ckpt_name, vortex_style_fp8=vortex_style_fp8)
309+
310+
311+
def get_model_and_tokenizer_ignore_vortex(ckpt_name, vortex_style_fp8=False):
312+
# Capture and remove the vortex_style_fp8 argument for mamba models.
313+
return get_model_and_tokenizer_raw(ckpt_name)
314+
315+
304316
def calc_matchrate(*, tokenizer, in_seq, logits):
305317
softmax_logprobs = torch.log_softmax(logits, dim=-1)
306318
softmax_logprobs = softmax_logprobs[:, :-1]
@@ -476,24 +488,30 @@ def calculate_sequence_identity(seq1: str, seq2: str) -> float | None:
476488

477489

478490
@pytest.mark.parametrize(
479-
"ckpt_name,expected_matchpercents",
491+
"ckpt_name,model_tokenizer_provider,expected_matchpercents",
480492
[
481-
("evo2/1b-8k-bf16:1.0", [96.8, 29.7, 76.6, 71.6]),
482-
("evo2/1b-8k:1.0", [96.8, 29.7, 76.6, 71.6]),
483-
# ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57]),
484-
# ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57]),
493+
("evo2/1b-8k-bf16:1.0", get_model_and_tokenizer, [96.8, 29.7, 76.6, 71.6]),
494+
("evo2/1b-8k:1.0", get_model_and_tokenizer, [96.8, 29.7, 76.6, 71.6]),
495+
("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, [99.2, 51.0, 73.0, 82.6]),
496+
# ("evo2/7b-8k:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
497+
# ("evo2/7b-1m:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
485498
],
486499
)
487-
def test_batch_generate(sequences: list[str], ckpt_name: str, expected_matchpercents: list[float]):
500+
def test_batch_generate(
501+
sequences: list[str], ckpt_name: str, model_tokenizer_provider: Callable, expected_matchpercents: list[float]
502+
):
488503
assert len(sequences) > 0
489504
is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device())
490505
skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported
491506
if skip:
492507
# This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device.
493508
pytest.skip(f"Skipping {ckpt_name} because it is not supported on {device_info} ({compute_capability})")
509+
if "evo2_mamba" in ckpt_name and os.environ.get("BIONEMO_DATA_SOURCE") != "pbss":
510+
# TODO: add evo2_mamba/7b-8k to NGC and remove this skip
511+
pytest.skip(f"Skipping {ckpt_name} because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`.")
494512
# only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support
495513
vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name
496-
inference_wrapped_model, mcore_tokenizer = get_model_and_tokenizer(ckpt_name, vortex_style_fp8=vortex_style_fp8)
514+
inference_wrapped_model, mcore_tokenizer = model_tokenizer_provider(ckpt_name, vortex_style_fp8=vortex_style_fp8)
497515

498516
match_percents = []
499517
num_tokens = 500

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_evo2_mamba_batch_generate.py

Lines changed: 0 additions & 181 deletions
This file was deleted.

0 commit comments

Comments
 (0)