Skip to content

Commit 2e73917

Browse files
committed
Add Kimi-k2-thinking and k2.5 checkpoint conversion support.
1 parent b3a1832 commit 2e73917

5 files changed

Lines changed: 349 additions & 12 deletions

File tree

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ addopts =
1111
--ignore=tests/integration/smoke/train_smoke_test.py
1212
--ignore=tests/integration/smoke/train_using_ragged_dot_smoke_test.py
1313
--ignore=tests/unit/dequantize_mxfp4_test.py
14+
--ignore=tests/unit/dequantize_pack_quantized_int4_test.py
1415
--ignore=tests/unit/gemma3_layers_test.py
1516
--ignore=tests/unit/gpt_vs_reference_test.py
1617
--ignore=tests/unit/llama4_layers_test.py

src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py

Lines changed: 124 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,39 @@
4848
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
4949

5050

51+
# Group size used by Kimi-K2 quantized variants for per-group symmetric int4 scales.
52+
# Matches `quantization_config.config_groups.group_0.weights.group_size` in the HF config.
53+
INT4_GROUP_SIZE = 32
54+
55+
56+
def dequantize_pack_quantized_int4(
57+
packed: torch.Tensor,
58+
scale: torch.Tensor,
59+
out_shape,
60+
) -> torch.Tensor:
61+
"""Dequantize a compressed-tensors pack-quantized int4 weight to bf16.
62+
63+
packed: int32 [out, in/8]. Each int32 packs 8 weights along the input dim;
64+
weight k lives in bits [4k : 4k+4] (so the first weight is in the low 4 bits).
65+
scale: bf16/fp16 [out, in/group_size], symmetric (no zero point).
66+
67+
Each 4-bit value is unsigned 0..15; subtract 8 to get the signed weight in [-8, 7].
68+
"""
69+
out_features, in_features = int(out_shape[0]), int(out_shape[1])
70+
if in_features % INT4_GROUP_SIZE != 0:
71+
raise ValueError(f"in_features={in_features} not divisible by group_size={INT4_GROUP_SIZE}")
72+
73+
shifts = torch.arange(8, dtype=torch.int32) * 4
74+
nibbles = (packed.to(torch.int32).unsqueeze(-1) >> shifts) & 0xF
75+
w_int = (nibbles - 8).reshape(out_features, -1)[:, :in_features].to(torch.float32)
76+
77+
s = scale.to(torch.float32).unsqueeze(-1)
78+
w = (w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE) * s).reshape(
79+
out_features, in_features
80+
)
81+
return w.to(torch.bfloat16)
82+
83+
5184
MODEL_PARAMS_DICT = {
5285
"deepseek2-16b": {
5386
"num_layers": 27,
@@ -88,6 +121,56 @@
88121
"v_head_dim": 128,
89122
"has_mtp": False,
90123
},
124+
"kimi-k2-thinking": {
125+
"num_layers": 61,
126+
"first_num_dense_layers": 1,
127+
"base_num_query_heads": 64,
128+
"base_emb_dim": 7168,
129+
"num_experts": 384,
130+
"q_lora_rank": 1536,
131+
"kv_lora_rank": 512,
132+
"qk_nope_head_dim": 128,
133+
"qk_rope_head_dim": 64,
134+
"v_head_dim": 128,
135+
"has_mtp": False,
136+
# Only routed-expert projections are int4 (group_size=32, symmetric);
137+
# attention, shared experts, dense MLP, and lm_head stay bf16.
138+
"compressed_int4": True,
139+
},
140+
# Multimodal wrapper; text branch matches kimi-k2-thinking but keys are
141+
# prefixed `language_model.`. Vision keys are dropped (text-only target).
142+
"kimi-k2.5-text": {
143+
"num_layers": 61,
144+
"first_num_dense_layers": 1,
145+
"base_num_query_heads": 64,
146+
"base_emb_dim": 7168,
147+
"num_experts": 384,
148+
"q_lora_rank": 1536,
149+
"kv_lora_rank": 512,
150+
"qk_nope_head_dim": 128,
151+
"qk_rope_head_dim": 64,
152+
"v_head_dim": 128,
153+
"has_mtp": False,
154+
"compressed_int4": True,
155+
"hf_key_prefix": "language_model.",
156+
},
157+
# K2.6 reuses the K2.5 multimodal wrapper (KimiK25ForConditionalGeneration);
158+
# text branch shape and quantization layout are identical to kimi-k2.5-text.
159+
"kimi-k2.6-text": {
160+
"num_layers": 61,
161+
"first_num_dense_layers": 1,
162+
"base_num_query_heads": 64,
163+
"base_emb_dim": 7168,
164+
"num_experts": 384,
165+
"q_lora_rank": 1536,
166+
"kv_lora_rank": 512,
167+
"qk_nope_head_dim": 128,
168+
"qk_rope_head_dim": 64,
169+
"v_head_dim": 128,
170+
"has_mtp": False,
171+
"compressed_int4": True,
172+
"hf_key_prefix": "language_model.",
173+
},
91174
}
92175

