@@ -289,16 +289,22 @@ def save_pretrained(
289289 except (OSError , ValueError , ImportError ):
290290 pass
291291
292+ mtp_state_dict = self ._get_mtp_state_dict ()
293+ if len (mtp_state_dict ) > 0 :
294+ state_dict .update (mtp_state_dict )
295+ print (f"Successfully loaded { len (mtp_state_dict )} MTP tensors" )
296+
297+ combined_exclude_modules = self ._gather_exclude_modules ()
298+
292299 if is_last_stage_main_rank and quantization is not None :
293- self ._gather_exclude_modules () # gather exclude_modules from all ranks
294300 self ._hf_quant_config = {
295301 "producer" : {
296302 "name" : "modelopt" ,
297303 "version" : __version__ ,
298304 },
299305 "quantization" : {
300306 "quant_algo" : quantization ,
301- "exclude_modules" : self . exclude_modules ,
307+ "exclude_modules" : combined_exclude_modules ,
302308 },
303309 }
304310 if quantization == "NVFP4" : # update block size
@@ -377,10 +383,6 @@ def save_pretrained(
377383 # Add multimodal components to state_dict
378384 state_dict .update (multimodal_state_dict )
379385
380- mtp_state_dict = self ._get_mtp_state_dict ()
381- state_dict .update (mtp_state_dict )
382- print (f"Successfully loaded { len (mtp_state_dict )} MTP tensors" )
383-
384386 # Barrier to ensure the export_dir has been created.
385387 torch .distributed .barrier ()
386388
@@ -1238,6 +1240,9 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None):
12381240
12391241 def _gather_exclude_modules (self ):
12401242 """Get exclude_modules from all ranks to ensure hf_quant_config is complete."""
1243+ if not torch .distributed .is_initialized ():
1244+ return sorted (self .exclude_modules )
1245+
12411246 all_exclude_modules = [None ] * torch .distributed .get_world_size ()
12421247 torch .distributed .all_gather_object (all_exclude_modules , self .exclude_modules )
12431248 combined_exclude_modules = set ()
0 commit comments