Skip to content

Commit 1fa937b

Browse files
author
hongchao
committed
fix review
1 parent 8c0a548 commit 1fa937b

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

checkpoint_engine/worker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,18 @@ def _load_weights(weights: _WEIGHTS_TYPE):
171171
# Load main model weights
172172
self.model_runner.model.load_weights(weights)
173173
# Load drafter model weights if MTP/speculative decoding is enabled
174-
if hasattr(self.model_runner, "drafter") and hasattr(
175-
self.model_runner.drafter, "model"
174+
if (
175+
getattr(self.model_runner, "drafter", None) is not None
176+
and getattr(self.model_runner.drafter, "model", None) is not None
176177
):
177178
self.model_runner.drafter.model.load_weights(weights=weights)
178179

179180
def _post_hook():
180181
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
181182
# Also trigger drafter model's post processing if MTP is enabled
182-
if hasattr(self.model_runner, "drafter") and hasattr(
183-
self.model_runner.drafter, "model"
183+
if (
184+
getattr(self.model_runner, "drafter", None) is not None
185+
and getattr(self.model_runner.drafter, "model", None) is not None
184186
):
185187
process_weights_after_loading(
186188
self.model_runner.drafter.model, self.model_config, self.device

0 commit comments

Comments
 (0)