Skip to content

Commit 6d65e2e

Browse files
committed
Add support for inference using LoRA checkpoint
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent c97261d commit 6d65e2e

3 files changed

Lines changed: 105 additions & 16 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,40 @@ rather than silently producing asymmetric behaviour.
395395
weights are always treated as a unit, and any asymmetric configuration will
396396
raise an error.
397397

398+
### Running inference on a LoRA checkpoint
399+
400+
A LoRA training checkpoint contains only adapter tensors — the base model weights
401+
are not duplicated. Point `--ckpt-dir` at the LoRA `iter_*` directory as usual:
402+
403+
```bash
404+
torchrun --nproc_per_node 1 --no-python \
405+
infer_evo2 \
406+
--ckpt-dir /path/to/lora_run/checkpoints/iter_0000250 \
407+
--prompt "ATCGATCGATCGATCG" \
408+
--max-new-tokens 200
409+
```
410+
411+
```bash
412+
torchrun --nproc_per_node 1 --no-python \
413+
predict_evo2 \
414+
--fasta sequences.fa \
415+
--ckpt-dir /path/to/lora_run/checkpoints/iter_0000250 \
416+
--output-dir ./predictions
417+
```
418+
419+
When `infer_evo2` / `predict_evo2` detect a `peft` section in the checkpoint's
420+
`run_config.yaml`, they:
421+
422+
1. load dense base weights from `checkpoint.pretrained_checkpoint` (the same
423+
value that was supplied during LoRA training),
424+
2. apply the stored PEFT config (`run_config["peft"]`) to graft `LoRALinear`
425+
wrappers onto the base modules,
426+
3. load only the adapter tensors from `--ckpt-dir`.
427+
428+
No merge step is required. The base checkpoint referenced by
429+
`pretrained_checkpoint` must still exist on disk at the path recorded in
430+
`run_config.yaml`.
431+
398432
## Exporting to Vortex format
399433

400434
Vortex is ARC Institute's inference format for Evo2 Hyena models, used by the

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@
7070

7171
import torch
7272
import torch.distributed as dist
73-
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
73+
from megatron.bridge.training.checkpointing import (
74+
_generate_model_state_dict,
75+
_load_model_weights_from_checkpoint,
76+
apply_peft_adapter_filter_to_state_dict,
77+
)
7478
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7579
from megatron.bridge.training.mixed_precision import get_mixed_precision_config
7680
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
@@ -81,7 +85,7 @@
8185
)
8286
from megatron.bridge.utils.common_utils import get_world_size_safe
8387
from megatron.bridge.utils.instantiate_utils import instantiate
84-
from megatron.core import parallel_state
88+
from megatron.core import dist_checkpointing, parallel_state
8589
from megatron.core.inference.contexts import StaticInferenceContext
8690
from megatron.core.inference.engines.static_engine import StaticInferenceEngine
8791
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
@@ -462,12 +466,35 @@ def setup_inference_engine(
462466

463467
raw_model = model_provider.provide().eval().cuda()
464468

465-
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
466-
_load_model_weights_from_checkpoint(
467-
checkpoint_path=str(resolved_ckpt_dir),
468-
model=[raw_model],
469-
dist_ckpt_strictness="ignore_all",
470-
)
469+
# A LoRA finetune checkpoint only contains adapter tensors; the base weights live in
470+
# run_config["checkpoint"]["pretrained_checkpoint"]. Detect via the top-level `peft:`
471+
# section (same signal `peft_pre_wrap_hook` uses during training).
472+
peft_node = run_config.get("peft")
473+
if peft_node is not None:
474+
# pretrained_checkpoint may point at a training-output parent containing iter_*; resolve.
475+
resolved_pretrained_dir = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"]))
476+
logger.info(f"PEFT checkpoint detected. Loading base weights from: {resolved_pretrained_dir}")
477+
_load_model_weights_from_checkpoint(
478+
checkpoint_path=str(resolved_pretrained_dir),
479+
model=[raw_model],
480+
dist_ckpt_strictness="ignore_all",
481+
)
482+
483+
logger.info("Applying PEFT adapter structure to base model")
484+
peft_cfg = instantiate(peft_node)
485+
raw_model = peft_cfg(raw_model, training=False)
486+
487+
logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}")
488+
sharded_sd = apply_peft_adapter_filter_to_state_dict(_generate_model_state_dict([raw_model], {}), peft_cfg)
489+
loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all")
490+
raw_model.load_state_dict(loaded["model"], strict=False)
491+
else:
492+
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
493+
_load_model_weights_from_checkpoint(
494+
checkpoint_path=str(resolved_ckpt_dir),
495+
model=[raw_model],
496+
dist_ckpt_strictness="ignore_all",
497+
)
471498
logger.info("Weights loaded successfully")
472499

473500
# Wrap with Float16Module

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@
6969
import torch
7070
import torch.distributed as dist
7171
from megatron.bridge.data.samplers import build_pretraining_data_loader
72-
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
72+
from megatron.bridge.training.checkpointing import (
73+
_generate_model_state_dict,
74+
_load_model_weights_from_checkpoint,
75+
apply_peft_adapter_filter_to_state_dict,
76+
)
7377
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7478
from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config
7579
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
@@ -86,7 +90,7 @@
8690
get_world_size_safe,
8791
)
8892
from megatron.bridge.utils.instantiate_utils import instantiate
89-
from megatron.core import parallel_state, tensor_parallel
93+
from megatron.core import dist_checkpointing, parallel_state, tensor_parallel
9094
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator
9195
from megatron.core.tensor_parallel.mappings import _gather_along_last_dim
9296
from megatron.core.transformer.module import Float16Module
@@ -1117,12 +1121,36 @@ def predict(
11171121
else:
11181122
logger.warning("Could not determine number of layers from model structure")
11191123

1120-
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
1121-
_load_model_weights_from_checkpoint(
1122-
checkpoint_path=str(resolved_ckpt_dir),
1123-
model=model,
1124-
dist_ckpt_strictness="ignore_all",
1125-
)
1124+
peft_section = run_config.get("peft")
1125+
if peft_section is not None:
1126+
pretrained_ckpt = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"]))
1127+
logger.info(f"Loading base model weights from: {pretrained_ckpt}")
1128+
_load_model_weights_from_checkpoint(
1129+
checkpoint_path=str(pretrained_ckpt),
1130+
model=model,
1131+
dist_ckpt_strictness="ignore_all",
1132+
)
1133+
1134+
unwrapped = [m.module for m in model]
1135+
peft_cfg = instantiate(peft_section)
1136+
peft_cfg(unwrapped, training=False)
1137+
1138+
logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}")
1139+
sharded_sd = _generate_model_state_dict(unwrapped, {})
1140+
sharded_sd = apply_peft_adapter_filter_to_state_dict(sharded_sd, peft_cfg)
1141+
loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all")
1142+
if len(unwrapped) == 1:
1143+
unwrapped[0].load_state_dict(loaded["model"], strict=False)
1144+
else:
1145+
for i, inner in enumerate(unwrapped):
1146+
inner.load_state_dict(loaded[f"model{i}"], strict=False)
1147+
else:
1148+
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
1149+
_load_model_weights_from_checkpoint(
1150+
checkpoint_path=str(resolved_ckpt_dir),
1151+
model=model,
1152+
dist_ckpt_strictness="ignore_all",
1153+
)
11261154
logger.info("Weights loaded successfully")
11271155

11281156
# -------------------------------------------------------------------------

0 commit comments

Comments
 (0)