93176

@@ -232,20 +315,52 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info,
232315

233316
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
234317
chkpt_vars = {}
318+
is_compressed = bool(model_params.get("compressed_int4", False))
319+
hf_key_prefix = model_params.get("hf_key_prefix", "") # for multimodal text-only
320+
321+
def _normalize(raw_key):
322+
"""Strip multimodal prefix; return None to drop keys outside the text branch."""
323+
if not hf_key_prefix:
324+
return raw_key
325+
if raw_key.startswith(hf_key_prefix):
326+
return raw_key[len(hf_key_prefix) :]
327+
return None
328+
235329
for i, ckpt_path in enumerate(ckpt_paths):
236330
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
237331
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
238-
for key in f.keys():
239-
parts = key.split(".")
240-
layer = int(parts[2]) if "layers" in key else 0
332+
for raw_key in f.keys():
333+
key = _normalize(raw_key)
334+
if key is None:
335+
continue # vision_tower / mm_projector etc. when text-only
241336
if key.endswith("_scale_inv"):
242337
raise ValueError("fp8 checkpoint is not supported.")
243-
if is_key_allowed(key, MTP_KEYS_TO_SKIP):
244-
mapped_key = hf_to_maxtext_mapping(
245-
layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp
246-
).get(key)
247-
if mapped_key:
248-
chkpt_vars[mapped_key] = f.get_tensor(key)
338+
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
339+
continue # consumed alongside the matching .weight_packed
340+
341+
# Compressed weights advertise as ".weight_packed"; resolve back to the
342+
# logical ".weight" name so it matches the maxtext mapping table.
343+
is_packed = is_compressed and key.endswith(".weight_packed")
344+
hf_key = key[: -len(".weight_packed")] + ".weight" if is_packed else key
345+
346+
if not is_key_allowed(hf_key, MTP_KEYS_TO_SKIP):
347+
continue
348+
layer = int(hf_key.split(".")[2]) if "layers" in hf_key else 0
349+
mapped_key = hf_to_maxtext_mapping(
350+
layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp
351+
).get(hf_key)
352+
if not mapped_key:
353+
continue
354+
355+
if is_packed:
356+
raw_base = raw_key[: -len(".weight_packed")]
357+
chkpt_vars[mapped_key] = dequantize_pack_quantized_int4(
358+
f.get_tensor(raw_key),
359+
f.get_tensor(raw_base + ".weight_scale"),
360+
f.get_tensor(raw_base + ".weight_shape").tolist(),
361+
)
362+
else:
363+
chkpt_vars[mapped_key] = f.get_tensor(raw_key)
249364

250365
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
251366

src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,47 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
6161

