Skip to content

Commit 8eba86f

Browse files
AAnooshehdanielkorzekwa
authored andcommitted
Layerwise KD mode (#802)
## What does this PR do? **Type of change:** new feature **Overview:** Add a subclass of `DistillationModel` which implements slightly different hooks to inject teacher tensors into corresponding student layers for module replacement purposes, as opposed to logits distillation. ## Usage ```python mtd.convert(model, mode=[("layerwise_kd", config)]) ``` ## Testing New units ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Introduced bypass-enabled knowledge distillation mode with layer-level loss mapping for fine-grained model optimization control. * Added model export functionality with automatic cleanup of intermediate activation capturing mechanisms. * **API Changes** * New bypass_kd mode configuration option available for advanced knowledge distillation workflows. * Updated model export interface for improved lifecycle management. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 936aea1 commit 8eba86f

File tree

8 files changed

+450
-59
lines changed

8 files changed

+450
-59
lines changed

modelopt/torch/distill/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .config import *
2020
from .distillation import *
2121
from .distillation_model import *
22+
from .layerwise_distillation_model import *
2223
from .loss_balancers import *
2324
from .losses import *
2425
from .registry import *

modelopt/torch/distill/config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from .loss_balancers import DistillationLossBalancer
2828

29-
__all__ = ["KDLossConfig"]
29+
__all__ = ["ExportStudentConfig", "KDLossConfig", "LayerwiseKDConfig"]
3030

3131
Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007
3232

@@ -120,6 +120,25 @@ def _strict_validate(self) -> None:
120120
)
121121

122122

123+
class LayerwiseKDConfig(KDLossConfig):
124+
"""Configuration for the Layerwise Knowledge-Distillation mode.
125+
126+
This mode is used to distill knowledge from a teacher model to a student model using layerwise distillation.
127+
"""
128+
129+
@pydantic.field_validator("criterion")
130+
@classmethod
131+
def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]:
132+
"""Ensure criterion is a mapping from layer names to loss (potentially entire module)."""
133+
if not isinstance(criterion, dict):
134+
raise ValueError("Layerwise Distillation mode requires explicit criterion pairs.")
135+
if any(key == ("", "") for key in criterion):
136+
raise ValueError(
137+
"Layerwise Distillation mode does not support output-only distillation."
138+
)
139+
return criterion
140+
141+
123142
class ExportStudentConfig(ModeloptBaseConfig):
124143
"""Configuration for the export_student mode.
125144

modelopt/torch/distill/distillation_model.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
17-
1816
"""Meta-model wrapper to support knowledge-distillation learning."""
1917

2018
import inspect
@@ -45,6 +43,7 @@ def _setup(self):
4543
self._register_temp_attribute("_loss_modules", nn.ModuleList())
4644
self._register_temp_attribute("_only_teacher_fwd", False)
4745
self._register_temp_attribute("_only_student_fwd", False)
46+
self._register_temp_attribute("_hook_handles", set())
4847

4948
# HACK: set model's forward signature to match student class' original.
5049
# Needed for HF `transformers.utils.find_labels` which relies on inspecting class signature.
@@ -57,23 +56,22 @@ def _setup(self):
5756

5857
def modify(
5958
self,
60-
teacher_model: nn.Module, # To be frozen.
59+
teacher_model: nn.Module,
6160
criterion: dict[
6261
tuple[
63-
str, # Student model layer whose output to capture.
64-
str, # Teacher model layer whose output to capture.
62+
str, # Student model layer whose output to capture
63+
str, # Teacher model layer whose output to capture
6564
],
66-
Loss, # Loss fn.
65+
Loss, # Loss function
6766
],
6867
loss_balancer: DistillationLossBalancer | None = None,
6968
expose_minimal_state_dict: bool = True,
7069
):
7170
"""Constructor.
7271
7372
Args:
74-
teacher_model: A teacher model which this class would encapsulate.
75-
criterion: A dictionary mapping the tuple of student and teacher
76-
model layer names to the loss function to apply to that layer pair.
73+
teacher_model: The teacher model (will be frozen).
74+
criterion: Dictionary mapping (student_layer_name, teacher_layer_name) to loss functions.
7775
loss_balancer: Instance of
7876
:class:`DistillationLossBalancer <modelopt.torch.distill.DistillationLossBalancer>`
7977
which reduces distillation and non-distillation losses into a single value using some weighing scheme.
@@ -106,22 +104,30 @@ def modify(
106104
{m for m in self._layers_to_loss.values() if len(list(m.parameters())) > 0}
107105
)
108106

