Skip to content

Commit fff65b0

Browse files
Allow HF trainer to mask sequences prior to reduction (#1009)
### What does this PR do? Type of change: Bug fix Previously HF trainer did not account for loss masking ### Usage ```python # Add a code snippet demonstrating how to use this ``` ### Testing <!-- Mention how have you tested your change if applicable. --> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Improvements** * Knowledge-distillation loss now properly ignores padding/special tokens and supports masked per-token averaging. * Default loss reduction behavior adjusted for finer-grained training control and clearer per-token outputs. * More robust logit handling with consistent numeric casting for improved stability and accuracy, including mixed-precision scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 695c8e8 commit fff65b0

2 files changed

Lines changed: 38 additions & 7 deletions

File tree

modelopt/torch/distill/losses.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class LogitsDistillationLoss(Loss):
3131
This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531.
3232
"""
3333

34-
def __init__(self, temperature: float = 1.0, reduction: str = "batchmean"):
34+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
3535
"""Constructor.
3636
3737
Args:
@@ -57,11 +57,12 @@ def forward(self, logits_s: torch.Tensor, logits_t: torch.Tensor) -> torch.Tenso
5757
soft_log_probs = F.log_softmax(logits_s / self._temperature, dim=-1)
5858
soft_targets = F.softmax(logits_t / self._temperature, dim=-1)
5959

60-
soft_log_probs = soft_log_probs.view(-1, soft_log_probs.size(-1))
61-
soft_targets = soft_targets.view(-1, soft_targets.size(-1))
62-
6360
kd_loss = F.kl_div(soft_log_probs, soft_targets.detach(), reduction=self._reduction)
6461

62+
if self._reduction == "none":
63+
# Remove vocab dimension
64+
kd_loss = kd_loss.sum(dim=-1)
65+
6566
# Since the magnitudes of the gradients produced by the soft logits scale as 1/(T^2),
6667
# multiplying them by T^2 ensures that the relative contributions of the logits
6768
# remain roughly unchanged while experimenting with meta-parameters.

modelopt/torch/distill/plugins/huggingface.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515

1616
"""ModelOpt plugin to train HuggingFace models with knowledge distillation."""
1717

18+
from torch import Tensor
1819
from transformers.modeling_outputs import CausalLMOutputWithPast
20+
from transformers.trainer_pt_utils import LabelSmoother
1921

2022
import modelopt.torch.distill as mtd
2123
from modelopt.torch.opt.plugins import ModelOptHFTrainer
2224
from modelopt.torch.utils import print_rank_0
2325

26+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index # equals -100
27+
2428

2529
class KDTrainer(ModelOptHFTrainer):
2630
"""Distillation trainer for HuggingFace models."""
@@ -98,12 +102,37 @@ def save_model(
98102

99103
def train(self, *args, **kwargs):
100104
"""Train the model."""
101-
self.compute_loss_func = lambda *args, **kwargs: self.model.compute_kd_loss()
105+
106+
def _compute_kd_loss(outputs: Tensor, labels: Tensor | None, **kwargs):
107+
def loss_reduction_fn(loss: Tensor):
108+
if labels is None:
109+
return loss.mean()
110+
loss_mask = labels != IGNORE_TOKEN_ID
111+
return (loss * loss_mask).sum() / loss_mask.sum().clamp(min=1)
112+
113+
return self.model.compute_kd_loss(loss_reduction_fn=loss_reduction_fn)
114+
115+
self.compute_loss_func = _compute_kd_loss
102116
return super().train(*args, **kwargs)
103117

104118

105119
class LMLogitsLoss(mtd.LogitsDistillationLoss):
106-
"""Logits loss for knowledge distillation."""
120+
"""Logits loss for language-model knowledge distillation.
121+
122+
Defaults to ``reduction="none"`` to support per-token loss masking via ``loss_reduction_fn``
123+
in :meth:`DistillationModel.compute_kd_loss`. This allows masking out padding and non-assistant
124+
tokens before reducing the loss.
125+
"""
126+
127+
def __init__(self, temperature: float = 1.0, reduction: str = "none"):
128+
"""Constructor.
129+
130+
Args:
131+
temperature: A value used to soften the logits before computing loss.
132+
reduction: How to reduce the final pointwise loss. Defaults to ``"none"`` to
133+
allow loss-masking via ``loss_reduction_fn`` in ``compute_kd_loss``.
134+
"""
135+
super().__init__(temperature=temperature, reduction=reduction)
107136

108137
def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutputWithPast):
109138
"""Forward pass for logits distillation loss.
@@ -112,4 +141,5 @@ def forward(self, out_student: CausalLMOutputWithPast, out_teacher: CausalLMOutp
112141
out_student: The student model output.
113142
out_teacher: The teacher model output.
114143
"""
115-
return super().forward(out_student.logits, out_teacher.logits)
144+
student_logits, teacher_logits = out_student.logits.float(), out_teacher.logits.float()
145+
return super().forward(student_logits, teacher_logits)

0 commit comments

Comments
 (0)