Skip to content

Commit 500e2eb

Browse files
[FIX] module_tree tests (#2411)
* FIX moe flag passing not passing nested ci test Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * fix failed test_model_alignment.py with transformers v5 Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> --------- Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 60e99b9 commit 500e2eb

3 files changed

Lines changed: 74 additions & 14 deletions

File tree

gptqmodel/models/base.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def has_moe_flag(cls, module_spec: str) -> bool:
353353
return MOE_FLAG.lstrip(":") in flags
354354

355355
@classmethod
356-
def _collect_moe_modules_from_tree(cls, tree_node, parent_path="") -> Set[str]:
356+
def _collect_moe_modules_from_tree(cls, tree_node, parent_path="", parent_is_moe=False) -> Set[str]:
357357
"""
358358
Recursively collect all module paths that have the :moe flag.
359359
Returns a set of full module paths (e.g., "mlp", "mlp.experts", "mlp.shared_experts").
@@ -366,27 +366,26 @@ def _collect_moe_modules_from_tree(cls, tree_node, parent_path="") -> Set[str]:
366366
if key == "#":
367367
# Recursively process the value if it's a dict
368368
if isinstance(value, dict):
369-
moe_modules.update(cls._collect_moe_modules_from_tree(value, parent_path))
369+
moe_modules.update(cls._collect_moe_modules_from_tree(value, parent_path, parent_is_moe))
370370
continue
371371

372372
# Build full path
373+
module_name, _ = cls._parse_module_flags(key) if isinstance(key, str) else (key, [])
373374
if parent_path:
374-
full_path = f"{parent_path}.{key}"
375+
full_path = f"{parent_path}.{module_name}"
375376
else:
376-
full_path = key
377+
full_path = module_name
377378

378379
# Check if this key has :moe flag
379-
if cls.has_moe_flag(key):
380-
# Extract just the module name without flags
381-
module_name, _ = cls._parse_module_flags(key)
382-
if parent_path:
383-
moe_modules.add(f"{parent_path}.{module_name}")
384-
else:
385-
moe_modules.add(module_name)
380+
is_moe = cls.has_moe_flag(key) if isinstance(key, str) else False
381+
if is_moe or parent_is_moe:
382+
moe_modules.add(full_path)
386383

387384
# Recursively process nested structures
388385
if isinstance(value, (dict, tuple, list)):
389-
moe_modules.update(cls._collect_moe_modules_from_tree(value, full_path.split(":")[0]))
386+
moe_modules.update(
387+
cls._collect_moe_modules_from_tree(value, full_path, parent_is_moe or is_moe)
388+
)
390389

391390
elif isinstance(tree_node, (tuple, list)):
392391
for item in tree_node:
@@ -397,7 +396,7 @@ def _collect_moe_modules_from_tree(cls, tree_node, parent_path="") -> Set[str]:
397396
else:
398397
moe_modules.add(module_name)
399398
elif isinstance(item, dict):
400-
moe_modules.update(cls._collect_moe_modules_from_tree(item, parent_path))
399+
moe_modules.update(cls._collect_moe_modules_from_tree(item, parent_path, parent_is_moe))
401400

402401
return moe_modules
403402

tests/module_tree/test_model_alignment.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "models"))
1717
from model_test import ModelTest # noqa: E402
18+
from torch import nn
1819

1920

2021
class TestDots1Struct(ModelTest):
@@ -32,7 +33,11 @@ def test_module_tree_alignment(self):
3233

3334
moe_layer = shell.model.layers[config.first_k_dense_replace]
3435
self.assertTrue(hasattr(moe_layer.mlp, "experts"))
35-
self.assertEqual(len(moe_layer.mlp.experts), config.n_routed_experts)
36+
if isinstance(moe_layer.mlp.experts, nn.ModuleList):
37+
expert_num = len(moe_layer.mlp.experts)
38+
else:
39+
expert_num = moe_layer.mlp.experts.num_experts
40+
self.assertEqual(expert_num, config.n_routed_experts)
3641
self.assertTrue(hasattr(moe_layer.mlp, "shared_experts"))
3742

3843
self.assertIn("q_norm:!", Dots1QModel.module_tree[3]["self_attn"])

tests/module_tree/test_moe_flag_parsing.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,56 @@ def test_parse_module_flags_combined(self):
3131
assert "moe" in flags
3232
assert "!" in flags
3333

34+
def test_parse_module_flags_with_optional_capture(self):
35+
"""Test parsing with :? and order variations."""
36+
name, flags = BaseQModel._parse_module_flags("mlp:moe:?")
37+
assert name == "mlp"
38+
assert "moe" in flags
39+
assert "?" in flags
40+
41+
name, flags = BaseQModel._parse_module_flags("mlp:?:moe")
42+
assert name == "mlp"
43+
assert "moe" in flags
44+
assert "?" in flags
45+
3446
def test_has_moe_flag(self):
3547
"""Test MoE flag detection."""
3648
assert BaseQModel.has_moe_flag("mlp:moe") is True
3749
assert BaseQModel.has_moe_flag("experts:moe") is True
3850
assert BaseQModel.has_moe_flag("gate:moe:!") is True
51+
assert BaseQModel.has_moe_flag("mlp:moe:?") is True
52+
assert BaseQModel.has_moe_flag("mlp:?:moe") is True
3953
assert BaseQModel.has_moe_flag("gate_proj") is False
4054
assert BaseQModel.has_moe_flag("gate_proj:!") is False
55+
assert BaseQModel.has_moe_flag("mlp:mode:?") is False
56+
assert BaseQModel.has_moe_flag("mlp:?:mode") is False
4157

4258
def test_has_moe_flag_non_string(self):
4359
"""Test MoE flag detection with non-string input."""
4460
assert BaseQModel.has_moe_flag(None) is False
4561
assert BaseQModel.has_moe_flag(123) is False
4662
assert BaseQModel.has_moe_flag({}) is False
4763

64+
def test_collect_moe_modules_with_optional_capture(self):
65+
"""Test MoE module collection with :? flag in module tree."""
66+
module_tree = [
67+
"model",
68+
"layers",
69+
"#",
70+
{
71+
"mlp:moe:?": {
72+
"experts": {
73+
"#": ("gate_proj:0",),
74+
},
75+
},
76+
}
77+
]
78+
79+
MockModel = TestMockMoEModel.create_mock_model_class(module_tree)
80+
moe_modules = MockModel.get_moe_modules()
81+
82+
assert "mlp" in moe_modules
83+
4884

4985
class TestMockMoEModel:
5086
"""Mock MoE model for testing."""
@@ -138,6 +174,26 @@ def test_is_moe_module_detection(self):
138174
assert MockModel.is_moe_module("model.layers.0.self_attn.q_proj") is False
139175
assert MockModel.is_moe_module("model.layers.0.self_attn") is False
140176

177+
def test_is_moe_module_with_optional_capture(self):
178+
"""Test is_moe_module() with :? flag in module tree."""
179+
module_tree = [
180+
"model",
181+
"layers",
182+
"#",
183+
{
184+
"mlp:?:moe": {
185+
"experts": {
186+
"#": ("gate_proj:0",),
187+
},
188+
},
189+
}
190+
]
191+
192+
MockModel = TestMockMoEModel.create_mock_model_class(module_tree)
193+
194+
assert MockModel.is_moe_module("model.layers.0.mlp") is True
195+
assert MockModel.is_moe_module("model.layers.0.mlp.experts.3.gate_proj") is True
196+
141197
def test_backward_compatibility_no_moe_flags(self):
142198
"""Test that models without :moe flags still work."""
143199
module_tree = [

0 commit comments

Comments
 (0)