109-
# Disable grad for teacher
107+
# Disable grad for teacher.
110108
self._teacher_model.requires_grad_(False)
111109

112-
# Register hooks for intermediate outputs from teacher models and the student model.
113-
# HACK: For inexplicable reasons, sometimes a model will have hooks remain after
114-
# `ato.restore()` so we check if they are present accidentally first.
110+
# Use hooks to caputure relevant activation tensors for loss computation.
111+
self._register_hooks()
112+
113+
def _register_hooks(self):
114+
"""Register hooks for intermediate tensors from teacher models and the student model."""
115115
for student_layer, teacher_layer in self._layers_to_loss:
116116
setattr(student_layer, "_intermediate_output", None)
117-
if student_output_capture_fwd_hook not in student_layer._forward_hooks.values():
118-
student_layer.register_forward_hook(student_output_capture_fwd_hook)
117+
handle_s = student_layer.register_forward_hook(student_output_capture_fwd_hook)
119118
setattr(teacher_layer, "_intermediate_output", None)
120-
if teacher_output_capture_fwd_hook not in teacher_layer._forward_hooks.values():
121-
teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook)
119+
handle_t = teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook)
120+
self._hook_handles.update([handle_s, handle_t])
121+
122+
def export(self):
123+
"""Export the distillation model."""
124+
for handle in self._hook_handles:
125+
handle.remove()
126+
self._hook_handles.clear()
127+
return super().export()
122128

123129
@property
124-
def teacher_model(self) -> nn.ModuleList:
130+
def teacher_model(self) -> nn.Module:
125131
"""Fetch the teacher model."""
126132
return self._teacher_model
127133

@@ -148,7 +154,7 @@ def hide_teacher_model(self, enable=True):
148154

149155
@contextmanager
150156
def hide_loss_modules(self, enable=True):
151-
"""Context manager to temporarily hide teacher model from the model."""
157+
"""Context manager to temporarily hide loss modules from the model."""
152158
loss_modules = self._loss_modules
153159
if enable:
154160
self._loss_modules = nn.ModuleList()
@@ -169,7 +175,7 @@ def only_teacher_forward(self, enable=True):
169175

