@@ -171,16 +171,18 @@ def _load_weights(weights: _WEIGHTS_TYPE):
171171 # Load main model weights
172172 self .model_runner .model .load_weights (weights )
173173 # Load drafter model weights if MTP/speculative decoding is enabled
174- if hasattr (self .model_runner , "drafter" ) and hasattr (
175- self .model_runner .drafter , "model"
174+ if (
175+ getattr (self .model_runner , "drafter" , None ) is not None
176+ and getattr (self .model_runner .drafter , "model" , None ) is not None
176177 ):
177178 self .model_runner .drafter .model .load_weights (weights = weights )
178179
179180 def _post_hook ():
180181 process_weights_after_loading (self .model_runner .model , self .model_config , self .device )
181182 # Also trigger drafter model's post processing if MTP is enabled
182- if hasattr (self .model_runner , "drafter" ) and hasattr (
183- self .model_runner .drafter , "model"
183+ if (
184+ getattr (self .model_runner , "drafter" , None ) is not None
185+ and getattr (self .model_runner .drafter , "model" , None ) is not None
184186 ):
185187 process_weights_after_loading (
186188 self .model_runner .drafter .model , self .model_config , self .device
0 commit comments