Skip to content

Commit affb95a

Browse files
committed
Add support for inference using LoRA checkpoint
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent c97261d commit affb95a

4 files changed

Lines changed: 163 additions & 23 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,40 @@ rather than silently producing asymmetric behaviour.
395395
weights are always treated as a unit, and any asymmetric configuration will
396396
raise an error.
397397

398+
### Running inference on a LoRA checkpoint
399+
400+
A LoRA training checkpoint contains only adapter tensors — the base model weights
401+
are not duplicated. Point `--ckpt-dir` at the LoRA `iter_*` directory as usual:
402+
403+
```bash
404+
torchrun --nproc_per_node 1 --no-python \
405+
infer_evo2 \
406+
--ckpt-dir </path/to/lora_run/checkpoints/> \
407+
--prompt "ATCGATCGATCGATCG" \
408+
--max-new-tokens 200
409+
```
410+
411+
```bash
412+
torchrun --nproc_per_node 1 --no-python \
413+
predict_evo2 \
414+
--fasta <path/to/fasta/sequences> \
415+
--ckpt-dir </path/to/lora_run/checkpoints/> \
416+
--output-dir ./predictions
417+
```
418+
419+
When `infer_evo2` / `predict_evo2` detect a `peft` section in the checkpoint's
420+
`run_config.yaml`, they:
421+
422+
1. load dense base weights from `checkpoint.pretrained_checkpoint` (the same
423+
value that was supplied during LoRA training),
424+
2. apply the stored PEFT config (`run_config["peft"]`) to graft `LoRALinear`
425+
wrappers onto the base modules,
426+
3. load only the adapter tensors from `--ckpt-dir`.
427+
428+
No merge step is required. The base checkpoint referenced by
429+
`pretrained_checkpoint` must still exist on disk at the path recorded in
430+
`run_config.yaml`.
431+
398432
## Exporting to Vortex format
399433

400434
Vortex is ARC Institute's inference format for Evo2 Hyena models, used by the

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@
7070

7171
import torch
7272
import torch.distributed as dist
73-
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
73+
from megatron.bridge.training.checkpointing import (
74+
_generate_model_state_dict,
75+
_load_model_weights_from_checkpoint,
76+
apply_peft_adapter_filter_to_state_dict,
77+
)
7478
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7579
from megatron.bridge.training.mixed_precision import get_mixed_precision_config
7680
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
@@ -81,7 +85,7 @@
8185
)
8286
from megatron.bridge.utils.common_utils import get_world_size_safe
8387
from megatron.bridge.utils.instantiate_utils import instantiate
84-
from megatron.core import parallel_state
88+
from megatron.core import dist_checkpointing, parallel_state
8589
from megatron.core.inference.contexts import StaticInferenceContext
8690
from megatron.core.inference.engines.static_engine import StaticInferenceEngine
8791
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
@@ -462,12 +466,35 @@ def setup_inference_engine(
462466

463467
raw_model = model_provider.provide().eval().cuda()
464468

465-
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
466-
_load_model_weights_from_checkpoint(
467-
checkpoint_path=str(resolved_ckpt_dir),
468-
model=[raw_model],
469-
dist_ckpt_strictness="ignore_all",
470-
)
469+
# A LoRA finetune checkpoint only contains adapter tensors; the base weights live in
470+
# run_config["checkpoint"]["pretrained_checkpoint"]. Detect via the top-level `peft:`
471+
# section (same signal `peft_pre_wrap_hook` uses during training).
472+
peft_node = run_config.get("peft")
473+
if peft_node is not None:
474+
# pretrained_checkpoint may point at a training-output parent containing iter_*; resolve.
475+
resolved_pretrained_dir = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"]))
476+
logger.info(f"PEFT checkpoint detected. Loading base weights from: {resolved_pretrained_dir}")
477+
_load_model_weights_from_checkpoint(
478+
checkpoint_path=str(resolved_pretrained_dir),
479+
model=[raw_model],
480+
dist_ckpt_strictness="ignore_all",
481+
)
482+
483+
logger.info("Applying PEFT adapter structure to base model")
484+
peft_cfg = instantiate(peft_node)
485+
raw_model = peft_cfg(raw_model, training=False)
486+
487+
logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}")
488+
sharded_sd = apply_peft_adapter_filter_to_state_dict(_generate_model_state_dict([raw_model], {}), peft_cfg)
489+
loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all")
490+
raw_model.load_state_dict(loaded["model"], strict=False)
491+
else:
492+
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
493+
_load_model_weights_from_checkpoint(
494+
checkpoint_path=str(resolved_ckpt_dir),
495+
model=[raw_model],
496+
dist_ckpt_strictness="ignore_all",
497+
)
471498
logger.info("Weights loaded successfully")
472499

