|
| 1 | +# SPDX-FileCopyrightText: 2026 ModelCloud.ai |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from types import SimpleNamespace |
| 5 | + |
| 6 | +import defuser |
| 7 | +from accelerate import init_empty_weights |
| 8 | +from transformers import AutoModelForCausalLM |
| 9 | +from transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe import HunYuanMoEV1Config |
| 10 | + |
| 11 | +from gptqmodel.models import auto |
| 12 | +from gptqmodel.models.definitions.hunyuan_v1_dense import HunYuanDenseV1QModel |
| 13 | +from gptqmodel.models.definitions.hunyuan_v1_moe import HunYuanMoEV1QModel |
| 14 | + |
| 15 | + |
| 16 | +def test_hunyuan_v1_dense_model_type_selects_definition(monkeypatch): |
| 17 | + fake_config = SimpleNamespace(model_type="hunyuan_v1_dense") |
| 18 | + |
| 19 | + monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) |
| 20 | + monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) |
| 21 | + |
| 22 | + assert auto.check_and_get_model_definition("/tmp/hunyuan_v1_dense") is HunYuanDenseV1QModel |
| 23 | + |
| 24 | + |
| 25 | +def test_hunyuan_v1_dense_module_tree_skips_qk_norms(): |
| 26 | + attn_modules = HunYuanDenseV1QModel.module_tree[-1]["self_attn"] |
| 27 | + |
| 28 | + assert "q_proj:0" in attn_modules |
| 29 | + assert "k_proj:0" in attn_modules |
| 30 | + assert "v_proj:0" in attn_modules |
| 31 | + assert "o_proj:1" in attn_modules |
| 32 | + assert "query_layernorm:!" in attn_modules |
| 33 | + assert "key_layernorm:!" in attn_modules |
| 34 | + |
| 35 | + |
| 36 | +def test_hunyuan_v1_moe_model_type_selects_definition(monkeypatch): |
| 37 | + fake_config = SimpleNamespace(model_type="hunyuan_v1_moe") |
| 38 | + |
| 39 | + monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) |
| 40 | + monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) |
| 41 | + |
| 42 | + assert auto.check_and_get_model_definition("/tmp/hunyuan_v1_moe") is HunYuanMoEV1QModel |
| 43 | + |
| 44 | + |
| 45 | +def test_hunyuan_v1_moe_module_tree_matches_defused_experts(): |
| 46 | + cfg = HunYuanMoEV1Config( |
| 47 | + vocab_size=128, |
| 48 | + hidden_size=64, |
| 49 | + intermediate_size=32, |
| 50 | + num_hidden_layers=1, |
| 51 | + num_attention_heads=4, |
| 52 | + num_key_value_heads=2, |
| 53 | + num_experts=4, |
| 54 | + moe_topk=2, |
| 55 | + head_dim=16, |
| 56 | + max_position_embeddings=128, |
| 57 | + pad_token_id=0, |
| 58 | + bos_token_id=1, |
| 59 | + eos_token_id=2, |
| 60 | + tie_word_embeddings=True, |
| 61 | + ) |
| 62 | + |
| 63 | + with init_empty_weights(include_buffers=True): |
| 64 | + model = AutoModelForCausalLM.from_config(cfg) |
| 65 | + |
| 66 | + assert defuser.convert_model(model, cleanup_original=False) is True |
| 67 | + |
| 68 | + layer = model.model.layers[0] |
| 69 | + expert = layer.mlp.experts[0] |
| 70 | + |
| 71 | + assert hasattr(layer.self_attn, "query_layernorm") |
| 72 | + assert hasattr(layer.self_attn, "key_layernorm") |
| 73 | + assert hasattr(layer.mlp, "shared_mlp") |
| 74 | + assert hasattr(expert, "gate_proj") |
| 75 | + assert hasattr(expert, "up_proj") |
| 76 | + assert hasattr(expert, "down_proj") |
| 77 | + |
| 78 | + attn_modules = HunYuanMoEV1QModel.module_tree[-1]["self_attn"] |
| 79 | + mlp_tree = HunYuanMoEV1QModel.module_tree[-1]["mlp:moe:?"] |
| 80 | + layer_modules = HunYuanMoEV1QModel.simple_layer_modules( |
| 81 | + model_config=cfg, |
| 82 | + quantize_config=SimpleNamespace(dynamic=None), |
| 83 | + ) |
| 84 | + |
| 85 | + assert "query_layernorm:!" in attn_modules |
| 86 | + assert "key_layernorm:!" in attn_modules |
| 87 | + assert "shared_mlp" in mlp_tree |
| 88 | + assert "experts:0" in mlp_tree |
| 89 | + assert ["mlp.shared_mlp.gate_proj", "mlp.shared_mlp.up_proj"] in layer_modules |
| 90 | + assert ["mlp.shared_mlp.down_proj"] in layer_modules |
| 91 | + assert any("mlp.experts.0.gate_proj" in block for block in layer_modules) |
| 92 | + assert any("mlp.experts.0.down_proj" in block for block in layer_modules) |
0 commit comments