Skip to content

Commit d9fe3d7

Browse files
committed
polish
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 01be4d9 commit d9fe3d7

2 files changed

Lines changed: 22 additions & 14 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,21 +151,16 @@ def train():
151151
trust_remote_code=recipe.model.trust_remote_code,
152152
)
153153
if isinstance(recipe, ModelOptMedusaRecipe):
154-
mtsp.convert(model, [("medusa", recipe.medusa.model_dump())])
154+
medusa_cfg: dict = recipe.medusa.model_dump()
155+
mtsp.convert(model, [("medusa", medusa_cfg)])
155156
elif isinstance(recipe, ModelOptEagleRecipe):
156-
eagle_cfg = recipe.eagle.model_dump()
157+
eagle_cfg: dict = recipe.eagle.model_dump()
157158
mtsp.convert(model, [("eagle", eagle_cfg)])
158-
159-
# Load draft vocab cache if the draft model uses a compressed vocabulary
160-
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
161-
d2t = recipe.data.draft_vocab_cache
162-
if d2t is None or not os.path.isfile(d2t):
163-
raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t}")
164-
model.eagle_module.d2t = torch.load(d2t, weights_only=True)
165-
print_rank_0(f"Loaded draft vocab cache from {d2t}.")
159+
# Load draft vocab cache
160+
mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache)
166161
elif isinstance(recipe, ModelOptDFlashRecipe):
167162
# Re-validate with tokenizer to resolve dflash_mask_token_id and enforce its presence.
168-
dflash_cfg = DFlashConfig.model_validate(
163+
dflash_cfg: dict = DFlashConfig.model_validate(
169164
recipe.dflash.model_dump(), context={"tokenizer": tokenizer}
170165
).model_dump()
171166
mtsp.convert(model, [("dflash", dflash_cfg)])

modelopt/torch/speculative/plugins/hf_eagle.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import contextlib
1919
import copy
20+
import os
2021
from typing import Any
2122

2223
import torch
@@ -25,6 +26,8 @@
2526
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
2627
from transformers.utils import ModelOutput
2728

29+
from modelopt.torch.utils import print_rank_0
30+
2831
from ...export.plugins.hf_spec_export import EagleExporter, SpeculativeDecodingExporter
2932
from ..eagle.conversion import EagleDMRegistry
3033
from ..eagle.eagle_model import EagleModel
@@ -88,7 +91,7 @@ def _nvtx_range(self, name):
8891

8992
return nvtx.range(name)
9093
except Exception as e:
91-
print(f"Failed to create NVTX range {name}: {e}")
94+
print_rank_0(f"Failed to create NVTX range {name}: {e}")
9295
return contextlib.nullcontext()
9396

9497
def _find_base_model_parts(self):
@@ -105,7 +108,7 @@ def _find_base_model_parts(self):
105108
try:
106109
submodule = self.get_submodule(path)
107110
assert isinstance(submodule, torch.nn.Module)
108-
print(f"Found {name} at {path}")
111+
print_rank_0(f"Found {name} at {path}")
109112
found_submodule = True
110113
setattr(self, name, path)
111114
break
@@ -128,7 +131,7 @@ def _activate_torch_compile(self):
128131
try:
129132
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
130133
except Exception: # noqa: PERF203
131-
print(f"Disabling torch.compile for {name} due to compilation error.")
134+
print_rank_0(f"Disabling torch.compile for {name} due to compilation error.")
132135

133136
def get_dummy_inputs(self) -> dict:
134137
"""Construct dummy inputs for export forward pass."""
@@ -250,6 +253,16 @@ def _preservation_loss(
250253
)
251254
return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight
252255

256+
@staticmethod
257+
def load_draft_vocab_cache(model, d2t_path: str) -> None:
258+
"""Load the draft vocab cache from the given path."""
259+
if d2t_path is None or model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size:
260+
return
261+
if not os.path.isfile(d2t_path):
262+
raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t_path}")
263+
model.eagle_module.d2t = torch.load(d2t_path, weights_only=True)
264+
print_rank_0(f"Loaded draft vocab cache from {d2t_path}.")
265+
253266
def modify(
254267
self,
255268
config,

0 commit comments

Comments
 (0)