Skip to content

Commit 89b5dc5

Browse files
committed
added connector for lora skip
1 parent a933771 commit 89b5dc5

4 files changed

Lines changed: 389 additions & 2 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_lora.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,97 @@
1616

1717
import logging
1818
from dataclasses import dataclass, field
19+
from functools import wraps
20+
from typing import Set
1921

2022
import torch
2123
from megatron.bridge.peft.base import ModelType
2224
from megatron.bridge.peft.lora import LoRA
2325
from megatron.bridge.peft.utils import wildcard_match
26+
from megatron.core.utils import unwrap_model
2427
from torch import nn
2528

29+
from bionemo.evo2.models.megatron.hyena.hyena_block import HyenaStack
30+
2631

2732
logger: logging.Logger = logging.getLogger(__name__)
2833

34+
_HYENA_RECOMPUTE_PATCHED: Set[int] = set()
35+
36+
37+
def _enable_recompute_inputs_grad_for_hyena(model, patched_registry: Set[int] | None = None) -> Set[int]:
38+
"""Enable grad on HyenaStack inputs when only adapters are trainable.
39+
40+
This is the HyenaStack analogue of ``maybe_enable_recompute_inputs_grad`` from
41+
``megatron.bridge.peft.recompute``, which only patches ``TransformerBlock``.
42+
HyenaStack is not a TransformerBlock subclass, so the upstream fix never fires
43+
for Evo2 models.
44+
45+
When activation checkpointing is active (``recompute_granularity == "full"``),
46+
Megatron's ``CheckpointFunction.backward()`` is only invoked by PyTorch autograd
47+
when at least one *input* tensor to the checkpoint has ``requires_grad=True``.
48+
With PP=1 and a fully frozen base model the embedding outputs carry
49+
``requires_grad=False``, so ``CheckpointFunction.backward()`` is never called
50+
and LoRA gradients inside the checkpoint are silently dropped.
51+
52+
The fix: monkey-patch ``HyenaStack.forward`` to force
53+
``hidden_states.requires_grad_(True)`` before the tensor enters the checkpointed
54+
region. No parameters are unfrozen; only the autograd bookkeeping is corrected.
55+
"""
56+
registry = patched_registry if patched_registry is not None else _HYENA_RECOMPUTE_PATCHED
57+
58+
unwrapped = unwrap_model(model)
59+
if not isinstance(unwrapped, list):
60+
unwrapped = [unwrapped]
61+
62+
for unwrapped_model in unwrapped:
63+
if unwrapped_model is None:
64+
continue
65+
66+
cfg = getattr(unwrapped_model, "config", None)
67+
if cfg is None or getattr(cfg, "recompute_method", None) is None:
68+
continue
69+
70+
if id(unwrapped_model) in registry:
71+
continue
72+
73+
params = list(unwrapped_model.named_parameters())
74+
trainable_adapter = any(p.requires_grad and ".adapter." in n.lower() for n, p in params)
75+
trainable_base = any(
76+
p.requires_grad and ".to_wrap." not in n.lower() and ".adapter." not in n.lower() for n, p in params
77+
)
78+
79+
if not (trainable_adapter and not trainable_base):
80+
continue
81+
82+
patched_any = False
83+
for module in unwrapped_model.modules():
84+
if isinstance(module, HyenaStack):
85+
original_forward = module.forward
86+
87+
@wraps(original_forward)
88+
def _patched_forward(hidden_states, *args, _orig=original_forward, **kwargs):
89+
if (
90+
torch.is_tensor(hidden_states)
91+
and not hidden_states.requires_grad
92+
and hidden_states.is_floating_point()
93+
):
94+
hidden_states = hidden_states.detach().requires_grad_(True)
95+
return _orig(hidden_states, *args, **kwargs)
96+
97+
module.forward = _patched_forward
98+
patched_any = True
99+
100+
if patched_any:
101+
registry.add(id(unwrapped_model))
102+
logger.info(
103+
"[Evo2LoRA+Recompute] Patched HyenaStack.forward to enable grad on "
104+
"hidden_states input. This ensures checkpoint backward is called when "
105+
"only adapters are trainable (PP=1 with frozen base model)."
106+
)
107+
108+
return registry
109+
29110

