Skip to content

Commit bca66db

Browse files
authored
split distill with qad and distill (#313)
1 parent 089576d commit bca66db

27 files changed

Lines changed: 553 additions & 295 deletions

angelslim/compressor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414

1515
from .compressor_factory import CompressorFactory # noqa: F401
1616
from .distill import Distill # noqa: F401
17+
from .qad import QAD # noqa: F401
1718
from .qat.qat import QAT # noqa: F401
1819
from .quant import PTQ # noqa: F401

angelslim/compressor/distill/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
# limitations under the License.
1414

1515
from .distill import Distill # noqa: F401
16+
from .loss import DistillLoss # noqa: F401
1617

17-
__all__ = ["Distill"]
18+
__all__ = ["Distill", "DistillLoss"]

angelslim/compressor/distill/distill.py

Lines changed: 18 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -22,73 +22,45 @@
2222
from ...data.qat_dataset import QATDataset
2323
from ...utils import patch_deepspeed_duplicate_check, print_info
2424
from ..compressor_factory import CompressorFactory
25-
from ..qat.plugins import PluginManager
26-
from ..qat.qat import QAT
2725
from .trainer import DistillSeq2SeqTrainer
2826

2927

30-
def _unique_named_params(model, predicate):
31-
seen = set()
32-
result = []
33-
for name, param in model.named_parameters():
34-
if id(param) in seen or not predicate(name, param):
35-
continue
36-
seen.add(id(param))
37-
result.append(param)
38-
return result
39-
40-
4128
def _normalize_device_map(device_map):
4229
if isinstance(device_map, str) and device_map.lower() in ("none", "distributed"):
4330
return None
4431
return device_map
4532

4633

4734
@CompressorFactory.register
48-
class Distill(QAT):
35+
class Distill:
36+
"""Full-precision knowledge distillation.
37+
38+
Quantized-student distillation lives in ``angelslim.compressor.qad``.
39+
Keeping this path fp-only prevents it from inheriting QAT state or save
40+
semantics by accident.
41+
"""
42+
4943
def __init__(self, model, slim_config=None):
5044
self.quant_model = model
5145
self.config = slim_config
5246
self.distill_config = slim_config["compress_config"].Distill
53-
self.student_type = self.distill_config.student_type.lower()
47+
self.student_type = getattr(self.distill_config, "student_type", "fp").lower()
5448
self.trainable_parameters = self.distill_config.trainable_parameters.lower()
5549
self.save_fmt = self.distill_config.save_format
56-
self.plugin_config = self.distill_config.plugin_config
57-
self.plugin_manager = PluginManager()
5850
self.trainer = SimpleNamespace(external_trainer=None)
59-
self._rank0_state_dict = None
6051
self.teacher_model = None
6152
self.train_dataset = None
6253

6354
self._validate_config()
64-
self.is_quantized_student = self.student_type == "quantized"
65-
if self.is_quantized_student:
66-
self.quant_model.init_ptq(slim_config)
67-
self.quant_info = self.quant_model.quant_config
68-
self._init_plugins()
69-
else:
70-
self.quant_info = None
7155

7256
def _validate_config(self):
7357
if not self.distill_config.teacher_model_path:
7458
raise ValueError("Distill requires compression.Distill.teacher_model_path.")
75-
if self.student_type not in ("fp", "quantized"):
76-
raise ValueError("Distill student_type must be 'fp' or 'quantized'.")
77-
if self.trainable_parameters not in ("all", "quant"):
78-
raise ValueError("Distill trainable_parameters must be 'all' or 'quant'.")
79-
if self.student_type == "fp" and self.trainable_parameters == "quant":
80-
raise ValueError("trainable_parameters='quant' requires a quantized student.")
81-
82-
def _init_plugins(self):
83-
if self.plugin_config.get("enable_scale", False):
84-
self.plugin_manager.register_plugin(
85-
"learnable_scale",
86-
quant_info=self.quant_info,
87-
ignore_layers=self.config["compress_config"].quantization.ignore_layers,
88-
resume_ckpt_dir=self.distill_config.resume_ckpt_dir,
89-
from_ptq_ckpt_dir=self.distill_config.from_ptq_ckpt,
90-
config=self.plugin_config.get("quant_config", {}),
91-
quant_model=self.quant_model,
59+
if self.student_type != "fp":
60+
raise ValueError("Distill only supports fp students. Use QAD for quantized students.")
61+
if self.trainable_parameters != "all":
62+
raise ValueError(
63+
"Distill trainable_parameters must be 'all'. Use QAD for quant params."
9264
)
9365

9466
def _prepare_dataset(self, dataloader):
@@ -128,49 +100,11 @@ def _load_teacher_model(self):
128100

129101
def _apply_trainable_parameters(self):
130102
model = self.quant_model.model
131-
if self.trainable_parameters == "all":
132-
for param in model.parameters():
133-
param.requires_grad = True
134-
return
135-
136-
if not any(param.requires_grad for param in model.parameters()):
137-
raise ValueError("Distill quant optimizer has no trainable parameters.")
103+
for param in model.parameters():
104+
param.requires_grad = True
138105

139106
def _init_optimizer(self):
140-
if self.trainable_parameters == "all":
141-
return None
142-
143-
lr = float(self.distill_config.hf_args.get("learning_rate", 1e-5))
144-
wd = float(self.distill_config.hf_args.get("weight_decay", 0))
145-
lwc_names = ("clip_factor_w_max", "clip_factor_w_min")
146-
base_params = _unique_named_params(
147-
self.quant_model.model,
148-
lambda n, p: p.requires_grad and not any(key in n for key in lwc_names),
149-
)
150-
params = [{"params": base_params, "weight_decay": wd, "lr": lr}]
151-
152-
lwc_params = _unique_named_params(
153-
self.quant_model.model,
154-
lambda n, p: p.requires_grad and any(key in n for key in lwc_names),
155-
)
156-
if lwc_params:
157-
lwc_lr = float(
158-
self.plugin_config.get("quant_config", {}).get("lwc", {}).get("lwc_lr", lr)
159-
)
160-
params.append({"params": lwc_params, "weight_decay": wd, "lr": lwc_lr})
161-
print_info(
162-
f"Init distill optimizer with {len(base_params)} params, "
163-
f"{len(lwc_params)} lwc params, lr={lr}, lwc_lr={lwc_lr}, weight_decay={wd}"
164-
)
165-
else:
166-
print_info(
167-
f"Init distill optimizer with {len(base_params)} params, "
168-
f"lr={lr}, weight_decay={wd}"
169-
)
170-
171-
if not any(group["params"] for group in params):
172-
raise ValueError("Distill optimizer has no trainable parameters.")
173-
return torch.optim.AdamW(params)
107+
return None
174108

175109
def _prepare_trainer(self, place_teacher_on_device):
176110
optimizer = self._init_optimizer()
@@ -212,9 +146,6 @@ def _load_resume_checkpoint(self):
212146

213147
def run(self, dataloader):
214148
self._prepare_dataset(dataloader)
215-
if self.is_quantized_student:
216-
self.plugin_manager.call_before_train(train_dataset=self.train_dataset)
217-
218149
self._apply_trainable_parameters()
219150
self._load_resume_checkpoint()
220151
self.teacher_model, place_teacher_on_device = self._load_teacher_model()
@@ -223,17 +154,10 @@ def run(self, dataloader):
223154
if self.distill_config.do_train:
224155
self.trainer.external_trainer.train()
225156

226-
if self.is_quantized_student:
227-
self.plugin_manager.call_after_train()
228-
229157
def convert(self):
230-
if self.is_quantized_student:
231-
super().convert()
158+
return None
232159

233160
def save(self, save_path: str):
234-
if self.is_quantized_student:
235-
return super().save(save_path)
236-
237161
if self.save_fmt not in ("hf", "real", "full"):
238162
print_info("Save format not specified, skip save.")
239163
return None

angelslim/compressor/qat/plugins/distill_loss.py renamed to angelslim/compressor/distill/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def compute(self, student_logits, teacher_logits, labels):
9696
kd = self._kl_from_logps(top_s_logp, top_t_logp).mean()
9797
else:
9898
raise ValueError(
99-
f"Unsupported QAT kd loss_type: {self.loss_type}. "
99+
f"Unsupported distill loss_type: {self.loss_type}. "
100100
"Valid: kl, rkl, mse, kd, cakld, kl_top[_K], r_kl_top[_K]."
101101
)
102102

angelslim/compressor/distill/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from transformers import Seq2SeqTrainer
1919

20-
from ..qat.plugins.distill_loss import DistillLoss
20+
from .loss import DistillLoss
2121

2222

2323
class DistillSeq2SeqTrainer(Seq2SeqTrainer):
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .qad import QAD # noqa: F401
16+
17+
__all__ = ["QAD"]

0 commit comments

Comments
 (0)