Skip to content

Commit 1f1c250

Browse files
realAsmaclaude
andcommitted
Address PR review: causal shift fix, save_dtype default, forward restore safety
- Fix causal shift inconsistency between _standard_kd_loss and _liger_kd_loss - Change save_dtype default from "bfloat16" to None (preserve original dtype) - Add try/except in _forward_redirect to restore module.forward on failure - Skip ARGUMENTS.md pre-commit hook gracefully when modelopt not installed Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 4c5a889 commit 1f1c250

4 files changed

Lines changed: 18 additions & 11 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ repos:
139139
hooks:
140140
- id: generate-arguments-md
141141
name: Regenerate examples/llm_qat/ARGUMENTS.md
142-
entry: bash -c 'python examples/llm_qat/train.py --generate_docs examples/llm_qat/ARGUMENTS.md'
142+
entry: bash -c 'python -c "import modelopt" 2>/dev/null && python examples/llm_qat/train.py --generate_docs examples/llm_qat/ARGUMENTS.md || echo
143+
"Skipping ARGUMENTS.md generation (modelopt not installed)"'
143144
language: system
144145
files: >-
145146
(?x)^(

examples/llm_qat/ARGUMENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Extends [HuggingFace TrainingArguments](https://huggingface.co/docs/transformers
5252
| `--trainable_params` | `list[str]` | `None` | Glob patterns (fnmatch) for parameters that should be trainable. All other parameters will be frozen. Mutually exclusive with frozen_params. |
5353
| `--frozen_params` | `list[str]` | `None` | Glob patterns (fnmatch) for parameters that should be frozen. Mutually exclusive with trainable_params. |
5454
| `--lr_config` | `str` | `None` | Path to a YAML file mapping fnmatch patterns to optimizer kwargs (e.g. lr, weight_decay). First matching pattern wins per parameter. See examples/llm_qat/configs/train/lr_config_example.yaml. |
55-
| `--save_dtype` | `str` | `"bfloat16"` | Dtype string to write into the saved model's config.json (e.g. 'bfloat16', 'float16'). Defaults to 'bfloat16'. |
55+
| `--save_dtype` | `str` | `None` | Dtype string to write into the saved model's config.json (e.g. 'bfloat16', 'float16'). Preserves the original dtype when not set. |
5656
| `--manual_gc` | `bool` | `False` | Run `gc.collect()` before each training/prediction step to work around GPU memory leaks during QAT/distillation. |
5757
| `--liger_ce_label_smoothing` | `float` | `0.0` | Label smoothing for Liger fused CE loss. Only used when --use_liger_kernel is enabled. |
5858
| `--lora` | `bool` | `False` | Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, the LoRA adapter must be set, as quantized weights will be frozen during training. |

modelopt/torch/distill/plugins/huggingface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,15 @@ def compute_kd_loss_func(self, outputs, labels, **kwargs):
180180
return self._standard_kd_loss(outputs, labels, **kwargs)
181181

182182
def _standard_kd_loss(self, outputs, labels, **kwargs):
183-
"""KD loss with ignore-index masking."""
184-
student_logits = outputs.logits.float()
185-
teacher_logits = self._last_teacher_outputs.logits.float()
183+
"""KD loss with causal shift and ignore-index masking."""
184+
# Causal LM shift (match _liger_kd_loss semantics)
185+
student_logits = outputs.logits[..., :-1, :].contiguous().float()
186+
teacher_logits = self._last_teacher_outputs.logits[..., :-1, :].contiguous().float()
186187
per_token_loss = self._kd_criterion(student_logits, teacher_logits)
187188
if labels is None:
188189
return per_token_loss.sum()
189-
mask = labels != IGNORE_INDEX
190+
shift_labels = labels[..., 1:].contiguous()
191+
mask = shift_labels != IGNORE_INDEX
190192
loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1)
191193
self._last_teacher_outputs = None
192194
return loss

modelopt/torch/opt/plugins/transformers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,12 @@ class ModelOptTrainerArguments(ModelOptHFArguments):
220220
),
221221
},
222222
)
223-
save_dtype: str = dataclasses.field(
224-
default="bfloat16",
223+
save_dtype: str | None = dataclasses.field(
224+
default=None,
225225
metadata={
226226
"help": (
227227
"Dtype string to write into the saved model's config.json "
228-
"(e.g. 'bfloat16', 'float16'). Defaults to 'bfloat16'."
228+
"(e.g. 'bfloat16', 'float16'). Preserves the original dtype when not set."
229229
),
230230
},
231231
)
@@ -433,8 +433,12 @@ def wrapped_forward(*a, **kw):
433433
return fn()
434434

435435
module.forward = wrapped_forward
436-
dummy = torch.empty(1, device=next(module.parameters()).device)
437-
return module(dummy)
436+
try:
437+
dummy = torch.empty(1, device=next(module.parameters()).device)
438+
return module(dummy)
439+
except Exception:
440+
module.forward = original_forward
441+
raise
438442

439443

440444
class ModelOptHFTrainer(Trainer):

0 commit comments

Comments
 (0)