Skip to content

Commit f6f802e

Browse files
committed
fix exclude modules deadlock
Signed-off-by: jenchen13 <jennifchen@nvidia.com>
1 parent d7911a4 commit f6f802e

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

modelopt/torch/export/unified_export_megatron.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,22 @@ def save_pretrained(
289289
except (OSError, ValueError, ImportError):
290290
pass
291291

292+
mtp_state_dict = self._get_mtp_state_dict()
293+
if len(mtp_state_dict) > 0:
294+
state_dict.update(mtp_state_dict)
295+
print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors")
296+
297+
combined_exclude_modules = self._gather_exclude_modules()
298+
292299
if is_last_stage_main_rank and quantization is not None:
293-
self._gather_exclude_modules() # gather exclude_modules from all ranks
294300
self._hf_quant_config = {
295301
"producer": {
296302
"name": "modelopt",
297303
"version": __version__,
298304
},
299305
"quantization": {
300306
"quant_algo": quantization,
301-
"exclude_modules": self.exclude_modules,
307+
"exclude_modules": combined_exclude_modules,
302308
},
303309
}
304310
if quantization == "NVFP4": # update block size
@@ -377,10 +383,6 @@ def save_pretrained(
377383
# Add multimodal components to state_dict
378384
state_dict.update(multimodal_state_dict)
379385

380-
mtp_state_dict = self._get_mtp_state_dict()
381-
state_dict.update(mtp_state_dict)
382-
print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors")
383-
384386
# Barrier to ensure the export_dir has been created.
385387
torch.distributed.barrier()
386388

@@ -1238,6 +1240,9 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None):
12381240

12391241
def _gather_exclude_modules(self):
12401242
"""Get exclude_modules from all ranks to ensure hf_quant_config is complete."""
1243+
if not torch.distributed.is_initialized():
1244+
return sorted(self.exclude_modules)
1245+
12411246
all_exclude_modules = [None] * torch.distributed.get_world_size()
12421247
torch.distributed.all_gather_object(all_exclude_modules, self.exclude_modules)
12431248
combined_exclude_modules = set()

0 commit comments

Comments
 (0)