30111
@dataclass
31112
class Evo2LoRA(LoRA):
@@ -47,6 +128,13 @@ class Evo2LoRA(LoRA):
47128

48129
skip_freeze_modules: list[str] = field(default_factory=list)
49130

131+
def __call__(self, model: ModelType, training: bool = True) -> ModelType:
132+
"""Apply LoRA to the model, with HyenaStack-aware recompute patching."""
133+
model = super().__call__(model, training=training)
134+
if training:
135+
_enable_recompute_inputs_grad_for_hyena(model)
136+
return model
137+
50138
def freeze_model(self, model: ModelType, training: bool = True) -> None:
51139
"""Freeze all model parameters except those matching ``skip_freeze_modules``.
52140

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pathlib import Path
1919

2020
import torch
21-
from megatron.bridge.peft.lora import LoRA
2221
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
2322
from megatron.bridge.training.comm_overlap import CommOverlapConfig
2423
from megatron.bridge.training.config import (
@@ -37,6 +36,7 @@
3736
from bionemo.evo2.data.evo2_mock_dataset_provider import MockEvo2DatasetProvider
3837
from bionemo.evo2.data.megatron.hyena.evo2_dataset import Evo2Dataset, Evo2DatasetPadEodLossMask
3938
from bionemo.evo2.data.sharded_eden_dataset_provider import ShardedEdenDatasetProvider
39+
from bionemo.evo2.models.evo2_lora import Evo2LoRA
4040
from bionemo.evo2.models.evo2_provider import (
4141
Hyena1bModelProvider,
4242
HyenaModelProvider,
@@ -95,6 +95,7 @@ class Evo2CommonKwargs(TypedDict, total=False):
9595
lora_dim: int
9696
lora_dropout: float
9797
lora_target_modules: list[str]
98+
lora_skip_freeze_modules: list[str]
9899

99100

100101
def evo2_1b_pretrain_config(**user_kwargs: Unpack[Evo2CommonKwargs]) -> ConfigContainer:
@@ -170,6 +171,7 @@ def _evo2_common(
170171
lora_dim: int = 16,
171172
lora_dropout: float = 0.1,
172173
lora_target_modules: list[str] = ["dense_projection", "linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"],
174+
lora_skip_freeze_modules: list[str] = [],
173175
) -> ConfigContainer:
174176
"""Create a pre-training configuration for Mamba 2.x models.
175177
@@ -245,11 +247,12 @@ def _evo2_common(
245247
)
246248

247249
if lora_finetune:
248-
peft = LoRA(
250+
peft = Evo2LoRA(
249251
target_modules=lora_target_modules,
250252
dim=lora_dim,
251253
alpha=lora_alpha,
252254
dropout=lora_dropout,
255+
skip_freeze_modules=lora_skip_freeze_modules,
253256
)
254257
else:
255258
peft = None

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
687687
default=["dense_projection", "dense", "linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"],
688688
help="Target modules for LoRA fine-tuning, as a comma-separated list.",
689689
)
690+
parser.add_argument(
691+
"--lora-skip-freeze-modules",
692+
type=lambda s: [m.strip() for m in s.split(",")],
693+
default=[],
694+
help="Skip freeze modules for LoRA fine-tuning, as a comma-separated list.",
695+
)
690696

691697
return parser.parse_args(args=args)
692698

@@ -817,6 +823,7 @@ def train(args: argparse.Namespace) -> None:
817823
recipe_kwargs["lora_dim"] = args.lora_dim
818824
recipe_kwargs["lora_dropout"] = args.lora_dropout
819825
recipe_kwargs["lora_target_modules"] = args.lora_target_modules
826+
recipe_kwargs["lora_skip_freeze_modules"] = args.lora_skip_freeze_modules
820827

821828
# 2. Generate Base Configuration
822829
cfg: ConfigContainer = pretrain_config(**recipe_kwargs)

0 commit comments

Comments
 (0)