Skip to content

Commit 7bbcb28

Browse files
fix test_subset.py (#2907)
* fix test_subset.py Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * cleanup Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> --------- Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 3ff0371 commit 7bbcb28

2 files changed

Lines changed: 8 additions & 4 deletions

File tree

gptqmodel/models/moe_lifecycle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_moe_block(self, layer_module: nn.Module, model_class: type) -> Optional[
9999
return None
100100

101101
# Get the module by name
102-
moe_block = getattr(layer_module, moe_module_name, None)
102+
moe_block = getattr(layer_module, moe_module_name[0], None)
103103

104104
return moe_block
105105

tests/module_tree/test_subset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
import importlib
56
import os
67
import sys
78
import threading
@@ -583,7 +584,10 @@ def test_qwen3_5_moe_subset_early_stop_follows_module_tree_execution_order():
583584
layer = model.model.layers[0]
584585
replace_module_with_hooked_legacy(layer)
585586

586-
quant_cfg = _make_quant_config()
587+
is_causal_conv1d_available = importlib.util.find_spec("causal_conv1d") is not None
588+
589+
device = "cuda" if is_causal_conv1d_available else "cpu"
590+
quant_cfg = _make_quant_config(device)
587591

588592
class _DummyQwen3_5Model:
589593
moe_lifecycle_hooks = Qwen3_5_MoeQModel.moe_lifecycle_hooks
@@ -632,7 +636,7 @@ def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, tar
632636
]
633637
assert subset_names[-1] == "mlp.experts.3.up_proj"
634638

635-
layer_inputs = [[torch.randn(1, 4, cfg.hidden_size)]]
639+
layer_inputs = [[torch.randn(1, 4, cfg.hidden_size).to(device)]]
636640
full_modules = find_modules(layer)
637641
subset = looper.create_named_modules(
638642
module=layer,
@@ -665,7 +669,7 @@ def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, tar
665669
layer_input_kwargs=[{}],
666670
position_ids=[None],
667671
attention_masks=[None],
668-
cur_layer_device=torch.device("cpu"),
672+
cur_layer_device=torch.device(device),
669673
is_lm_head_module=False,
670674
layer_descriptor="layers.0",
671675
layer_title="subset-check",

0 commit comments

Comments
 (0)