Skip to content

Commit ddb23e7

Browse files
add kimi 2.5 support (#2858)
* Add Kimi K2.5 model support * move tests to models/ * use new test template * USE_FLASH_ATTN = False * use default * fix flash attn check * fix NotImplementedError: Cannot copy out of meta tensor; no data! * add todo * add retry to fix file not found.
1 parent ed42bc5 commit ddb23e7

6 files changed

Lines changed: 385 additions & 23 deletions

File tree

gptqmodel/models/auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
# ruff: noqa: I001
56

67
from __future__ import annotations
78

@@ -118,6 +119,7 @@
118119
from .definitions.internlm2 import InternLM2QModel # noqa: E402
119120
from .definitions.internvl_chat import InternVLChatQModel # noqa: E402
120121
from .definitions.klear import KlearQModel # noqa: E402
122+
from .definitions.kimi_k25 import KimiK25QModel # noqa: E402
121123
from .definitions.laguna import LagunaQModel # noqa: E402
122124
from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402
123125
from .definitions.llada2 import LLaDA2MoeQModel
@@ -187,6 +189,7 @@
187189
"brumby": BrumbyQModel,
188190
"gpt_neo": GptNeoQModel,
189191
"kimi_k2": DeepSeekV3QModel, # 100% DeepSeekV3QModel clone
192+
"kimi_k25": KimiK25QModel,
190193
"klear": KlearQModel,
191194
"laguna": LagunaQModel,
192195
"gpt_neox": GPTNeoXQModel,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
6+
from ..base import BaseQModel
7+
from ..moe_lifecycle import GateUpDownMoELifecycleHooks
8+
from ...utils.model import get_module
9+
10+
11+
class KimiK25QModel(BaseQModel):
12+
# Kimi-K2.5 wraps a DeepSeek-V3 text backbone with a vision tower and
13+
# projector. Quantize the language model and keep the vision path in base.
14+
require_trust_remote_code = True
15+
16+
require_load_processor = True
17+
18+
pre_lm_head_norm_module = "language_model.model.norm"
19+
20+
dynamic_expert_index = "n_routed_experts"
21+
22+
layer_modules_strict = False
23+
24+
moe_lifecycle_hooks = GateUpDownMoELifecycleHooks()
25+
26+
module_tree = [
27+
"language_model",
28+
"model",
29+
"layers",
30+
"#",
31+
{
32+
"input_layernorm": ("input_layernorm:!",),
33+
"self_attn": ("q_proj:0", "q_a_proj:0", "kv_a_proj_with_mqa:0", "q_b_proj:1", "kv_b_proj:1", "o_proj:2"),
34+
"post_attention_layernorm": ("post_attention_layernorm:!",),
35+
"mlp:moe": {
36+
"": ("gate_proj:0", "up_proj:0", "down_proj:1"),
37+
"experts": {
38+
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
39+
},
40+
"shared_experts": ("gate_proj:0", "up_proj:0", "down_proj:1"),
41+
},
42+
},
43+
]
44+
45+
@classmethod
46+
def get_base_modules(cls, model):
47+
base_modules = super().get_base_modules(model)
48+
prefix, core_model = cls._resolve_multimodal_layout(model)
49+
for name, _ in core_model.named_children():
50+
if name != "language_model":
51+
module_name = f"{prefix}.{name}" if prefix else name
52+
if module_name not in base_modules:
53+
base_modules.append(module_name)
54+
return base_modules
55+
56+
@classmethod
57+
def _resolve_multimodal_layout(cls, model):
58+
for prefix in ("model", ""):
59+
core_model = get_module(model, prefix) if prefix else model
60+
if core_model is None:
61+
continue
62+
if hasattr(core_model, "language_model"):
63+
return prefix, core_model
64+
raise AttributeError("Unable to resolve Kimi-K2.5 core model with a `language_model` module.")
65+
66+
67+
__all__ = ["KimiK25QModel"]

gptqmodel/models/loader.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import copy
99
import numpy as np
1010
import os
11+
import shutil
1112
import time
1213
import torch
1314
import transformers
@@ -115,6 +116,47 @@ def _supports_flash_attn_2(config: PretrainedConfig) -> bool:
115116
return False
116117

117118

119+
def _iter_nested_pretrained_configs(config: PretrainedConfig):
120+
"""Yield config and all nested PretrainedConfig nodes once."""
121+
122+
stack = [config]
123+
visited = set()
124+
125+
while stack:
126+
cur = stack.pop()
127+
if not isinstance(cur, PretrainedConfig):
128+
continue
129+
130+
node_id = id(cur)
131+
if node_id in visited:
132+
continue
133+
visited.add(node_id)
134+
yield cur
135+
136+
for value in vars(cur).values():
137+
if isinstance(value, PretrainedConfig):
138+
stack.append(value)
139+
elif isinstance(value, dict):
140+
for sub in value.values():
141+
if isinstance(sub, PretrainedConfig):
142+
stack.append(sub)
143+
elif isinstance(value, (list, tuple, set)):
144+
for sub in value:
145+
if isinstance(sub, PretrainedConfig):
146+
stack.append(sub)
147+
148+
149+
def _override_attn_implementation(config: PretrainedConfig, attn_implementation: str) -> None:
150+
"""Apply attention implementation override to root and nested configs."""
151+
152+
for sub_config in _iter_nested_pretrained_configs(config):
153+
try:
154+
sub_config._attn_implementation = attn_implementation
155+
except Exception:
156+
# Some remote configs may expose read-only wrappers; ignore safely.
157+
pass
158+
159+
118160
def _is_accelerated_attention_device(device: object) -> bool:
119161
"""Return True when the selected device can run CUDA/ROCm flash attention."""
120162

@@ -197,6 +239,43 @@ def _is_meta_shell_build_error(exc: Exception) -> bool:
197239
return "cannot be called on meta tensors" in message and ".item()" in message
198240

199241

242+
def _is_broken_transformers_dynamic_module_error(exc: Exception) -> bool:
243+
if not isinstance(exc, FileNotFoundError):
244+
return False
245+
missing_path = str(getattr(exc, "filename", "") or exc)
246+
return "transformers_modules" in missing_path and missing_path.endswith(".py")
247+
248+
249+
def _hf_loader_from_pretrained_with_dynamic_module_retry(loader, model_local_path: str, **kwargs):
250+
try:
251+
return loader.from_pretrained(model_local_path, **kwargs)
252+
except Exception as exc:
253+
if not _is_broken_transformers_dynamic_module_error(exc):
254+
raise
255+
256+
missing_path = str(getattr(exc, "filename", "") or "")
257+
missing_name = os.path.basename(missing_path)
258+
source_path = os.path.join(model_local_path, missing_name)
259+
if missing_path and os.path.isfile(source_path):
260+
os.makedirs(os.path.dirname(missing_path), exist_ok=True)
261+
shutil.copy2(source_path, missing_path)
262+
log.warn(
263+
"Loader: repaired missing dynamic-module file by copying `%s` -> `%s`.",
264+
source_path,
265+
missing_path,
266+
)
267+
268+
retry_kwargs = dict(kwargs)
269+
retry_kwargs["force_download"] = True
270+
log.warn(
271+
"Loader: detected broken transformers dynamic-module cache while loading `%s`; "
272+
"retrying once with force_download=True: %s",
273+
model_local_path,
274+
exc,
275+
)
276+
return loader.from_pretrained(model_local_path, **retry_kwargs)
277+
278+
200279
def _coerce_quantized_awq_dtype(*, backend: BACKEND, qcfg: QuantizeConfig, dtype):
201280
if qcfg.quant_method not in (METHOD.AWQ, METHOD.PARO):
202281
return dtype
@@ -458,7 +537,7 @@ def from_pretrained(
458537

459538
if atten_impl is not None and atten_impl != "auto":
460539
log.info(f"Loader: overriding attn_implementation in config to `{atten_impl}`")
461-
config._attn_implementation = atten_impl
540+
_override_attn_implementation(config, atten_impl)
462541

463542
resolved_device = normalize_device_device_map(device, device_map)
464543
resolved_device = auto_select_device(resolved_device, backend)
@@ -536,7 +615,12 @@ def from_pretrained(
536615
hf_model_init_kwargs[ATTN_IMPLEMENTATION] = "flash_attention_2"
537616
log.info("Loader: Auto enabling flash_attention_2 for dense Bonsai PROFILE.%s.", effective_profile.name)
538617
# Load a non-quantized model, but do not perform quantization. For example, for evaluation.
539-
model = cls.loader.from_pretrained(model_local_path, config=config, **hf_model_init_kwargs)
618+
model = _hf_loader_from_pretrained_with_dynamic_module_retry(
619+
cls.loader,
620+
model_local_path,
621+
config=config,
622+
**hf_model_init_kwargs,
623+
)
540624
model._model_init_kwargs = hf_model_init_kwargs
541625
_maybe_print_module_tree(model=model)
542626

@@ -634,7 +718,8 @@ def skip(*args, **kwargs):
634718
fallback_init_kwargs = model_init_kwargs_without_internal.copy()
635719
fallback_init_kwargs.pop("device_map", None)
636720
fallback_init_kwargs["low_cpu_mem_usage"] = False
637-
model = cls.loader.from_pretrained(
721+
model = _hf_loader_from_pretrained_with_dynamic_module_retry(
722+
cls.loader,
638723
model_local_path,
639724
config=config,
640725
**fallback_init_kwargs,
@@ -674,7 +759,8 @@ def skip(*args, **kwargs):
674759
)
675760
else:
676761
log.info("Loader: loading model directly to CPU (not using meta device or turtle_model)")
677-
model = cls.loader.from_pretrained(
762+
model = _hf_loader_from_pretrained_with_dynamic_module_retry(
763+
cls.loader,
678764
model_local_path,
679765
config=config,
680766
**model_init_kwargs_without_internal,

gptqmodel/models/writer.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin
2121
from transformers.models.auto.tokenization_auto import get_tokenizer_config
2222

23+
from ._const import DEFAULT_MAX_SHARD_SIZE, DEVICE
2324
from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora
2425
from ..adapter.peft import LoraConfig
2526
from ..quantization.config import (
@@ -60,11 +61,9 @@
6061
make_quant,
6162
streaming_state_dict_to_shards,
6263
)
63-
from ..utils.structure import alias_all_from_turtle_if_meta
64+
from ..utils.structure import alias_all_from_turtle_if_meta, alias_from_turtle_for_submodule
6465
from ..utils.torch import torch_empty_cache
6566
from ..version import __version__
66-
from ._const import DEFAULT_MAX_SHARD_SIZE, DEVICE
67-
6867

6968
log = setup_logger()
7069

@@ -103,6 +102,104 @@ def _parse_split_by(value: Optional[str]) -> Optional[str]:
103102
return normalized
104103

105104

105+
def _materialize_remaining_meta_params_from_turtle(model: torch.nn.Module, turtle_model) -> int:
106+
"""Best-effort fallback for meta params that survive normal turtle sync."""
107+
108+
if (
109+
turtle_model is None
110+
or not hasattr(turtle_model, "_resolve_checkpoint_tensor_source")
111+
or not hasattr(turtle_model, "_weight_map")
112+
or not hasattr(turtle_model, "model_local_path")
113+
):
114+
return 0
115+
116+
restored = 0
117+
pending_by_shard: Dict[str, List[tuple[str, str, str, torch.nn.Parameter, Optional[int], Optional[int], Optional[int]]]] = {}
118+
119+
for full_name, param in list(model.named_parameters()):
120+
if not (getattr(param, "is_meta", False) or param.device.type == "meta"):
121+
continue
122+
123+
module_path, leaf = full_name.rsplit(".", 1)
124+
resolved_name, expert_index, split_index, split_dim = turtle_model._resolve_checkpoint_tensor_source(module_path, leaf)
125+
if resolved_name is None:
126+
continue
127+
shard = turtle_model._weight_map.get(resolved_name)
128+
if shard is None:
129+
continue
130+
pending_by_shard.setdefault(shard, []).append(
131+
(resolved_name, module_path, leaf, param, expert_index, split_index, split_dim)
132+
)
133+
134+
for shard, entries in pending_by_shard.items():
135+
shard_path = os.path.join(turtle_model.model_local_path, shard)
136+
unique_names = {name for name, _module_path, _leaf, _param, _expert_index, _split_index, _split_dim in entries}
137+
138+
try:
139+
with safe_open(shard_path, framework="pt", device="cpu") as handler:
140+
tensors = {name: handler.get_tensor(name) for name in unique_names}
141+
except RuntimeError as exc:
142+
log.warn("Model save: skipping shard `%s` during meta materialization due to runtime error: %s", shard, exc)
143+
continue
144+
145+
for tensor_name, module_path, leaf, param, expert_index, split_index, split_dim in entries:
146+
source = tensors.get(tensor_name)
147+
if source is None:
148+
continue
149+
target = source
150+
if expert_index is not None:
151+
if expert_index >= target.shape[0]:
152+
continue
153+
target = target.narrow(0, expert_index, 1).squeeze(0)
154+
if split_index is not None and split_dim is not None:
155+
if target.shape[split_dim] % 2 != 0:
156+
continue
157+
chunk = target.shape[split_dim] // 2
158+
target = target.narrow(split_dim, split_index * chunk, chunk)
159+
if target.dtype != param.dtype:
160+
target = target.to(dtype=param.dtype)
161+
if tuple(target.shape) != tuple(param.shape):
162+
continue
163+
module = model.get_submodule(module_path)
164+
replacement = torch.nn.Parameter(target.detach().clone(), requires_grad=param.requires_grad)
165+
setattr(module, leaf, replacement)
166+
restored += 1
167+
168+
return restored
169+
170+
171+
def _materialize_meta_layers_from_turtle(model: torch.nn.Module, turtle_model) -> int:
172+
if turtle_model is None or not hasattr(turtle_model, "materialize_submodule"):
173+
return 0
174+
175+
layer_paths = set()
176+
for full_name, param in model.named_parameters():
177+
if not (getattr(param, "is_meta", False) or param.device.type == "meta"):
178+
continue
179+
parts = full_name.split(".")
180+
if "layers" in parts:
181+
i = parts.index("layers")
182+
if i + 1 < len(parts):
183+
layer_paths.add(".".join(parts[: i + 2]))
184+
185+
materialized = 0
186+
for path in sorted(layer_paths):
187+
try:
188+
submodule = model.get_submodule(path)
189+
alias_from_turtle_for_submodule(
190+
target_model=model,
191+
turtle_model=turtle_model,
192+
target_submodule=submodule,
193+
device=torch.device("cpu"),
194+
non_blocking=False,
195+
)
196+
materialized += 1
197+
except Exception as exc:
198+
log.warn("Model save: failed to materialize meta layer `%s` from turtle: %s", path, exc)
199+
200+
return materialized
201+
202+
106203
def _cleanup_saved_weight_files(
107204
save_dir: str,
108205
expected_files: List[str],
@@ -658,6 +755,12 @@ def debug_saved_config(path):
658755
# Due to shell/turtle state, we need to sync the modules from turtle to shell
659756
if not self.load_quantized_model:
660757
alias_all_from_turtle_if_meta(shell_model=self.model, turtle_model=self.turtle_model)
758+
materialized_layers = _materialize_meta_layers_from_turtle(self.model, self.turtle_model)
759+
if materialized_layers:
760+
log.info("Model save: materialized %s meta layer modules from turtle source.", materialized_layers)
761+
restored_meta = _materialize_remaining_meta_params_from_turtle(self.model, self.turtle_model)
762+
if restored_meta:
763+
log.info("Model save: materialized %s remaining meta params from turtle source.", restored_meta)
661764

662765
offload_root = self.quantize_config.offload_to_disk_path if getattr(self.quantize_config, "offload_to_disk", False) else None
663766
state_dict = get_state_dict_for_save(self.model, offload_root=offload_root)

0 commit comments

Comments
 (0)