Skip to content

Commit af51539

Browse files
committed
adopt *experts.{id}.* naming pattern
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent f0326e5 commit af51539

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,36 @@ def _export_diffusers_checkpoint(
959959
print(f"Export complete. Saved to: {export_dir}")
960960

961961

962+
def _revert_weight_conversion_noop(model: Any, state_dict: dict) -> dict:
963+
"""No-op replacement for transformers' revert_weight_conversion."""
964+
return state_dict
965+
966+
967+
def _patch_revert_weight_conversion() -> list[tuple[Any, Any]]:
968+
"""Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
969+
import importlib
970+
971+
patches: list[tuple[Any, Any]] = []
972+
for mod_path in [
973+
"transformers.core_model_loading",
974+
"transformers.modeling_utils",
975+
]:
976+
try:
977+
mod = importlib.import_module(mod_path)
978+
if hasattr(mod, "revert_weight_conversion"):
979+
patches.append((mod, getattr(mod, "revert_weight_conversion")))
980+
setattr(mod, "revert_weight_conversion", _revert_weight_conversion_noop)
981+
except (ImportError, AttributeError):
982+
pass
983+
return patches
984+
985+
986+
def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
987+
"""Restore the original revert_weight_conversion functions."""
988+
for mod, original in patches:
989+
mod.revert_weight_conversion = original
990+
991+
962992
def export_hf_checkpoint(
963993
model: Any,
964994
dtype: torch.dtype | None = None,
@@ -1022,21 +1052,7 @@ def export_hf_checkpoint(
10221052
# quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
10231053
# We must patch both the source module and the importing module since
10241054
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
1025-
_patches = []
1026-
_noop = lambda model, state_dict: state_dict
1027-
for _mod_path in [
1028-
"transformers.core_model_loading",
1029-
"transformers.modeling_utils",
1030-
]:
1031-
try:
1032-
import importlib
1033-
1034-
_mod = importlib.import_module(_mod_path)
1035-
if hasattr(_mod, "revert_weight_conversion"):
1036-
_patches.append((_mod, getattr(_mod, "revert_weight_conversion")))
1037-
setattr(_mod, "revert_weight_conversion", _noop)
1038-
except (ImportError, AttributeError):
1039-
pass
1055+
_patches = _patch_revert_weight_conversion()
10401056

10411057
try:
10421058
model.save_pretrained(
@@ -1045,8 +1061,7 @@ def export_hf_checkpoint(
10451061
save_modelopt_state=save_modelopt_state,
10461062
)
10471063
finally:
1048-
for _mod, _original in _patches:
1049-
_mod.revert_weight_conversion = _original
1064+
_unpatch_revert_weight_conversion(_patches)
10501065

10511066
original_config = f"{export_dir}/config.json"
10521067
config_data = {}

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def forward(
653653
return next_states
654654

655655

656-
class _Qwen3_5MoeExpertModule(nn.Module):
656+
class _Qwen35MoeExpertModule(nn.Module):
657657
"""Container for a single Qwen3.5 MoE expert's linear layers.
658658
659659
Produces the naming pattern: experts.{id}.gate_proj.weight
@@ -667,7 +667,7 @@ def __init__(self, hidden_dim: int, expert_dim: int):
667667
self.down_proj = nn.Linear(expert_dim, hidden_dim, bias=False)
668668

669669

670-
class _QuantQwen3_5MoeExperts(QuantModule):
670+
class _QuantQwen35MoeExperts(QuantModule):
671671
def _setup(self):
672672
"""Modify the Qwen3_5MoeExperts by using per-expert nn.Module containers.
673673
@@ -688,7 +688,7 @@ def _copy_weight(module, weight):
688688
with init_empty_weights():
689689
expert_modules = nn.ModuleList(
690690
[
691-
_Qwen3_5MoeExpertModule(self.hidden_dim, expert_dim)
691+
_Qwen35MoeExpertModule(self.hidden_dim, expert_dim)
692692
for _ in range(self.num_experts)
693693
]
694694
)
@@ -898,7 +898,7 @@ def unpack_weight(self):
898898
pass
899899

900900

901-
class _QuantQwen3_5MoeSparseMoeBlock(_QuantSparseMoe):
901+
class _QuantQwen35MoeSparseMoeBlock(_QuantSparseMoe):
902902
"""Qwen3.5 MoE stores top_k/num_experts in the router (self.gate), not as direct attributes.
903903
904904
We override forward instead of just bridging attributes because the router (self.gate)
@@ -927,12 +927,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
927927

928928
if Qwen3_5MoeSparseMoeBlock not in QuantModuleRegistry:
929929
QuantModuleRegistry.register({Qwen3_5MoeSparseMoeBlock: "hf.Qwen3_5MoeSparseMoeBlock"})(
930-
_QuantQwen3_5MoeSparseMoeBlock
930+
_QuantQwen35MoeSparseMoeBlock
931931
)
932932

933933
if Qwen3_5MoeExperts not in QuantModuleRegistry:
934934
QuantModuleRegistry.register({Qwen3_5MoeExperts: "hf.Qwen3_5MoeExperts"})(
935-
_QuantQwen3_5MoeExperts
935+
_QuantQwen35MoeExperts
936936
)
937937
except ImportError:
938938
pass

0 commit comments

Comments
 (0)