You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support DeepSpeed ZeRO-3 in KDTrainer; fix Liger hidden-states dtype
- Add fully-frozen-model fallback in ModelOptHFTrainer._prepare_model so
DS ZeRO-3 can prepare a frozen teacher without hitting the empty
trainable_param_groups assertion.
- Add KDTrainer._ds_gather context manager for explicit param gather,
since the teacher is loaded under zero.Init but not wrapped in a
DeepSpeedEngine (no per-module hooks).
- Unify KD sharded Liger compute: delegate student lm_head gather to
the parent's _sharded_liger_compute and add teacher lm_head gather
via _apply_teacher_gather.
- Cast outputs.logits to lm_head.weight dtype before Liger fused kernels
(final RMSNorm may leave hidden_states in fp32).
- Drop redundant KDTrainer._get_lm_head override (inherited).
Signed-off-by: realAsma <akuriparambi@nvidia.com>
0 commit comments