Skip to content

Commit 7775b41

Browse files
committed
Add support for inference using LoRA checkpoint
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent c97261d commit 7775b41

6 files changed

Lines changed: 236 additions & 23 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,40 @@ rather than silently producing asymmetric behaviour.
395395
weights are always treated as a unit, and any asymmetric configuration will
396396
raise an error.
397397

398+
### Running inference on a LoRA checkpoint
399+
400+
A LoRA training checkpoint contains only adapter tensors — the base model weights
401+
are not duplicated. Point `--ckpt-dir` at the LoRA `iter_*` directory as usual:
402+
403+
```bash
404+
torchrun --nproc_per_node 1 --no-python \
405+
infer_evo2 \
406+
--ckpt-dir </path/to/lora_run/checkpoints/> \
407+
--prompt "ATCGATCGATCGATCG" \
408+
--max-new-tokens 200
409+
```
410+
411+
```bash
412+
torchrun --nproc_per_node 1 --no-python \
413+
predict_evo2 \
414+
--fasta <path/to/fasta/sequences> \
415+
--ckpt-dir </path/to/lora_run/checkpoints/> \
416+
--output-dir ./predictions
417+
```
418+
419+
When `infer_evo2` / `predict_evo2` detect a `peft` section in the checkpoint's
420+
`run_config.yaml`, they:
421+
422+
1. load dense base weights from `checkpoint.pretrained_checkpoint` (the same
423+
value that was supplied during LoRA training),
424+
2. apply the stored PEFT config (`run_config["peft"]`) to graft `LoRALinear`
425+
wrappers onto the base modules,
426+
3. load only the adapter tensors from `--ckpt-dir`.
427+
428+
No merge step is required. The base checkpoint referenced by
429+
`pretrained_checkpoint` must still exist on disk at the path recorded in
430+
`run_config.yaml`.
431+
398432
## Exporting to Vortex format
399433

400434
Vortex is ARC Institute's inference format for Evo2 Hyena models, used by the

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@
7070

7171
import torch
7272
import torch.distributed as dist
73-
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
73+
from megatron.bridge.training.checkpointing import (
74+
_generate_model_state_dict,
75+
_load_model_weights_from_checkpoint,
76+
apply_peft_adapter_filter_to_state_dict,
77+
)
7478
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7579
from megatron.bridge.training.mixed_precision import get_mixed_precision_config
7680
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
@@ -81,7 +85,7 @@
8185
)
8286
from megatron.bridge.utils.common_utils import get_world_size_safe
8387
from megatron.bridge.utils.instantiate_utils import instantiate
84-
from megatron.core import parallel_state
88+
from megatron.core import dist_checkpointing, parallel_state
8589
from megatron.core.inference.contexts import StaticInferenceContext
8690
from megatron.core.inference.engines.static_engine import StaticInferenceEngine
8791
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
@@ -462,12 +466,35 @@ def setup_inference_engine(
462466

463467
raw_model = model_provider.provide().eval().cuda()
464468

465-
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
466-
_load_model_weights_from_checkpoint(
467-
checkpoint_path=str(resolved_ckpt_dir),
468-
model=[raw_model],
469-
dist_ckpt_strictness="ignore_all",
470-
)
469+
# A LoRA finetune checkpoint only contains adapter tensors; the base weights live in
470+
# run_config["checkpoint"]["pretrained_checkpoint"]. Detect via the top-level `peft:`
471+
# section (same signal `peft_pre_wrap_hook` uses during training).
472+
peft_node = run_config.get("peft")
473+
if peft_node is not None:
474+
# pretrained_checkpoint may point at a training-output parent containing iter_*; resolve.
475+
resolved_pretrained_dir = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"]))
476+
logger.info(f"PEFT checkpoint detected. Loading base weights from: {resolved_pretrained_dir}")
477+
_load_model_weights_from_checkpoint(
478+
checkpoint_path=str(resolved_pretrained_dir),
479+
model=[raw_model],
480+
dist_ckpt_strictness="ignore_all",
481+
)
482+
483+
logger.info("Applying PEFT adapter structure to base model")
484+
peft_cfg = instantiate(peft_node)
485+
raw_model = peft_cfg(raw_model, training=False)
486+
487+
logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}")
488+
sharded_sd = apply_peft_adapter_filter_to_state_dict(_generate_model_state_dict([raw_model], {}), peft_cfg)
489+
loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all")
490+
raw_model.load_state_dict(loaded["model"], strict=False)
491+
else:
492+
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
493+
_load_model_weights_from_checkpoint(
494+
checkpoint_path=str(resolved_ckpt_dir),
495+
model=[raw_model],
496+
dist_ckpt_strictness="ignore_all",
497+
)
471498
logger.info("Weights loaded successfully")
472499

