2222from ...data .qat_dataset import QATDataset
2323from ...utils import patch_deepspeed_duplicate_check , print_info
2424from ..compressor_factory import CompressorFactory
25- from ..qat .plugins import PluginManager
26- from ..qat .qat import QAT
2725from .trainer import DistillSeq2SeqTrainer
2826
2927
30- def _unique_named_params (model , predicate ):
31- seen = set ()
32- result = []
33- for name , param in model .named_parameters ():
34- if id (param ) in seen or not predicate (name , param ):
35- continue
36- seen .add (id (param ))
37- result .append (param )
38- return result
39-
40-
4128def _normalize_device_map (device_map ):
4229 if isinstance (device_map , str ) and device_map .lower () in ("none" , "distributed" ):
4330 return None
4431 return device_map
4532
4633
4734@CompressorFactory .register
48- class Distill (QAT ):
35+ class Distill :
36+ """Full-precision knowledge distillation.
37+
38+ Quantized-student distillation lives in ``angelslim.compressor.qad``.
39+ Keeping this path fp-only prevents it from inheriting QAT state or save
40+ semantics by accident.
41+ """
42+
4943 def __init__ (self , model , slim_config = None ):
5044 self .quant_model = model
5145 self .config = slim_config
5246 self .distill_config = slim_config ["compress_config" ].Distill
53- self .student_type = self .distill_config . student_type .lower ()
47+ self .student_type = getattr ( self .distill_config , " student_type" , "fp" ) .lower ()
5448 self .trainable_parameters = self .distill_config .trainable_parameters .lower ()
5549 self .save_fmt = self .distill_config .save_format
56- self .plugin_config = self .distill_config .plugin_config
57- self .plugin_manager = PluginManager ()
5850 self .trainer = SimpleNamespace (external_trainer = None )
59- self ._rank0_state_dict = None
6051 self .teacher_model = None
6152 self .train_dataset = None
6253
6354 self ._validate_config ()
64- self .is_quantized_student = self .student_type == "quantized"
65- if self .is_quantized_student :
66- self .quant_model .init_ptq (slim_config )
67- self .quant_info = self .quant_model .quant_config
68- self ._init_plugins ()
69- else :
70- self .quant_info = None
7155
7256 def _validate_config (self ):
7357 if not self .distill_config .teacher_model_path :
7458 raise ValueError ("Distill requires compression.Distill.teacher_model_path." )
75- if self .student_type not in ("fp" , "quantized" ):
76- raise ValueError ("Distill student_type must be 'fp' or 'quantized'." )
77- if self .trainable_parameters not in ("all" , "quant" ):
78- raise ValueError ("Distill trainable_parameters must be 'all' or 'quant'." )
79- if self .student_type == "fp" and self .trainable_parameters == "quant" :
80- raise ValueError ("trainable_parameters='quant' requires a quantized student." )
81-
82- def _init_plugins (self ):
83- if self .plugin_config .get ("enable_scale" , False ):
84- self .plugin_manager .register_plugin (
85- "learnable_scale" ,
86- quant_info = self .quant_info ,
87- ignore_layers = self .config ["compress_config" ].quantization .ignore_layers ,
88- resume_ckpt_dir = self .distill_config .resume_ckpt_dir ,
89- from_ptq_ckpt_dir = self .distill_config .from_ptq_ckpt ,
90- config = self .plugin_config .get ("quant_config" , {}),
91- quant_model = self .quant_model ,
59+ if self .student_type != "fp" :
60+ raise ValueError ("Distill only supports fp students. Use QAD for quantized students." )
61+ if self .trainable_parameters != "all" :
62+ raise ValueError (
63+ "Distill trainable_parameters must be 'all'. Use QAD for quant params."
9264 )
9365
9466 def _prepare_dataset (self , dataloader ):
@@ -128,49 +100,11 @@ def _load_teacher_model(self):
128100
129101 def _apply_trainable_parameters (self ):
130102 model = self .quant_model .model
131- if self .trainable_parameters == "all" :
132- for param in model .parameters ():
133- param .requires_grad = True
134- return
135-
136- if not any (param .requires_grad for param in model .parameters ()):
137- raise ValueError ("Distill quant optimizer has no trainable parameters." )
103+ for param in model .parameters ():
104+ param .requires_grad = True
138105
139106 def _init_optimizer (self ):
140- if self .trainable_parameters == "all" :
141- return None
142-
143- lr = float (self .distill_config .hf_args .get ("learning_rate" , 1e-5 ))
144- wd = float (self .distill_config .hf_args .get ("weight_decay" , 0 ))
145- lwc_names = ("clip_factor_w_max" , "clip_factor_w_min" )
146- base_params = _unique_named_params (
147- self .quant_model .model ,
148- lambda n , p : p .requires_grad and not any (key in n for key in lwc_names ),
149- )
150- params = [{"params" : base_params , "weight_decay" : wd , "lr" : lr }]
151-
152- lwc_params = _unique_named_params (
153- self .quant_model .model ,
154- lambda n , p : p .requires_grad and any (key in n for key in lwc_names ),
155- )
156- if lwc_params :
157- lwc_lr = float (
158- self .plugin_config .get ("quant_config" , {}).get ("lwc" , {}).get ("lwc_lr" , lr )
159- )
160- params .append ({"params" : lwc_params , "weight_decay" : wd , "lr" : lwc_lr })
161- print_info (
162- f"Init distill optimizer with { len (base_params )} params, "
163- f"{ len (lwc_params )} lwc params, lr={ lr } , lwc_lr={ lwc_lr } , weight_decay={ wd } "
164- )
165- else :
166- print_info (
167- f"Init distill optimizer with { len (base_params )} params, "
168- f"lr={ lr } , weight_decay={ wd } "
169- )
170-
171- if not any (group ["params" ] for group in params ):
172- raise ValueError ("Distill optimizer has no trainable parameters." )
173- return torch .optim .AdamW (params )
107+ return None
174108
175109 def _prepare_trainer (self , place_teacher_on_device ):
176110 optimizer = self ._init_optimizer ()
@@ -212,9 +146,6 @@ def _load_resume_checkpoint(self):
212146
213147 def run (self , dataloader ):
214148 self ._prepare_dataset (dataloader )
215- if self .is_quantized_student :
216- self .plugin_manager .call_before_train (train_dataset = self .train_dataset )
217-
218149 self ._apply_trainable_parameters ()
219150 self ._load_resume_checkpoint ()
220151 self .teacher_model , place_teacher_on_device = self ._load_teacher_model ()
@@ -223,17 +154,10 @@ def run(self, dataloader):
223154 if self .distill_config .do_train :
224155 self .trainer .external_trainer .train ()
225156
226- if self .is_quantized_student :
227- self .plugin_manager .call_after_train ()
228-
229157 def convert (self ):
230- if self .is_quantized_student :
231- super ().convert ()
158+ return None
232159
233160 def save (self , save_path : str ):
234- if self .is_quantized_student :
235- return super ().save (save_path )
236-
237161 if self .save_fmt not in ("hf" , "real" , "full" ):
238162 print_info ("Save format not specified, skip save." )
239163 return None
0 commit comments