Skip to content

Commit e1fcfc4

Browse files
authored
add freeze, unfreeze methods with experimental tag (#15729)
1 parent 5d9b6bc commit e1fcfc4

3 files changed

Lines changed: 79 additions & 80 deletions

File tree

nemo/collections/asr/modules/transformer_encoder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import torch.nn as nn
1919
from torch.nn.attention.flex_attention import and_masks, create_block_mask, flex_attention
2020

21+
from nemo.core.classes.module import freeze, unfreeze
22+
from nemo.utils.decorators import experimental
23+
2124
flex_attention_compiled = torch.compile(flex_attention, dynamic=True)
2225

2326

@@ -163,6 +166,7 @@ def forward(self, x, block_mask=None):
163166
return x
164167

165168

169+
@experimental
166170
class TransformerEncoder(nn.Module):
167171
"""Pre-norm Transformer encoder for ASR.
168172
@@ -259,3 +263,9 @@ def forward(self, audio_signal, length):
259263
x = self.final_norm(x)
260264
x = x.transpose(1, 2) # (B, T, D) -> (B, D, T)
261265
return x, length
266+
267+
def freeze(self) -> None:
268+
freeze(self)
269+
270+
def unfreeze(self, partial: bool = False) -> None:
271+
unfreeze(self, partial=partial)

nemo/core/classes/module.py

Lines changed: 54 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,50 @@
2020
from nemo.core.classes.common import FileIO, Serialization, Typing
2121
from nemo.utils import logging
2222

23-
__all__ = ['NeuralModule']
23+
__all__ = ['NeuralModule', 'freeze', 'unfreeze']
24+
25+
26+
def freeze(module: Module) -> None:
27+
"""Freeze all parameters of ``module`` and snapshot their prior ``requires_grad`` state.
28+
29+
The snapshot is stored on ``module._frozen_grad_map`` so a later call to ``unfreeze(..., partial=True)``
30+
can restore the pre-freeze state instead of unconditionally enabling gradients.
31+
"""
32+
grad_map = {pname: param.requires_grad for pname, param in module.named_parameters()}
33+
for param in module.parameters():
34+
param.requires_grad = False
35+
if not hasattr(module, '_frozen_grad_map'):
36+
module._frozen_grad_map = grad_map
37+
else:
38+
module._frozen_grad_map.update(grad_map)
39+
module.eval()
40+
41+
42+
def unfreeze(module: Module, partial: bool = False) -> None:
43+
"""Unfreeze parameters of ``module``.
44+
45+
If ``partial=True``, restore each parameter's ``requires_grad`` from the snapshot recorded by
46+
``freeze(module)``; otherwise enable gradients on every parameter. The snapshot is cleared in
47+
both cases and ``module.train()`` is called.
48+
"""
49+
if partial and not hasattr(module, '_frozen_grad_map'):
50+
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")
51+
52+
for pname, param in module.named_parameters():
53+
if not partial:
54+
param.requires_grad = True
55+
elif pname in module._frozen_grad_map:
56+
param.requires_grad = module._frozen_grad_map[pname]
57+
else:
58+
logging.warning(
59+
f"Parameter {pname} not found in list of previously frozen parameters. Unfreezing this parameter."
60+
)
61+
param.requires_grad = True
62+
63+
if hasattr(module, '_frozen_grad_map'):
64+
delattr(module, '_frozen_grad_map')
65+
66+
module.train()
2467

2568

2669
class NeuralModule(Module, Typing, Serialization, FileIO):
@@ -53,99 +96,30 @@ def input_example(self, max_batch=None, max_dim=None):
5396
return None
5497

5598
def freeze(self) -> None:
56-
r"""
57-
Freeze all params for inference.
58-
59-
This method sets `requires_grad` to False for all parameters of the module.
60-
It also stores the original `requires_grad` state of each parameter in a dictionary,
61-
so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
62-
"""
63-
grad_map = {}
64-
65-
for pname, param in self.named_parameters():
66-
# Store the original grad state
67-
grad_map[pname] = param.requires_grad
68-
# Freeze the parameter
69-
param.requires_grad = False
70-
71-
# Store the frozen grad map
72-
if not hasattr(self, '_frozen_grad_map'):
73-
self._frozen_grad_map = grad_map
74-
else:
75-
self._frozen_grad_map.update(grad_map)
76-
77-
self.eval()
99+
r"""Freeze all params for inference. See :func:`freeze` for details."""
100+
freeze(self)
78101

79102
def unfreeze(self, partial: bool = False) -> None:
80-
"""
81-
Unfreeze all parameters for training.
82-
83-
Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
84-
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
85-
previously unfrozen prior `freeze()`.
103+
"""Unfreeze parameters for training. See :func:`unfreeze` for details.
86104
87105
Example:
88-
Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.
89-
90-
```python
91-
model.encoder.freeze() # Freezes all parameters in the encoder explicitly
92-
```
93-
94-
During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
95-
This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
96-
we should keep the encoder parameters frozen.
97-
98106
```python
99-
model.freeze() # Freezes all parameters in the model; encoder remains frozen
107+
model.encoder.freeze() # caller freezes encoder
108+
model.freeze() # freezes everything; encoder snapshot preserved
109+
model.unfreeze(partial=True) # decoder unfrozen, encoder stays frozen
100110
```
101-
102-
Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
103-
`unfreeze(partial=True)`.
104-
105-
```python
106-
model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen
107-
```
108-
109-
Args:
110-
partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
111-
when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
112111
"""
113-
if partial and not hasattr(self, '_frozen_grad_map'):
114-
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")
115-
116-
for pname, param in self.named_parameters():
117-
if not partial:
118-
# Unfreeze all parameters
119-
param.requires_grad = True
120-
else:
121-
# Unfreeze only parameters that were previously frozen
122-
123-
# Check if the parameter was frozen
124-
if pname in self._frozen_grad_map:
125-
param.requires_grad = self._frozen_grad_map[pname]
126-
else:
127-
# Log a warning if the parameter was not found in the frozen grad map
128-
logging.warning(
129-
f"Parameter {pname} not found in list of previously frozen parameters. "
130-
f"Unfreezing this parameter."
131-
)
132-
param.requires_grad = True
133-
134-
# Clean up the frozen grad map
135-
if hasattr(self, '_frozen_grad_map'):
136-
delattr(self, '_frozen_grad_map')
137-
138-
self.train()
112+
unfreeze(self, partial=partial)
139113

140114
@contextmanager
141115
def as_frozen(self):
142116
"""
143117
Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially
144118
to return to original state.
145119
146-
Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
147-
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
148-
previously unfrozen prior `freeze()`.
120+
Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen
121+
previously with `freeze()`). The `partial` argument is used to determine whether to unfreeze
122+
all parameters or only the parameters that were previously unfrozen prior `freeze()`.
149123
150124
Example:
151125
with model.as_frozen(): # by default, partial = True

tests/collections/asr/test_transformer_encoder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,21 @@ def test_causal_future_does_not_affect_past(self):
163163
safe_t = (T // 2) // model.pre_encode.subsampling_factor
164164
assert torch.allclose(out_a[:, :, :safe_t], out_b[:, :, :safe_t], atol=1e-5)
165165

166+
@pytest.mark.unit
167+
def test_freeze_unfreeze_partial_restores_prior_state(self):
168+
model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2)
169+
for p in model.final_norm.parameters():
170+
p.requires_grad = False
171+
prior = {n: p.requires_grad for n, p in model.named_parameters()}
172+
173+
model.freeze()
174+
assert all(not p.requires_grad for p in model.parameters())
175+
assert not model.training
176+
177+
model.unfreeze(partial=True)
178+
assert {n: p.requires_grad for n, p in model.named_parameters()} == prior
179+
assert model.training
180+
166181
@pytest.mark.unit
167182
def test_forward_cpu(self):
168183
"""Forward pass on CPU uses unfused FlexAttention fallback."""

0 commit comments

Comments
 (0)