473500
# Wrap with Float16Module

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@
6969
import torch
7070
import torch.distributed as dist
7171
from megatron.bridge.data.samplers import build_pretraining_data_loader
72-
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
72+
from megatron.bridge.training.checkpointing import (
73+
_generate_model_state_dict,
74+
_load_model_weights_from_checkpoint,
75+
apply_peft_adapter_filter_to_state_dict,
76+
)
7377
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7478
from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config
7579
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
@@ -86,7 +90,7 @@
8690
get_world_size_safe,
8791
)
8892
from megatron.bridge.utils.instantiate_utils import instantiate
89-
from megatron.core import parallel_state, tensor_parallel
93+
from megatron.core import dist_checkpointing, parallel_state, tensor_parallel
9094
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator
9195
from megatron.core.tensor_parallel.mappings import _gather_along_last_dim
9296
from megatron.core.transformer.module import Float16Module
@@ -1117,12 +1121,36 @@ def predict(
11171121
else:
11181122
logger.warning("Could not determine number of layers from model structure")
11191123

1120-
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
1121-
_load_model_weights_from_checkpoint(
1122-
checkpoint_path=str(resolved_ckpt_dir),
1123-
model=model,
1124-
dist_ckpt_strictness="ignore_all",
1125-
)
1124+
peft_section = run_config.get("peft")
1125+
if peft_section is not None:
1126+
pretrained_ckpt = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"]))
1127+
logger.info(f"Loading base model weights from: {pretrained_ckpt}")
1128+
_load_model_weights_from_checkpoint(
1129+
checkpoint_path=str(pretrained_ckpt),
1130+
model=model,
1131+
dist_ckpt_strictness="ignore_all",
1132+
)
1133+
1134+
unwrapped = [m.module for m in model]
1135+
peft_cfg = instantiate(peft_section)
1136+
peft_cfg(unwrapped, training=False)
1137+
1138+
logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}")
1139+
sharded_sd = _generate_model_state_dict(unwrapped, {})
1140+
sharded_sd = apply_peft_adapter_filter_to_state_dict(sharded_sd, peft_cfg)
1141+
loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all")
1142+
if len(unwrapped) == 1:
1143+
unwrapped[0].load_state_dict(loaded["model"], strict=False)
1144+
else:
1145+
for i, inner in enumerate(unwrapped):
1146+
inner.load_state_dict(loaded[f"model{i}"], strict=False)
1147+
else:
1148+
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
1149+
_load_model_weights_from_checkpoint(
1150+
checkpoint_path=str(resolved_ckpt_dir),
1151+
model=model,
1152+
dist_ckpt_strictness="ignore_all",
1153+
)
11261154
logger.info("Weights loaded successfully")
11271155

11281156
# -------------------------------------------------------------------------

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616

1717
# conftest.py
18+
import copy
1819
import gc
1920
import os
21+
import shlex
22+
import subprocess
2023
from pathlib import Path
2124

2225
import pytest
@@ -26,6 +29,8 @@
2629
from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH_512
2730
from bionemo.evo2.utils.checkpoint.nemo2_to_mbridge import run_nemo2_to_mbridge
2831

32+
from .utils import find_free_network_port, is_a6000_gpu
33+
2934

