1616
1717import logging
1818from dataclasses import dataclass , field
19+ from functools import wraps
20+ from typing import Set
1921
2022import torch
2123from megatron .bridge .peft .base import ModelType
2224from megatron .bridge .peft .lora import LoRA
2325from megatron .bridge .peft .utils import wildcard_match
26+ from megatron .core .utils import unwrap_model
2427from torch import nn
2528
29+ from bionemo .evo2 .models .megatron .hyena .hyena_block import HyenaStack
30+
2631
2732logger : logging .Logger = logging .getLogger (__name__ )
2833
34+ _HYENA_RECOMPUTE_PATCHED : Set [int ] = set ()
35+
36+
37+ def _enable_recompute_inputs_grad_for_hyena (model , patched_registry : Set [int ] | None = None ) -> Set [int ]:
38+ """Enable grad on HyenaStack inputs when only adapters are trainable.
39+
40+ This is the HyenaStack analogue of ``maybe_enable_recompute_inputs_grad`` from
41+ ``megatron.bridge.peft.recompute``, which only patches ``TransformerBlock``.
42+ HyenaStack is not a TransformerBlock subclass, so the upstream fix never fires
43+ for Evo2 models.
44+
45+ When activation checkpointing is active (``recompute_granularity == "full"``),
46+ Megatron's ``CheckpointFunction.backward()`` is only invoked by PyTorch autograd
47+ when at least one *input* tensor to the checkpoint has ``requires_grad=True``.
48+ With PP=1 and a fully frozen base model the embedding outputs carry
49+ ``requires_grad=False``, so ``CheckpointFunction.backward()`` is never called
50+ and LoRA gradients inside the checkpoint are silently dropped.
51+
52+ The fix: monkey-patch ``HyenaStack.forward`` to force
53+ ``hidden_states.requires_grad_(True)`` before the tensor enters the checkpointed
54+ region. No parameters are unfrozen; only the autograd bookkeeping is corrected.
55+ """
56+ registry = patched_registry if patched_registry is not None else _HYENA_RECOMPUTE_PATCHED
57+
58+ unwrapped = unwrap_model (model )
59+ if not isinstance (unwrapped , list ):
60+ unwrapped = [unwrapped ]
61+
62+ for unwrapped_model in unwrapped :
63+ if unwrapped_model is None :
64+ continue
65+
66+ cfg = getattr (unwrapped_model , "config" , None )
67+ if cfg is None or getattr (cfg , "recompute_method" , None ) is None :
68+ continue
69+
70+ if id (unwrapped_model ) in registry :
71+ continue
72+
73+ params = list (unwrapped_model .named_parameters ())
74+ trainable_adapter = any (p .requires_grad and ".adapter." in n .lower () for n , p in params )
75+ trainable_base = any (
76+ p .requires_grad and ".to_wrap." not in n .lower () and ".adapter." not in n .lower () for n , p in params
77+ )
78+
79+ if not (trainable_adapter and not trainable_base ):
80+ continue
81+
82+ patched_any = False
83+ for module in unwrapped_model .modules ():
84+ if isinstance (module , HyenaStack ):
85+ original_forward = module .forward
86+
87+ @wraps (original_forward )
88+ def _patched_forward (hidden_states , * args , _orig = original_forward , ** kwargs ):
89+ if (
90+ torch .is_tensor (hidden_states )
91+ and not hidden_states .requires_grad
92+ and hidden_states .is_floating_point ()
93+ ):
94+ hidden_states = hidden_states .detach ().requires_grad_ (True )
95+ return _orig (hidden_states , * args , ** kwargs )
96+
97+ module .forward = _patched_forward
98+ patched_any = True
99+
100+ if patched_any :
101+ registry .add (id (unwrapped_model ))
102+ logger .info (
103+ "[Evo2LoRA+Recompute] Patched HyenaStack.forward to enable grad on "
104+ "hidden_states input. This ensures checkpoint backward is called when "
105+ "only adapters are trainable (PP=1 with frozen base model)."
106+ )
107+
108+ return registry
109+
29110
30111@dataclass
31112class Evo2LoRA (LoRA ):
@@ -47,6 +128,13 @@ class Evo2LoRA(LoRA):
47128
48129 skip_freeze_modules : list [str ] = field (default_factory = list )
49130
131+ def __call__ (self , model : ModelType , training : bool = True ) -> ModelType :
132+ """Apply LoRA to the model, with HyenaStack-aware recompute patching."""
133+ model = super ().__call__ (model , training = training )
134+ if training :
135+ _enable_recompute_inputs_grad_for_hyena (model )
136+ return model
137+
50138 def freeze_model (self , model : ModelType , training : bool = True ) -> None :
51139 """Freeze all model parameters except those matching ``skip_freeze_modules``.
52140
0 commit comments