6262
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
6363
chkpt_vars = {}
64+
is_compressed = bool(model_params.get("compressed_int4", False))
65+
hf_key_prefix = model_params.get("hf_key_prefix", "")
66+
67+
def _normalize(raw_key):
68+
if not hf_key_prefix:
69+
return raw_key
70+
if raw_key.startswith(hf_key_prefix):
71+
return raw_key[len(hf_key_prefix) :]
72+
return None
73+
6474
for i, ckpt_path in enumerate(ckpt_paths):
6575
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
6676
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
67-
for key in f.keys():
77+
for raw_key in f.keys():
78+
key = _normalize(raw_key)
79+
if key is None:
80+
continue
81+
82+
if is_compressed and key.endswith(".weight_packed"):
83+
base = key[: -len(".weight_packed")]
84+
hf_key = base + ".weight"
85+
parts = hf_key.split(".")
86+
layer = int(parts[2]) if "layers" in hf_key else 0
87+
if not ds_ckpt.is_key_allowed(hf_key, ds_ckpt.MTP_KEYS_TO_SKIP):
88+
continue
89+
mapped_key = ds_ckpt.hf_to_maxtext_mapping(
90+
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
91+
).get(hf_key)
92+
if not mapped_key:
93+
continue
94+
raw_base = raw_key[: -len(".weight_packed")]
95+
shape_t = f.get_tensor(raw_base + ".weight_shape")
96+
chkpt_vars[mapped_key] = ds_ckpt.dequantize_pack_quantized_int4(
97+
f.get_tensor(raw_key),
98+
f.get_tensor(raw_base + ".weight_scale"),
99+
shape_t.tolist(),
100+
)
101+
continue
102+
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
103+
continue
104+
68105
parts = key.split(".")
69106
layer = int(parts[2]) if "layers" in key else 0
70107
if key.endswith("_scale_inv"):
@@ -74,7 +111,7 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
74111
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
75112
).get(key)
76113
if mapped_key:
77-
chkpt_vars[mapped_key] = f.get_tensor(key)
114+
chkpt_vars[mapped_key] = f.get_tensor(raw_key)
78115
else:
79116
# This catches keys that are allowed but missing from the mapping dictionary
80117
max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.")

tests/end_to_end/tpu/kimi/Run_Kimi.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# Kimi
1818

19-
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**.
19+
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. Supported models for checkpoint conversion are **Kimi K2 (1T)**, **Kimi K2-Thinking**, **Kimi K2.5** (text branch), and **Kimi K2.6** (text branch). All variants share the same DeepSeek-V3-style architecture (61 layers, 384 routed experts, MLA).
2020

2121
* **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks.
2222
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
@@ -45,6 +45,30 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam
4545
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE
4646
```
4747

48+
### Quantized variants: K2-Thinking, K2.5, K2.6
49+
50+
K2-Thinking/K2.5/K2.6 ship routed-expert weights as int4 (compressed-tensors pack-quantized, group_size=32 symmetric); the converter dequantizes them inline to bf16. The FP8→bf16 path above does not apply. Other tensors (attention, shared experts, dense MLP, `lm_head`) are already bf16 in the checkpoint and pass through unchanged.
51+
52+
K2.5 and K2.6 are multimodal wrappers (`KimiK25ForConditionalGeneration`); the converter strips the `language_model.` prefix and silently drops `vision_tower.*` / `mm_projector.*` keys to produce a text-only MaxText checkpoint.
53+
54+
1. Download the model from HuggingFace — pick one:
55+
```sh
56+
hf download moonshotai/Kimi-K2-Thinking --local-dir $LOCAL_HF_PATH # --model_size kimi-k2-thinking
57+
hf download moonshotai/Kimi-K2.5 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.5-text
58+
hf download moonshotai/Kimi-K2.6 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.6-text
59+
```
60+
61+
2. Convert directly to Orbax (no intermediate FP8→BF16 pass):
62+
```sh
63+
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt \
64+
--model_size <kimi-k2-thinking | kimi-k2.5-text | kimi-k2.6-text> \
65+
--base_model_path $LOCAL_HF_PATH \
66+
--maxtext_model_path $GCS_PATH_TO_SAVE
67+
```
68+
Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout.
69+
70+
> **Note:** Pre-training / fine-tuning / decoding flows below currently reference `model_name=kimi-k2-1t`. To run those for the new variants, a parallel `src/maxtext/configs/models/<variant>.yml` must be added; the converter changes here only cover the checkpoint side.
71+
4872
## Pre-training
4973
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). To use MuonClip optimizer, you need `optax>=0.2.7` and `tokamax>=0.0.11`.
5074

0 commit comments

Comments
 (0)