Skip to content

Commit be7d8e2

Browse files
committed
Add ModelOptHFTrainer and simplify KDTrainer distillation API
Introduce ModelOptHFTrainer wrapping HF Trainer with modelopt features (quantization, LR config, trainable/frozen param globs, save_dtype config rewrite, Liger fused CE, manual GC, etc.) and simplify the KDTrainer distillation API on top of it. Also includes follow-up fixes applied during review: - Causal shift fix and forward-restore safety in KDTrainer - DeepSpeed ZeRO-3 support in KDTrainer; Liger hidden-states dtype fix - save_dtype defaults to "bfloat16"; config.json rewrite skipped when save_dtype is None - Narrowed exceptions, moved defaults to configs, fixed recipe.quantize reference in transformers_trainer.py Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent d94e64e commit be7d8e2

File tree

22 files changed

+1028
-265
lines changed

22 files changed

+1028
-265
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ repos:
140140
hooks:
141141
- id: generate-arguments-md
142142
name: Regenerate examples/llm_qat/ARGUMENTS.md
143-
entry: bash -c 'python examples/llm_qat/train.py --generate_docs examples/llm_qat/ARGUMENTS.md'
143+
entry: bash -c 'python -c "import modelopt" 2>/dev/null && python examples/llm_qat/train.py --generate_docs examples/llm_qat/ARGUMENTS.md || echo
144+
"Skipping ARGUMENTS.md generation (modelopt not installed)"'
144145
language: system
145146
files: >-
146147
(?x)^(

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ Changelog
66

77
**New Features**
88

9+
- Add model-agnostic `Liger kernel <https://github.com/linkedin/Liger-Kernel>`_ fused loss support in ``ModelOptHFTrainer`` for any HuggingFace causal LM, with distributed param gathering for FSDP2, DeepSpeed ZeRO-3, and DDP. Extends HuggingFace's built-in Liger integration which is limited to `a fixed set of model architectures <https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/monkey_patch.py>`_, FSDP only, and CrossEntropy loss. ModelOpt additionally supports Liger fused KD loss (JSD) for knowledge distillation.
10+
- Add ``ModelOptTrainerArguments`` to ``ModelOptHFTrainer`` with ``--trainable_params``, ``--frozen_params``, ``--lr_config``, ``--save_dtype``, and ``--manual_gc`` flags. Add per-parameter learning rate support via YAML config.
11+
- Simplify ``KDTrainer`` for HuggingFace knowledge distillation: remove ``mtd.convert()`` class-swap in favor of explicit teacher forwarding with logit-level distillation support.
912
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
1013
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/puzzletron>`_ for more details.
1114
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.

examples/llm_distill/main.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from trl import SFTTrainer
2828

2929
import modelopt.torch.opt as mto
30-
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss
30+
from modelopt.torch.distill.plugins.huggingface import KDTrainer
3131

3232
logger = get_logger(__name__, log_level="INFO")
3333

@@ -115,12 +115,6 @@ def train():
115115
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
116116
)
117117

118-
# Distillation configuration
119-
kd_config = {
120-
"teacher_model": teacher_model,
121-
"criterion": LMLogitsLoss(),
122-
}
123-
124118
# Fix problematic settings that logger.info excessive warnings
125119
model.generation_config.temperature = None
126120
model.generation_config.top_p = None
@@ -129,7 +123,7 @@ def train():
129123
trainer = KDSFTTrainer(
130124
model,
131125
training_args,
132-
distill_config=kd_config,
126+
distill_args={"teacher_model": teacher_model},
133127
train_dataset=dset_train,
134128
eval_dataset=dset_eval,
135129
formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer),

examples/llm_qat/ARGUMENTS.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ _Auto-generated — do not edit by hand._
77
| Argument | Type | Default | Description |
88
|----------|------|---------|-------------|
99
| `--distill` | `bool` | `False` | Enable training with knowledge distillation. |
10-
| `--teacher_model` | `str` | `None` | The name or path of the teacher model to use for distillation. |
10+
| `--teacher_model` | `str` | `None` | The name or path of the teacher model. |
1111
| `--criterion` | `str` | `"logits_loss"` | Distillation loss criterion. Currently only 'logits_loss' is supported. |
12+
| `--temperature` | `float` | `1.0` | Softmax temperature for softening logits in KD loss. Used by both standard and Liger KD loss. |
13+
| `--liger_jsd_beta` | `float` | `0.0` | JSD beta coefficient in [0, 1]. 0=forward KL, 1=reverse KL. Only used when --use_liger_kernel is enabled. |
1214

1315
## DataArguments
1416

@@ -27,8 +29,9 @@ _Auto-generated — do not edit by hand._
2729

2830
| Argument | Type | Default | Description |
2931
|----------|------|---------|-------------|
30-
| `--model_name_or_path` | `str` | `"meta-llama/Llama-2-7b-hf"` | HuggingFace model name or local path to the base model to quantize/train. |
31-
| `--model_max_length` | `int` | `4096` | Maximum sequence length. Sequences will be right-padded (and possibly truncated). |
32+
| `--model_name_or_path` | `str` | `"meta-llama/Llama-2-7b-hf"` | HuggingFace model ID or local path to a pretrained model. |
33+
| `--model_max_length` | `int` | `8192` | Maximum sequence length. Sequences will be right-padded (and possibly truncated). |
34+
| `--attn_implementation` | `str` | `None` | Attention implementation: 'flash_attention_2', 'flash_attention_3', 'sdpa', or 'eager'. |
3235

3336
## QuantizeArguments
3437

@@ -46,5 +49,10 @@ Extends [HuggingFace TrainingArguments](https://huggingface.co/docs/transformers
4649

4750
| Argument | Type | Default | Description |
4851
|----------|------|---------|-------------|
49-
| `--cache_dir` | `str` | `None` | |
52+
| `--trainable_params` | `list[str]` | `None` | Glob patterns (fnmatch) for parameters that should be trainable. All other parameters will be frozen. Mutually exclusive with frozen_params. |
53+
| `--frozen_params` | `list[str]` | `None` | Glob patterns (fnmatch) for parameters that should be frozen. Mutually exclusive with trainable_params. |
54+
| `--lr_config` | `str` | `None` | Path to a YAML file mapping fnmatch patterns to optimizer kwargs (e.g. lr, weight_decay). First matching pattern wins per parameter. See examples/llm_qat/configs/train/lr_config_example.yaml. |
55+
| `--save_dtype` | `str` | `"bfloat16"` | Dtype string to write into the saved model's config.json (e.g. 'bfloat16', 'float16'). Set to None to preserve the original dtype. |
56+
| `--manual_gc` | `bool` | `False` | Run `gc.collect()` before each training/prediction step to work around GPU memory leaks during QAT/distillation. |
57+
| `--liger_ce_label_smoothing` | `float` | `0.0` | Label smoothing for Liger fused CE loss. Only used when --use_liger_kernel is enabled. |
5058
| `--lora` | `bool` | `False` | Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, the LoRA adapter must be set, as quantized weights will be frozen during training. |

examples/llm_qat/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,23 @@ trainer.train()
140140
trainer.save_model()
141141
```
142142
143-
`QADTrainer` extends `QATTrainer` with distillation:
143+
`QADTrainer` extends `QATTrainer` with distillation. Pass the teacher model and a `DistillArguments` instance:
144144
145145
```python
146-
from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss
146+
from modelopt.torch.distill.plugins.huggingface import DistillArguments
147147
from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer
148148
149-
distill_config = {
150-
"teacher_model": teacher_model,
151-
"criterion": LMLogitsLoss(),
152-
}
149+
distill_args = DistillArguments(
150+
distill=True,
151+
teacher_model="Qwen/Qwen3-8B",
152+
criterion="logits_loss",
153+
)
153154
154155
trainer = QADTrainer(
155156
model=model, # pre-quantized model
156157
processing_class=tokenizer,
157158
args=training_args,
158-
distill_config=distill_config,
159+
distill_args=distill_args,
159160
**data_module,
160161
)
161162
trainer.train()

examples/llm_qat/arguments.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,31 @@
1919

2020
import transformers
2121

22-
from modelopt.torch.opt.plugins.transformers import ModelOptHFArguments
22+
from modelopt.torch.opt.plugins.transformers import ModelOptHFArguments, ModelOptTrainerArguments
2323

2424

2525
class ModelArguments(ModelOptHFArguments):
2626
model_name_or_path: str = field(
2727
default="meta-llama/Llama-2-7b-hf",
28-
metadata={
29-
"help": "HuggingFace model name or local path to the base model to quantize/train."
30-
},
28+
metadata={"help": "HuggingFace model ID or local path to a pretrained model."},
3129
)
3230
model_max_length: int = field(
33-
default=4096,
31+
default=8192,
3432
metadata={
3533
"help": (
3634
"Maximum sequence length. Sequences will be right-padded (and possibly truncated)."
3735
)
3836
},
3937
)
38+
attn_implementation: str | None = field(
39+
default=None,
40+
metadata={
41+
"help": (
42+
"Attention implementation: 'flash_attention_2', 'flash_attention_3', "
43+
"'sdpa', or 'eager'."
44+
)
45+
},
46+
)
4047

4148

4249
class DataArguments(ModelOptHFArguments):
@@ -74,10 +81,13 @@ class DataArguments(ModelOptHFArguments):
7481
)
7582

7683

77-
class TrainingArguments(ModelOptHFArguments, transformers.TrainingArguments):
78-
cache_dir: str | None = field(default=None)
84+
class TrainingArguments(ModelOptTrainerArguments, transformers.TrainingArguments):
7985
dataloader_drop_last: bool = field(default=True)
8086
bf16: bool = field(default=True)
87+
use_liger_kernel: bool = field(
88+
default=True,
89+
metadata={"help": "Use Liger kernel for fused loss computation. Reduces memory usage."},
90+
)
8191
lora: bool = field(
8292
default=False,
8393
metadata={

examples/llm_qat/configs/train/finetune.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ learning_rate: 1e-5
1515
per_device_train_batch_size: 2
1616
per_device_eval_batch_size: 2
1717
gradient_accumulation_steps: 2
18-
model_max_length: 4096
18+
model_max_length: 8192
1919
warmup_ratio: 0.05
2020
lr_scheduler_type: cosine
21-
gradient_checkpointing: true
21+
use_liger_kernel: true
2222
seed: 42
2323

2424
# Checkpointing
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Per-parameter optimizer config example
2+
#
3+
# Maps fnmatch glob patterns to optimizer kwargs (lr, weight_decay, betas,
4+
# eps, etc.). First matching pattern wins per parameter. Parameters not
5+
# matching any pattern use the global values from the train config.
6+
#
7+
# Any keyword accepted by the optimizer constructor can be specified here.
8+
# Common kwargs for AdamW:
9+
# lr - learning rate
10+
# weight_decay - L2 penalty (overrides the global --weight_decay)
11+
# betas - Adam momentum coefficients [beta1, beta2]
12+
# eps - term added to denominator for numerical stability
13+
#
14+
# Usage:
15+
# --lr_config configs/train/lr_config_example.yaml
16+
#
17+
# Tip: use `model.named_parameters()` to find the exact parameter names
18+
# for your model.
19+
20+
# Output head — lower LR, no weight decay
21+
"*lm_head*":
22+
lr: 1e-5
23+
weight_decay: 0.0
24+
25+
# Attention layers — custom LR + more aggressive momentum
26+
"*self_attn*":
27+
lr: 5e-5
28+
betas: [0.9, 0.95]
29+
30+
# MLP layers — custom LR + higher weight decay
31+
"*mlp*":
32+
lr: 5e-5
33+
weight_decay: 0.05
34+
35+
# Embedding layers (often kept at a lower LR or frozen)
36+
"*embed_tokens*":
37+
lr: 1e-6
38+
weight_decay: 0.0
39+
eps: 1e-7

examples/llm_qat/configs/train/qad_nvfp4.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Model
44
model_name_or_path: # e.g., Qwen/Qwen3-8B
55
output_dir: # e.g., qwen3-8b-qad-nvfp4
6+
attn_implementation: flash_attention_2
67

78
# Quantization
89
recipe: general/ptq/nvfp4_default-fp8_kv
@@ -22,10 +23,11 @@ learning_rate: 1e-5
2223
per_device_train_batch_size: 2
2324
per_device_eval_batch_size: 2
2425
gradient_accumulation_steps: 2
25-
model_max_length: 4096
26+
model_max_length: 8192
2627
warmup_ratio: 0.05
2728
lr_scheduler_type: cosine
28-
gradient_checkpointing: true
29+
use_liger_kernel: true
30+
manual_gc: true
2931
seed: 42
3032
do_train: true
3133
do_eval: true

examples/llm_qat/configs/train/qat_nvfp4.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Model
44
model_name_or_path: # e.g., Qwen/Qwen3-8B
55
output_dir: # e.g., qwen3-8b-qat-nvfp4
6+
attn_implementation: flash_attention_2
67

78
# Quantization
89
recipe: general/ptq/nvfp4_default-fp8_kv
@@ -18,10 +19,11 @@ learning_rate: 1e-5
1819
per_device_train_batch_size: 2
1920
per_device_eval_batch_size: 2
2021
gradient_accumulation_steps: 2
21-
model_max_length: 4096
22+
model_max_length: 8192
2223
warmup_ratio: 0.05
2324
lr_scheduler_type: cosine
24-
gradient_checkpointing: true
25+
use_liger_kernel: true
26+
manual_gc: true
2527
seed: 42
2628
do_train: true
2729
do_eval: true

0 commit comments

Comments
 (0)