@@ -76,9 +76,9 @@ class DataArguments:
7676 },
7777 )
7878 lazy_preprocess : bool = True
79- draft_vocab_cache_dir : str = field (
80- default = "draft_vocab_cache" ,
81- metadata = {"help" : "Path to the d2t cache directory ." },
79+ draft_vocab_cache : str = field (
80+ default = None ,
81+ metadata = {"help" : "Path to d2t.pt cache file ." },
8282 )
8383 vlm_img_dir : str = field (default = None , metadata = {"help" : "Path to the VLM image directory." })
8484 vlm_processor : str = field (default = None , metadata = {"help" : "Path to the VLM processor." })
@@ -97,7 +97,7 @@ class TrainingArguments(transformers.TrainingArguments):
9797 )
9898 dataloader_drop_last : bool = field (default = True )
9999 bf16 : bool = field (default = True )
100- mode : Literal ["eagle1" , " eagle3" , "medusa" ] = "eagle3"
100+ mode : Literal ["eagle3" , "medusa" ] = "eagle3"
101101 estimate_ar : bool = field (
102102 default = False , metadata = {"help" : "Whether to estimate AR during training for logging." }
103103 )
@@ -147,30 +147,35 @@ def train():
147147 training_args .parallelism_config .sp_backend = None
148148 print_rank_0 (f"arguments: { model_args } , { training_args } , { medusa_args } , { eagle_args } " )
149149
150- # Detecting last checkpoint.
151- last_checkpoint = None
152- if os .path .isdir (training_args .output_dir ):
153- last_checkpoint = get_last_checkpoint (training_args .output_dir )
150+ # Detect checkpoint to resume from
151+ last_checkpoint = (
152+ get_last_checkpoint (training_args .output_dir )
153+ if os .path .isdir (training_args .output_dir )
154+ else None
155+ )
156+ if last_checkpoint :
154157 print_rank_0 (f"Last checkpoint detected: { last_checkpoint } " )
155158
156- checkpoint = None
157- if training_args .resume_from_checkpoint is not None :
158- checkpoint = training_args .resume_from_checkpoint
159- elif last_checkpoint is not None :
160- checkpoint = last_checkpoint
159+ checkpoint = training_args .resume_from_checkpoint or last_checkpoint
161160
162161 use_offline_training = data_args .offline_data_path is not None
163162
163+ model_config = transformers .AutoConfig .from_pretrained (
164+ model_args .model_name_or_path , trust_remote_code = True
165+ )
166+ if "vl" in model_config .model_type .lower ():
167+ model_cls = transformers .AutoModelForVision2Seq
168+ else :
169+ model_cls = transformers .AutoModelForCausalLM
170+
164171 if checkpoint :
165- model = transformers .AutoModelForCausalLM .from_pretrained (
166- checkpoint , torch_dtype = "auto" , trust_remote_code = True
167- )
172+ model = model_cls .from_pretrained (checkpoint , torch_dtype = "auto" , trust_remote_code = True )
168173 tokenizer = transformers .AutoTokenizer .from_pretrained (checkpoint , trust_remote_code = True )
169174 else :
170175 # To avoid OOM for large models, we load and convert model on CPU first.
171176 # Model will be moved to GPU during HF trainer.init().
172177 offline_kwargs = {"num_hidden_layers" : 0 } if use_offline_training else {}
173- model = transformers . Qwen3VLForConditionalGeneration .from_pretrained (
178+ model = model_cls .from_pretrained (
174179 model_args .model_name_or_path ,
175180 torch_dtype = "auto" ,
176181 device_map = "cpu" ,
@@ -180,77 +185,38 @@ def train():
180185 if use_offline_training :
181186 # When doing offline training, we need to set num_hidden_layers
182187 # since we override it when loading the model for space savings
183- model_config = transformers .AutoConfig .from_pretrained (
184- model_args .model_name_or_path , trust_remote_code = True
185- )
186188 model .config .num_orig_hidden_layers = model_config .num_hidden_layers
187189 tokenizer = transformers .AutoTokenizer .from_pretrained (
188190 model_args .model_name_or_path ,
189191 model_max_length = training_args .training_seq_len ,
190192 trust_remote_code = True ,
191193 )
192- if tokenizer .chat_template is None :
193- tokenizer .chat_template = (
194- "{%- for message in messages %}"
195- "{{- '<|im_start|>' + message['role'] + '\n ' + message['content'] + '<|im_end|>' + '\n ' }}"
196- "{%- endfor %}"
197- )
198- if tokenizer .pad_token_id is None :
199- tokenizer .pad_token_id = tokenizer .eos_token_id
200-
201194 if training_args .mode == "medusa" :
202195 config = {
203196 "medusa_num_heads" : medusa_args .medusa_num_heads ,
204197 "medusa_num_layers" : medusa_args .medusa_num_layers ,
205198 }
206199 mtsp .convert (model , [("medusa" , config )])
207- elif training_args .mode in ["eagle1" , "eagle3" ]:
208- from modelopt .torch .speculative .config import (
209- default_eagle_config ,
210- eagle3_default_config ,
211- kimik2_eagle_default_config ,
200+ elif training_args .mode == "eagle3" :
201+ custom_config = (
202+ json .load (open (eagle_args .eagle_config )) if eagle_args .eagle_config else {}
212203 )
213204
214- if eagle_args .eagle_decoder_type == "kimik2" :
215- eagle_architecture_config = kimik2_eagle_default_config
216- else :
217- eagle_architecture_config = {
218- "eagle1" : default_eagle_config ,
219- "eagle3" : eagle3_default_config ,
220- }[training_args .mode ]
221-
222- if eagle_args .eagle_config :
223- with open (eagle_args .eagle_config ) as f :
224- custom_config = json .load (f )
225- eagle_architecture_config .update (custom_config )
226-
227205 config = {
228206 "eagle_decoder_type" : eagle_args .eagle_decoder_type ,
229207 "eagle_offline" : use_offline_training ,
230- "eagle_architecture_config" : eagle_architecture_config ,
208+ "eagle_architecture_config" : custom_config ,
231209 }
232210
233211 mtsp .convert (model , [("eagle" , config )])
234212
235- # read draft vocab cache
236- if model .eagle_config .draft_vocab_size < model .eagle_config .vocab_size :
237- try :
238- model_name = os .path .basename (os .path .normpath (model_args .model_name_or_path ))
239- vocab_cache_path = os .path .join (
240- data_args .draft_vocab_cache_dir , model_name , "d2t.pt"
241- )
242- vocab_cache = torch .load (vocab_cache_path )
243- model .eagle_module .d2t = vocab_cache
244- print_rank_0 (f"Loaded draft vocab cache from { vocab_cache_path } ." )
245- except Exception as e :
246- raise e
247213 else :
248214 raise Exception (f"{ training_args .mode } is not supported!" )
249215
250216 print_rank_0 ("Loading dataset..." )
251217 if training_args .mode == "medusa" :
252218 data_module = make_medusa_supervised_data_module (tokenizer , data_args )
253- elif training_args .mode in [ "eagle1" , " eagle3"] :
219+ elif training_args .mode == " eagle3" :
254220 data_module = make_eagle_supervised_data_module (
255221 tokenizer , data_args , train_len = training_args .training_seq_len
256222 )
0 commit comments