Skip to content

Commit cb9f079

Browse files
committed
Add LoRA example
1 parent 86c8329 commit cb9f079

5 files changed

Lines changed: 256 additions & 108 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# syntax=docker/dockerfile:1.4
2-
FROM nvcr.io/nvidia/pytorch:26.03-py3
2+
FROM nvcr.io/nvidia/pytorch:26.02-py3
33

44
# uv is pre-installed in the nvcr.io/nvidia/pytorch base image.
55
# If using a base image without uv, uncomment the following line:

bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb

Lines changed: 109 additions & 92 deletions
Large diffs are not rendered by default.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import logging
18+
from dataclasses import dataclass, field
19+
20+
import torch
21+
from megatron.bridge.peft.base import ModelType
22+
from megatron.bridge.peft.lora import LoRA
23+
from megatron.bridge.peft.utils import wildcard_match
24+
from torch import nn
25+
26+
27+
logger: logging.Logger = logging.getLogger(__name__)
28+
29+
30+
@dataclass
31+
class Evo2LoRA(LoRA):
32+
"""LoRA variant that allows selectively skipping parameter freezing for specified modules.
33+
34+
Extends LoRA with a ``skip_freeze_modules`` field that follows the same pattern-matching
35+
semantics as ``target_modules``:
36+
37+
- Exact short name: ``"mixer"`` matches any module whose immediate name is ``"mixer"``,
38+
regardless of depth.
39+
- Wildcard on full path: ``"*.layers.0.*.mixer"`` matches using ``*`` as a substring
40+
wildcard anchored to the full dotted path.
41+
42+
Args:
43+
skip_freeze_modules: List of module name patterns to exclude from freezing.
44+
Supports the same syntax as ``target_modules``. Modules whose short name or
45+
full path matches any pattern will remain trainable.
46+
"""
47+
48+
skip_freeze_modules: list[str] = field(default_factory=list)
49+
50+
def freeze_model(self, model: ModelType, training: bool = True) -> None:
51+
"""Freeze all model parameters except those matching ``skip_freeze_modules``.
52+
53+
Args:
54+
model: The model (or list of model chunks) to freeze.
55+
training: Whether the model is being used for training. When True, sets
56+
the model to training mode after freezing.
57+
"""
58+
matched_patterns: set[str] = set()
59+
60+
def selective_freeze(module: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
61+
full_name = f"{prefix}.{name}" if prefix else (name or "")
62+
short_name = name or ""
63+
matched = [p for p in self.skip_freeze_modules if short_name == p or wildcard_match(p, full_name)]
64+
if not matched:
65+
for param in module.parameters(recurse=False):
66+
param.requires_grad = False
67+
else:
68+
matched_patterns.update(matched)
69+
logger.info(f"Evo2LoRA: Skipping freezing module: {full_name}.")
70+
return module
71+
72+
self._walk_model(model, selective_freeze)
73+
74+
for p in self.skip_freeze_modules:
75+
if p not in matched_patterns:
76+
logger.warning(f"Evo2LoRA: skip_freeze_modules pattern '{p}' did not match any module.")
77+
78+
if training:
79+
if isinstance(model, list):
80+
for model_chunk in model:
81+
model_chunk.train(mode=True)
82+
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
83+
model.module.train(mode=True)
84+
else:
85+
model.train(mode=True)

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919

2020
import torch
21+
from megatron.bridge.peft.lora import LoRA
2122
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
2223
from megatron.bridge.training.comm_overlap import CommOverlapConfig
2324
from megatron.bridge.training.config import (
@@ -89,6 +90,11 @@ class Evo2CommonKwargs(TypedDict, total=False):
8990
comm_overlap_config: CommOverlapConfig | None
9091
pad_eod_loss_mask: bool
9192
no_weight_decay_embeddings: bool
93+
lora_finetune: bool
94+
lora_alpha: int
95+
lora_dim: int
96+
lora_dropout: float
97+
lora_target_modules: list[str]
9298

9399

94100
def evo2_1b_pretrain_config(**user_kwargs: Unpack[Evo2CommonKwargs]) -> ConfigContainer:
@@ -159,6 +165,11 @@ def _evo2_common(
159165
comm_overlap_config: CommOverlapConfig | None = None,
160166
no_weight_decay_embeddings: bool = False,
161167
pad_eod_loss_mask: bool = False,
168+
lora_finetune: bool = False,
169+
lora_alpha: int = 32,
170+
lora_dim: int = 16,
171+
lora_dropout: float = 0.1,
172+
lora_target_modules: list[str] = ["dense_projection", "linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"],
162173
) -> ConfigContainer:
163174
"""Create a pre-training configuration for Mamba 2.x models.
164175
@@ -233,6 +244,16 @@ def _evo2_common(
233244
min_lr=min_lr,
234245
)
235246

247+
if lora_finetune:
248+
peft = LoRA(
249+
target_modules=lora_target_modules,
250+
dim=lora_dim,
251+
alpha=lora_alpha,
252+
dropout=lora_dropout,
253+
)
254+
else:
255+
peft = None
256+
236257
cfg = ConfigContainer(
237258
model=model_cfg,
238259
train=TrainingConfig(
@@ -289,6 +310,7 @@ def _evo2_common(
289310
rng=RNGConfig(seed=seed),
290311
comm_overlap=comm_overlap_config,
291312
mixed_precision=precision_config,
313+
peft=peft,
292314
)
293315

294316
return cfg

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,6 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
594594
# help="Disable saving the last checkpoint.",
595595
# ) # TODO implement
596596
# parser.add_argument(
597-
# "--lora-finetune", action="store_true", help="Use LoRA fine-tuning", default=False
598-
# ) # TODO implement
599-
# parser.add_argument(
600597
# "--lora-checkpoint-path", type=str, default=None, help="LoRA checkpoint path"
601598
# ) # TODO implement
602599
parser.add_argument(
@@ -618,18 +615,6 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
618615
default=False,
619616
help="Enable CUDA memory cleanup before validation to prevent initialization errors.",
620617
) # DONE
621-
parser.add_argument(
622-
"--lora-alpha",
623-
type=int,
624-
default=None,
625-
help="Alpha parameter for LoRA fine-tuning.",
626-
) # TODO implement
627-
parser.add_argument(
628-
"--lora-dim",
629-
type=int,
630-
default=None,
631-
help="Dim parameter for LoRA fine-tuning.",
632-
) # TODO implement
633618
parser.add_argument(
634619
"--debug",
635620
action="store_true",
@@ -671,6 +656,38 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
671656
"--hf-tokenizer-model-name", type=str, help="Name of a remote HF tokenizer model."
672657
) # DONE
673658

659+
# LoRA
660+
parser.add_argument(
661+
"--lora-finetune",
662+
action="store_true",
663+
default=False,
664+
help="Use LoRA fine-tuning.",
665+
)
666+
parser.add_argument(
667+
"--lora-alpha",
668+
type=int,
669+
default=32,
670+
help="Alpha parameter for LoRA fine-tuning.",
671+
)
672+
parser.add_argument(
673+
"--lora-dim",
674+
type=int,
675+
default=16,
676+
help="Dim parameter for LoRA fine-tuning.",
677+
)
678+
parser.add_argument(
679+
"--lora-dropout",
680+
type=float,
681+
default=0.1,
682+
help="Dropout parameter for LoRA fine-tuning.",
683+
)
684+
parser.add_argument(
685+
"--lora-target-modules",
686+
type=lambda s: [m.strip() for m in s.split(",")],
687+
default=["dense_projection", "dense", "linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"],
688+
help="Target modules for LoRA fine-tuning, as a comma-separated list.",
689+
)
690+
674691
return parser.parse_args(args=args)
675692

676693

@@ -794,6 +811,13 @@ def train(args: argparse.Namespace) -> None:
794811
if args.no_weight_decay_embeddings:
795812
recipe_kwargs["no_weight_decay_embeddings"] = True
796813

814+
# LoRA
815+
recipe_kwargs["lora_finetune"] = args.lora_finetune
816+
recipe_kwargs["lora_alpha"] = args.lora_alpha
817+
recipe_kwargs["lora_dim"] = args.lora_dim
818+
recipe_kwargs["lora_dropout"] = args.lora_dropout
819+
recipe_kwargs["lora_target_modules"] = args.lora_target_modules
820+
797821
# 2. Generate Base Configuration
798822
cfg: ConfigContainer = pretrain_config(**recipe_kwargs)
799823

0 commit comments

Comments
 (0)