Skip to content

Commit ba1d725

Browse files
committed
Support DeepSpeed ZeRO-3 in KDTrainer; fix Liger hidden-states dtype
- Add fully-frozen-model fallback in ModelOptHFTrainer._prepare_model so DS ZeRO-3 can prepare a frozen teacher without hitting the empty trainable_param_groups assertion. - Add KDTrainer._ds_gather context manager for explicit param gather, since the teacher is loaded under zero.Init but not wrapped in a DeepSpeedEngine (no per-module hooks). - Unify KD sharded Liger compute: delegate student lm_head gather to the parent's _sharded_liger_compute and add teacher lm_head gather via _apply_teacher_gather. - Cast outputs.logits to lm_head.weight dtype before Liger fused kernels (final RMSNorm may leave hidden_states in fp32). - Drop redundant KDTrainer._get_lm_head override (inherited). Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 1f1c250 commit ba1d725

2 files changed

Lines changed: 44 additions & 25 deletions

File tree

modelopt/torch/distill/plugins/huggingface.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,22 @@ def _get_unwrapped_teacher(self):
157157
"""Unwrap teacher model (removes FSDP/DDP/DeepSpeed wrapper)."""
158158
return self.accelerator.unwrap_model(self._teacher_model)
159159

160+
@contextmanager
161+
def _ds_gather(self, params):
162+
"""Gather DS ZeRO-3 partitioned params; no-op if DeepSpeed disabled.
163+
164+
The teacher is loaded under an active ``zero.Init`` but not wrapped in a
165+
DeepSpeedEngine, so its params have no per-module gather hooks and need an
166+
explicit gather around any forward use.
167+
"""
168+
if self.is_deepspeed_enabled:
169+
import deepspeed
170+
171+
with deepspeed.zero.GatheredParameters(list(params), modifier_rank=None):
172+
yield
173+
else:
174+
yield
175+
160176
def compute_loss(self, model, inputs, **kwargs):
161177
"""Store teacher inputs before delegating to parent (which handles liger ctx)."""
162178
self._ensure_teacher_prepared()
@@ -169,7 +185,7 @@ def compute_kd_loss_func(self, outputs, labels, **kwargs):
169185
Teacher forward runs here so it is inside the liger identity-lm_head
170186
context when liger is enabled (ModelOptHFTrainer wraps compute_loss).
171187
"""
172-
with torch.no_grad():
188+
with torch.no_grad(), self._ds_gather(self._teacher_model.parameters()):
173189
self._teacher_model.eval()
174190
teacher_outputs = self._teacher_model(**self._teacher_inputs)
175191
self._teacher_inputs = None
@@ -193,10 +209,6 @@ def _standard_kd_loss(self, outputs, labels, **kwargs):
193209
self._last_teacher_outputs = None
194210
return loss
195211

196-
def _get_lm_head(self, model):
197-
"""Resolve lm_head from a model."""
198-
return model.lm_head
199-
200212
@contextmanager
201213
def _liger_identity_lm_head(self):
202214
"""Patch both student+teacher lm_heads to identity."""
@@ -215,22 +227,24 @@ def _liger_identity_lm_head(self):
215227
teacher_lm_head.forward = teacher_orig
216228

217229
def _sharded_liger_compute(self, fn):
218-
"""Route fn through sharded DP, gathering both student+teacher lm_head params."""
230+
"""Delegate student lm_head gather to parent; add teacher lm_head gather on top."""
231+
return super()._sharded_liger_compute(self._apply_teacher_gather(fn))
232+
233+
def _apply_teacher_gather(self, fn):
234+
"""Wrap fn so the teacher's lm_head params are gathered when it runs."""
219235
if self.is_fsdp_enabled:
220-
return _forward_redirect(
221-
self.model,
222-
lambda: _forward_redirect(self._teacher_model, fn),
223-
)
236+
teacher = self._teacher_model
237+
return lambda: _forward_redirect(teacher, fn)
224238
if self.is_deepspeed_enabled:
225-
model = self.accelerator.unwrap_model(self.model)
226-
teacher = self._get_unwrapped_teacher()
227-
student_lm_head = self._get_lm_head(model)
228-
teacher_lm_head = self._get_lm_head(teacher)
229-
return _forward_redirect(
230-
student_lm_head,
231-
lambda: _forward_redirect(teacher_lm_head, fn),
232-
)
233-
return fn()
239+
# Teacher is not in the DS engine; gather its lm_head explicitly.
240+
teacher_lm_head = self._get_lm_head(self._get_unwrapped_teacher())
241+
242+
def _wrapped():
243+
with self._ds_gather([teacher_lm_head.weight]):
244+
return fn()
245+
246+
return _wrapped
247+
return fn
234248

235249
def _liger_kd_loss(self, outputs, labels, **kwargs):
236250
"""Fused lm_head + JSD for KD."""
@@ -239,13 +253,13 @@ def _liger_kd_loss(self, outputs, labels, **kwargs):
239253
model = self.accelerator.unwrap_model(self.model)
240254
teacher = self._get_unwrapped_teacher()
241255

242-
student_hs = outputs.logits
243-
teacher_hs = self._last_teacher_outputs.logits
244-
self._last_teacher_outputs = None
245-
246256
student_lm_head = self._get_lm_head(model)
247257
teacher_lm_head = self._get_lm_head(teacher)
248258

259+
student_hs = outputs.logits.to(student_lm_head.weight.dtype) # RMSNorm may upcast to fp32
260+
teacher_hs = self._last_teacher_outputs.logits.to(teacher_lm_head.weight.dtype)
261+
self._last_teacher_outputs = None
262+
249263
# Causal LM shift
250264
student_hs = student_hs[..., :-1, :].contiguous().view(-1, student_hs.size(-1))
251265
teacher_hs = teacher_hs[..., :-1, :].contiguous().view(-1, teacher_hs.size(-1))

modelopt/torch/opt/plugins/transformers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,13 @@ def _prepare_model(self, model):
493493
"""Prepare a model via accelerator (materializes meta-device params, applies sharding).
494494
495495
Uses a dummy optimizer because ``accelerator.prepare`` requires one for FSDP2.
496-
Works generically for FSDP2, DDP, and DeepSpeed backends.
496+
Works generically for FSDP2, DDP, and DeepSpeed backends. For fully-frozen models
497+
under DS ZeRO-3, falls back to inference-mode prep since ZeRO-3 asserts on empty
498+
trainable_param_groups; in that case the caller is responsible for gathering
499+
``zero.Init``-partitioned params around forward passes.
497500
"""
501+
if self.is_deepspeed_enabled and not any(p.requires_grad for p in model.parameters()):
502+
return self.accelerator.prepare_model(model, evaluation_mode=True)
498503
dummy_optimizer = torch.optim.SGD([next(model.parameters())], lr=0.0)
499504
model, _ = self.accelerator.prepare(model, dummy_optimizer)
500505
return model
@@ -712,8 +717,8 @@ def _liger_loss_func(self, outputs, labels, num_items_in_batch=None, **kwargs):
712717
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
713718

714719
model = self.accelerator.unwrap_model(self.model)
715-
hidden_states = outputs.logits
716720
lm_head = self._get_lm_head(model)
721+
hidden_states = outputs.logits.to(lm_head.weight.dtype) # RMSNorm may upcast to fp32
717722

718723
def _compute():
719724
return LigerForCausalLMLoss(

0 commit comments

Comments
 (0)