1717
1818import contextlib
1919import copy
20+ import os
2021from typing import Any
2122
2223import torch
2526from transformers .models .llama .modeling_llama import LlamaDecoderLayer
2627from transformers .utils import ModelOutput
2728
29+ from modelopt .torch .utils import print_rank_0
30+
2831from ...export .plugins .hf_spec_export import EagleExporter , SpeculativeDecodingExporter
2932from ..eagle .conversion import EagleDMRegistry
3033from ..eagle .eagle_model import EagleModel
@@ -88,7 +91,7 @@ def _nvtx_range(self, name):
8891
8992 return nvtx .range (name )
9093 except Exception as e :
91- print (f"Failed to create NVTX range { name } : { e } " )
94+ print_rank_0 (f"Failed to create NVTX range { name } : { e } " )
9295 return contextlib .nullcontext ()
9396
9497 def _find_base_model_parts (self ):
@@ -105,7 +108,7 @@ def _find_base_model_parts(self):
105108 try :
106109 submodule = self .get_submodule (path )
107110 assert isinstance (submodule , torch .nn .Module )
108- print (f"Found { name } at { path } " )
111+ print_rank_0 (f"Found { name } at { path } " )
109112 found_submodule = True
110113 setattr (self , name , path )
111114 break
@@ -128,7 +131,7 @@ def _activate_torch_compile(self):
128131 try :
129132 setattr (self , name , torch .compile (getattr (self , name ), dynamic = False , ** kwargs ))
130133 except Exception : # noqa: PERF203
131- print (f"Disabling torch.compile for { name } due to compilation error." )
134+ print_rank_0 (f"Disabling torch.compile for { name } due to compilation error." )
132135
133136 def get_dummy_inputs (self ) -> dict :
134137 """Construct dummy inputs for export forward pass."""
@@ -250,6 +253,16 @@ def _preservation_loss(
250253 )
251254 return - loss .sum (dim = - 1 ).mean () * self .eagle_base_lora_preservation_loss_weight
252255
256+ @staticmethod
257+ def load_draft_vocab_cache (model , d2t_path : str ) -> None :
258+ """Load the draft vocab cache from the given path."""
259+ if d2t_path is None or model .eagle_config .draft_vocab_size >= model .eagle_config .vocab_size :
260+ return
261+ if not os .path .isfile (d2t_path ):
262+ raise FileNotFoundError (f"Draft vocab cache provided but not found: { d2t_path } " )
263+ model .eagle_module .d2t = torch .load (d2t_path , weights_only = True )
264+ print_rank_0 (f"Loaded draft vocab cache from { d2t_path } ." )
265+
253266 def modify (
254267 self ,
255268 config ,
0 commit comments