473500
# Wrap with Float16Module

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@
6969
import torch
7070
import torch.distributed as dist
7171
from megatron.bridge.data.samplers import build_pretraining_data_loader
72-
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
72+
from megatron.bridge.training.checkpointing import (
73+
_generate_model_state_dict,
74+
_load_model_weights_from_checkpoint,
75+
apply_peft_adapter_filter_to_state_dict,
76+
)
7377
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7478
from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config
7579
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
@@ -86,7 +90,7 @@
8690
get_world_size_safe,
8791
)
8892
from megatron.bridge.utils.instantiate_utils import instantiate
89-
from megatron.core import parallel_state, tensor_parallel
93+
from megatron.core import dist_checkpointing, parallel_state, tensor_parallel
9094
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator
9195
from megatron.core.tensor_parallel.mappings import _gather_along_last_dim
9296
from megatron.core.transformer.module import Float16Module
@@ -1117,12 +1121,36 @@ def predict(
11171121
else:
11181122
logger.warning("Could not determine number of layers from model structure")
11191123

1120-
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
1121-
_load_model_weights_from_checkpoint(
1122-
checkpoint_path=str(resolved_ckpt_dir),
1123-
model=model,
1124-
dist_ckpt_strictness="ignore_all",
1125-
)
1124+
peft_section = run_config.get("peft")
1125+
if peft_section is not None:
1126+
pretrained_ckpt = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"]))
1127+
logger.info(f"Loading base model weights from: {pretrained_ckpt}")
1128+
_load_model_weights_from_checkpoint(
1129+
checkpoint_path=str(pretrained_ckpt),
1130+
model=model,
1131+
dist_ckpt_strictness="ignore_all",
1132+
)
1133+
1134+
unwrapped = [m.module for m in model]
1135+
peft_cfg = instantiate(peft_section)
1136+
peft_cfg(unwrapped, training=False)
1137+
1138+
logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}")
1139+
sharded_sd = _generate_model_state_dict(unwrapped, {})
1140+
sharded_sd = apply_peft_adapter_filter_to_state_dict(sharded_sd, peft_cfg)
1141+
loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all")
1142+
if len(unwrapped) == 1:
1143+
unwrapped[0].load_state_dict(loaded["model"], strict=False)
1144+
else:
1145+
for i, inner in enumerate(unwrapped):
1146+
inner.load_state_dict(loaded[f"model{i}"], strict=False)
1147+
else:
1148+
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
1149+
_load_model_weights_from_checkpoint(
1150+
checkpoint_path=str(resolved_ckpt_dir),
1151+
model=model,
1152+
dist_ckpt_strictness="ignore_all",
1153+
)
11261154
logger.info("Weights loaded successfully")
11271155

11281156
# -------------------------------------------------------------------------

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,64 @@ def test_predict_evo2_equivalent_with_log_probs(
429429
assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel)
430430

431431