3035
def get_device_and_memory_allocated() -> str:
3136
"""Get the current device index, name, and memory usage."""
@@ -139,3 +144,43 @@ def mbridge_checkpoint_path(mbridge_checkpoint_1b_8k_bf16) -> Path:
139144
Path to the MBridge checkpoint iteration directory
140145
"""
141146
return mbridge_checkpoint_1b_8k_bf16
147+
148+
149+
@pytest.fixture(scope="session")
150+
def lora_finetune_checkpoint(mbridge_checkpoint_1b_8k_bf16, tmp_path_factory) -> Path:
151+
"""Session-scoped LoRA-finetuned checkpoint produced from ``mbridge_checkpoint_1b_8k_bf16``.
152+
153+
Runs ``train_evo2 --lora-finetune`` for 2 steps with mock data so downstream tests
154+
can exercise PEFT-aware load paths (infer/predict) against a checkpoint whose adapter
155+
weights differ from their init values. Shared across test files to avoid doing the
156+
finetune more than once per session.
157+
158+
Returns:
159+
Path to the ``iter_0000002/`` directory of the LoRA adapter checkpoint.
160+
"""
161+
num_steps = 2
162+
result_dir = tmp_path_factory.mktemp("lora_finetune_session") / "lora_finetune"
163+
env = copy.deepcopy(os.environ)
164+
if is_a6000_gpu():
165+
env["NCCL_P2P_DISABLE"] = "1"
166+
167+
port = find_free_network_port()
168+
cmd = (
169+
f"torchrun --nproc-per-node 1 --no-python --master_port {port} "
170+
f"train_evo2 --finetune-ckpt-dir {mbridge_checkpoint_1b_8k_bf16.parent} "
171+
f"--lora-finetune --lora-dim 8 --lora-alpha 16 "
172+
f"--lora-target-modules linear_qkv,linear_proj,linear_fc1,linear_fc2 "
173+
f"--hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH_512} "
174+
f"--model-size evo2_1b_base --max-steps {num_steps} --eval-interval {num_steps} --eval-iters 1 "
175+
f"--mock-data --result-dir {result_dir} --mixed-precision-recipe bf16_mixed "
176+
f"--micro-batch-size 1 --global-batch-size 1 --seq-length 512 "
177+
f"--ckpt-format torch_dist --log-interval 1 --decay-steps 100 --warmup-steps 1 "
178+
f"--seed 42 --dataset-seed 33 --disable-tensorboard-logger"
179+
)
180+
result = subprocess.run(
181+
shlex.split(cmd), check=False, capture_output=True, text=True, cwd=result_dir.parent, env=env
182+
)
183+
assert result.returncode == 0, f"LoRA finetune fixture failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
184+
lora_ckpt = result_dir / "evo2" / "checkpoints" / f"iter_{num_steps:07d}"
185+
assert lora_ckpt.exists(), f"Expected LoRA checkpoint at {lora_ckpt}"
186+
return lora_ckpt

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,56 @@ def test_savanna_to_mbridge_inference_accuracy_7b(mbridge_checkpoint_7b_from_sav
896896
)
897897

898898

899+
@pytest.mark.timeout(512)
900+
@pytest.mark.slow
901+
def test_different_results_with_without_peft(tmp_path, mbridge_checkpoint_path, lora_finetune_checkpoint):
902+
"""Greedy-generate from the base ckpt vs. the LoRA ckpt and assert the logprobs differ."""
903+
env = copy.deepcopy(PRETEST_ENV)
904+
# 64-char prompt for FP8 divisibility.
905+
prompt = "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG"
906+
907+
def _run_infer(ckpt: Path, output_file: Path) -> dict:
908+
port = find_free_network_port()
909+
cmd = [
910+
"torchrun",
911+
"--nproc_per_node",
912+
"1",
913+
"--nnodes",
914+
"1",
915+
"--master_port",
916+
str(port),
917+
"-m",
918+
"bionemo.evo2.run.infer",
919+
"--ckpt-dir",
920+
str(ckpt),
921+
"--prompt",
922+
prompt,
923+
"--max-new-tokens",
924+
"10",
925+
"--temperature",
926+
"1.0",
927+
"--top-k",
928+
"1",
929+
"--seed",
930+
"0",
931+
"--return-log-probs",
932+
"--output-file",
933+
str(output_file),
934+
]
935+
r = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=300, env=env)
936+
assert r.returncode == 0, f"infer_evo2 failed:\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}"
937+
with open(output_file) as f:
938+
return json.loads(f.readline())
939+
940+
base = _run_infer(mbridge_checkpoint_path, tmp_path / "out_base.jsonl")
941+
lora = _run_infer(lora_finetune_checkpoint, tmp_path / "out_lora.jsonl")
942+
943+
base_lp = base["logprobs"]["completion_logprobs"]
944+
lora_lp = lora["logprobs"]["completion_logprobs"]
945+
assert len(base_lp) == len(lora_lp), f"Different completion lengths: {len(base_lp)} vs {len(lora_lp)}"
946+
assert base_lp != lora_lp, "LoRA adapter had no effect on completion logprobs"
947+
948+
899949
class TestHyenaInferenceContext:
900950
"""Unit tests for the Hyena-specific inference context."""
901951

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,42 @@ def test_predict_evo2_equivalent_with_log_probs(
429429
assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel)
430430

431431

432-
# Note: The PEFT/LoRA test is commented out as it requires training infrastructure and LoRA support
433-
# which may need additional updates for the Megatron Bridge API
434-
# @pytest.mark.timeout(512)
435-
# @pytest.mark.slow
436-
# def test_different_results_with_without_peft(tmp_path):
437-
# """Test that predictions differ when using PEFT/LoRA adapters."""
438-
# pass
432+
@pytest.mark.timeout(512)
433+
@pytest.mark.slow
434+
def test_different_results_with_without_peft(tmp_path, mbridge_checkpoint_1b_8k_bf16_path, lora_finetune_checkpoint):
435+
"""Predict on base vs. LoRA ckpt and assert logits differ."""
436+
env = copy.deepcopy(PRETEST_ENV)
437+
if is_a6000_gpu():
438+
env["NCCL_P2P_DISABLE"] = "1"
439+
440+
fasta_file_path = tmp_path / "test.fasta"
441+
create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE)
442+
443+
def _run_predict(ckpt: Path, output_dir: Path) -> None:
444+
port = find_free_network_port()
445+
cmd = (
446+
f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {port} "
447+
f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {ckpt} "
448+
f"--output-dir {output_dir} --micro-batch-size 3 --write-interval epoch "
449+
f"--pipeline-model-parallel-size 1 --num-nodes 1 --devices 1"
450+
)
451+
r = subprocess.run(shlex.split(cmd), check=False, cwd=tmp_path, capture_output=True, text=True, env=env)
452+
assert r.returncode == 0, f"predict_evo2 failed:\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}"
453+
454+
out_base = tmp_path / "out_base"
455+
out_lora = tmp_path / "out_lora"
456+
_run_predict(mbridge_checkpoint_1b_8k_bf16_path, out_base)
457+
_run_predict(lora_finetune_checkpoint, out_lora)
458+
459+
base_files = glob.glob(str(out_base / "predictions__rank_*__dp_rank_*.pt"))
460+
lora_files = glob.glob(str(out_lora / "predictions__rank_*__dp_rank_*.pt"))
461+
assert len(base_files) == 1 and len(lora_files) == 1
462+
463+
base = torch.load(base_files[0], weights_only=False)
464+
lora = torch.load(lora_files[0], weights_only=False)
465+
assert torch.equal(base["seq_idx"], lora["seq_idx"])
466+
assert base["token_logits"].shape == lora["token_logits"].shape
467+
assert (base["token_logits"] != lora["token_logits"]).any(), "LoRA adapter had no effect on logits"
439468

440469

441470
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)