170176
@contextmanager
171177
def only_student_forward(self, enable=True):
172-
"""Context manager to temporarily disable forward passes on the student model."""
178+
"""Context manager to temporarily run forward passes only on the student model."""
173179
if enable:
174180
self._only_student_fwd = True
175181
try:
@@ -245,15 +251,13 @@ def compute_kd_loss(
245251
246252
Args:
247253
student_loss: Original loss computed from the student's output.
248-
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for
249-
loss-masking situations where the callable changes arguments each iteration.
254+
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing.
255+
Useful for loss-masking situations where the callable changes arguments each iteration.
250256
skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar.
251257
**loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed.
252-
This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``.
253258
254259
Returns:
255-
If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
256-
If reduce is False, a dict of student model output loss and layer-wise distillation losses.
260+
A dict of losses if skip_balancer is True, else the scalar total loss.
257261
"""
258262
if self._loss_balancer is None:
259263
assert student_loss is None, "Cannot pass in student loss without using Loss Balancer."
@@ -288,9 +292,9 @@ def compute_kd_loss(
288292
return loss_total
289293

290294

291-
def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin
295+
def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
292296
"""A hook to capture layer output."""
293-
# NOTE: Defined externally to allow pickling.
297+
# NOTE: Defined externally to allow pickling during DDP initialization.
294298

295299
if getattr(module, "_only_teacher_fwd", False):
296300
return # Might be hooked on entire model fwd
@@ -303,9 +307,9 @@ def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
303307
module._intermediate_output = output
304308

305309

306-
def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin
310+
def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
307311
"""A hook to capture layer output."""
308-
# NOTE: Defined externally to allow pickling.
312+
# NOTE: Defined externally to allow pickling during DDP initialization.
309313

310314
if module._intermediate_output is not None:
311315
# NOTE: cannot tell if train or eval since teacher is always eval
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning."""
17+
18+
import warnings
19+
from typing import Any
20+
21+
import torch.nn as nn
22+
23+
from .distillation_model import DistillationModel, student_output_capture_fwd_hook
24+
25+
__all__ = ["LayerwiseDistillationModel"]
26+
27+
28+
class LayerwiseDistillationModel(DistillationModel):
29+
"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning.
30+
31+
The LayerwiseDistillationModel is a subclass of the DistillationModel that injects teacher inputs
32+
into the corresponding student layers. This accomodates the case where the student model is the
33+
teacher with specific submodules replaced, which now need to be trained to mimic the original
34+
submodule in the teacher.
35+
"""
36+
37+
def modify(self, *args, **kwargs):
38+
"""Modify the distillation model."""
39+
super().modify(*args, **kwargs)
40+
41+
# Freeze student layers except those in criterion.
42+
self.requires_grad_(False)
43+
for student_layer, _ in self._layers_to_loss:
44+
student_layer.requires_grad_(True)
45+
46+
# Make lm heads (if we have them) no-ops to save compute.
47+
if hasattr(self, "lm_head"):
48+
self._lm_head = self.lm_head
49+
self.lm_head = nn.Identity()
50+
if hasattr(self._teacher_model, "lm_head"):
51+
self._teacher_model._lm_head = self._teacher_model.lm_head
52+
self._teacher_model.lm_head = nn.Identity()
53+
54+
return self
55+
56+
def _register_hooks(self):
57+
"""Register hooks for intermediate tensors from teacher models and the student model."""
58+
for student_layer, teacher_layer in self._layers_to_loss:
59+
setattr(student_layer, "_teacher_layer", [teacher_layer])
60+
handle_s1 = student_layer.register_forward_pre_hook(student_input_bypass_fwd_hook)
61+
setattr(student_layer, "_intermediate_output", None)
62+
handle_s2 = student_layer.register_forward_hook(student_output_capture_fwd_hook)
63+
setattr(teacher_layer, "_intermediate_input", None)
64+
setattr(teacher_layer, "_intermediate_output", None)
65+
handle_t = teacher_layer.register_forward_hook(teacher_input_output_capture_fwd_hook)
66+
self._hook_handles.update([handle_s1, handle_s2, handle_t])
67+
68+
def export(self):
69+
"""Export the distillation model."""
70+
for student_layer, _ in self._layers_to_loss:
71+
delattr(student_layer, "_teacher_layer")
72+
73+
if hasattr(self, "_lm_head"):
74+
self.lm_head = self._lm_head
75+
if hasattr(self._teacher_model, "_lm_head"):
76+
self._teacher_model.lm_head = self._teacher_model._lm_head
77+
78+
return super().export()
79+
80+
81+
def student_input_bypass_fwd_hook(module: nn.Module, input: Any):
82+
"""A hook to inject teacher input into corresponding student layer."""
83+
# NOTE: Defined externally to allow pickling during DDP initialization.
84+
85+
if getattr(module, "_only_teacher_fwd", False):
86+
return input # Might be hooked on entire model fwd
87+
88+
teacher_layer = module._teacher_layer[0]
89+
teacher_input = teacher_layer._intermediate_input
90+
if teacher_input is None:
91+
warnings.warn(
92+
f"Teacher's Module `{type(teacher_layer).__name__}` has no intermediate input stored."
93+
" This is expected when the `only_student_forward` context manager is in use."
94+
)
95+
return input
96+
97+
teacher_layer._intermediate_input = None # reset
98+
return teacher_input
99+
100+
101+
def teacher_input_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
102+
"""A hook to capture layer input and output."""
103+
# NOTE: Defined externally to allow pickling during DDP initialization.
104+
105+
if module._intermediate_output is not None:
106+
# NOTE: cannot tell if train or eval since teacher is always eval
107+
warnings.warn(
108+
f"Teacher's Module `{type(module).__name__}` already has an intermediate output stored."
109+
" This is expected when `DistillationModel.compute_kd_loss` is not called in eval mode."
110+
)
111+
112+
module._intermediate_input = input
113+
module._intermediate_output = output

0 commit comments

Comments
 (0)