432-
# Note: The PEFT/LoRA test is commented out as it requires training infrastructure and LoRA support
433-
# which may need additional updates for the Megatron Bridge API
434-
# @pytest.mark.timeout(512)
435-
# @pytest.mark.slow
436-
# def test_different_results_with_without_peft(tmp_path):
437-
# """Test that predictions differ when using PEFT/LoRA adapters."""
438-
# pass
432+
@pytest.mark.timeout(512)
433+
@pytest.mark.slow
434+
def test_different_results_with_without_peft(tmp_path, mbridge_checkpoint_1b_8k_bf16_path):
435+
"""LoRA-finetune a few steps off the base ckpt, then predict on base vs. LoRA ckpt and assert logits differ."""
436+
num_steps = 2
437+
result_dir = tmp_path / "lora_finetune"
438+
env = copy.deepcopy(PRETEST_ENV)
439+
if is_a6000_gpu():
440+
env["NCCL_P2P_DISABLE"] = "1"
441+
442+
ft_port = find_free_network_port()
443+
ft_cmd = (
444+
f"torchrun --nproc-per-node 1 --no-python --master_port {ft_port} "
445+
f"train_evo2 --finetune-ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path.parent} "
446+
f"--lora-finetune --lora-dim 8 --lora-alpha 16 "
447+
f"--lora-target-modules linear_qkv,linear_proj,linear_fc1,linear_fc2 "
448+
f"--hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH_512} "
449+
f"--model-size evo2_1b_base --max-steps {num_steps} --eval-interval {num_steps} --eval-iters 1 "
450+
f"--mock-data --result-dir {result_dir} --mixed-precision-recipe bf16_mixed "
451+
f"--micro-batch-size 1 --global-batch-size 1 --seq-length 512 "
452+
f"--ckpt-format torch_dist --log-interval 1 --decay-steps 100 --warmup-steps 1 "
453+
f"--seed 42 --dataset-seed 33"
454+
)
455+
ft_result = subprocess.run(shlex.split(ft_cmd), check=False, capture_output=True, text=True, cwd=tmp_path, env=env)
456+
assert ft_result.returncode == 0, (
457+
f"LoRA finetune failed:\nSTDOUT:\n{ft_result.stdout}\nSTDERR:\n{ft_result.stderr}"
458+
)
459+
lora_ckpt = result_dir / "evo2" / "checkpoints" / f"iter_{num_steps:07d}"
460+
assert lora_ckpt.exists(), f"Expected LoRA checkpoint at {lora_ckpt}"
461+
462+
fasta_file_path = tmp_path / "test.fasta"
463+
create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE)
464+
465+
def _run_predict(ckpt: Path, output_dir: Path) -> None:
466+
port = find_free_network_port()
467+
cmd = (
468+
f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {port} "
469+
f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {ckpt} "
470+
f"--output-dir {output_dir} --micro-batch-size 3 --write-interval epoch "
471+
f"--pipeline-model-parallel-size 1 --num-nodes 1 --devices 1"
472+
)
473+
r = subprocess.run(shlex.split(cmd), check=False, cwd=tmp_path, capture_output=True, text=True, env=env)
474+
assert r.returncode == 0, f"predict_evo2 failed:\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}"
475+
476+
out_base = tmp_path / "out_base"
477+
out_lora = tmp_path / "out_lora"
478+
_run_predict(mbridge_checkpoint_1b_8k_bf16_path, out_base)
479+
_run_predict(lora_ckpt, out_lora)
480+
481+
base_files = glob.glob(str(out_base / "predictions__rank_*__dp_rank_*.pt"))
482+
lora_files = glob.glob(str(out_lora / "predictions__rank_*__dp_rank_*.pt"))
483+
assert len(base_files) == 1 and len(lora_files) == 1
484+
485+
base = torch.load(base_files[0], weights_only=False)
486+
lora = torch.load(lora_files[0], weights_only=False)
487+
assert torch.equal(base["seq_idx"], lora["seq_idx"])
488+
assert base["token_logits"].shape == lora["token_logits"].shape
489+
assert (base["token_logits"] != lora["token_logits"]).any(), "LoRA adapter had no effect on logits"
439490

440491

441492
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)