Skip to content

Commit 2e3d6ee

Browse files
committed
Add Kimi-k2-thinking and k2.5 checkpoint conversion support.
1 parent 142bcf1 commit 2e3d6ee

5 files changed

Lines changed: 345 additions & 5 deletions

File tree

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ addopts =
1111
--ignore=tests/integration/smoke/train_smoke_test.py
1212
--ignore=tests/integration/smoke/train_using_ragged_dot_smoke_test.py
1313
--ignore=tests/unit/dequantize_mxfp4_test.py
14+
--ignore=tests/unit/dequantize_pack_quantized_int4_test.py
1415
--ignore=tests/unit/gemma3_layers_test.py
1516
--ignore=tests/unit/gpt_vs_reference_test.py
1617
--ignore=tests/unit/llama4_layers_test.py

src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,37 @@
4848
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
4949

5050

51+
# Group size used by Kimi-K2 quantized variants for per-group symmetric int4 scales.
52+
# Matches `quantization_config.config_groups.group_0.weights.group_size` in the HF config.
53+
INT4_GROUP_SIZE = 32
54+
55+
56+
def dequantize_pack_quantized_int4(
57+
packed: torch.Tensor,
58+
scale: torch.Tensor,
59+
out_shape,
60+
) -> torch.Tensor:
61+
"""Dequantize a compressed-tensors pack-quantized int4 weight to bf16.
62+
63+
packed: int32 [out, in/8], 8 nibbles per int32, low nibble first.
64+
scale: bf16/fp16 [out, in/group], symmetric (no zero point).
65+
Stored nibbles are biased by +8, so signed value = nibble - 8.
66+
"""
67+
out_features, in_features = int(out_shape[0]), int(out_shape[1])
68+
if in_features % INT4_GROUP_SIZE != 0:
69+
raise ValueError(f"in_features={in_features} not divisible by group_size={INT4_GROUP_SIZE}")
70+
71+
shifts = torch.arange(8, dtype=torch.int32) * 4
72+
nibbles = (packed.to(torch.int32).unsqueeze(-1) >> shifts) & 0xF
73+
w_int = (nibbles - 8).reshape(out_features, -1)[:, :in_features].to(torch.float32)
74+
75+
s = scale.to(torch.float32).unsqueeze(-1)
76+
w = (w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE) * s).reshape(
77+
out_features, in_features
78+
)
79+
return w.to(torch.bfloat16)
80+
81+
5182
MODEL_PARAMS_DICT = {
5283
"deepseek2-16b": {
5384
"num_layers": 27,
@@ -88,6 +119,56 @@
88119
"v_head_dim": 128,
89120
"has_mtp": False,
90121
},
122+
"kimi-k2-thinking": {
123+
"num_layers": 61,
124+
"first_num_dense_layers": 1,
125+
"base_num_query_heads": 64,
126+
"base_emb_dim": 7168,
127+
"num_experts": 384,
128+
"q_lora_rank": 1536,
129+
"kv_lora_rank": 512,
130+
"qk_nope_head_dim": 128,
131+
"qk_rope_head_dim": 64,
132+
"v_head_dim": 128,
133+
"has_mtp": False,
134+
# Only routed-expert projections are int4 (group_size=32, symmetric);
135+
# attention, shared experts, dense MLP, and lm_head stay bf16.
136+
"compressed_int4": True,
137+
},
138+
# Multimodal wrapper; text branch matches kimi-k2-thinking but keys are
139+
# prefixed `language_model.`. Vision keys are dropped (text-only target).
140+
"kimi-k2.5-text": {
141+
"num_layers": 61,
142+
"first_num_dense_layers": 1,
143+
"base_num_query_heads": 64,
144+
"base_emb_dim": 7168,
145+
"num_experts": 384,
146+
"q_lora_rank": 1536,
147+
"kv_lora_rank": 512,
148+
"qk_nope_head_dim": 128,
149+
"qk_rope_head_dim": 64,
150+
"v_head_dim": 128,
151+
"has_mtp": False,
152+
"compressed_int4": True,
153+
"hf_key_prefix": "language_model.",
154+
},
155+
# K2.6 reuses the K2.5 multimodal wrapper (KimiK25ForConditionalGeneration);
156+
# text branch shape and quantization layout are identical to kimi-k2.5-text.
157+
"kimi-k2.6-text": {
158+
"num_layers": 61,
159+
"first_num_dense_layers": 1,
160+
"base_num_query_heads": 64,
161+
"base_emb_dim": 7168,
162+
"num_experts": 384,
163+
"q_lora_rank": 1536,
164+
"kv_lora_rank": 512,
165+
"qk_nope_head_dim": 128,
166+
"qk_rope_head_dim": 64,
167+
"v_head_dim": 128,
168+
"has_mtp": False,
169+
"compressed_int4": True,
170+
"hf_key_prefix": "language_model.",
171+
},
91172
}
92173

93174

@@ -232,10 +313,48 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info,
232313

233314
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
234315
chkpt_vars = {}
316+
is_compressed = bool(model_params.get("compressed_int4", False))
317+
hf_key_prefix = model_params.get("hf_key_prefix", "") # for multimodal text-only
318+
319+
def _normalize(raw_key):
320+
"""Strip multimodal prefix; return None to drop keys outside the text branch."""
321+
if not hf_key_prefix:
322+
return raw_key
323+
if raw_key.startswith(hf_key_prefix):
324+
return raw_key[len(hf_key_prefix) :]
325+
return None
326+
235327
for i, ckpt_path in enumerate(ckpt_paths):
236328
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
237329
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
238-
for key in f.keys():
330+
for raw_key in f.keys():
331+
key = _normalize(raw_key)
332+
if key is None:
333+
continue # vision_tower / mm_projector etc. when text-only
334+
335+
if is_compressed and key.endswith(".weight_packed"):
336+
base = key[: -len(".weight_packed")]
337+
hf_key = base + ".weight"
338+
parts = hf_key.split(".")
339+
layer = int(parts[2]) if "layers" in hf_key else 0
340+
if not is_key_allowed(hf_key, MTP_KEYS_TO_SKIP):
341+
continue
342+
mapped_key = hf_to_maxtext_mapping(
343+
layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp
344+
).get(hf_key)
345+
if not mapped_key:
346+
continue
347+
raw_base = raw_key[: -len(".weight_packed")]
348+
shape_t = f.get_tensor(raw_base + ".weight_shape")
349+
chkpt_vars[mapped_key] = dequantize_pack_quantized_int4(
350+
f.get_tensor(raw_key),
351+
f.get_tensor(raw_base + ".weight_scale"),
352+
shape_t.tolist(),
353+
)
354+
continue
355+
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
356+
continue # consumed alongside the matching .weight_packed
357+
239358
parts = key.split(".")
240359
layer = int(parts[2]) if "layers" in key else 0
241360
if key.endswith("_scale_inv"):
@@ -245,7 +364,7 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info,
245364
layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp
246365
).get(key)
247366
if mapped_key:
248-
chkpt_vars[mapped_key] = f.get_tensor(key)
367+
chkpt_vars[mapped_key] = f.get_tensor(raw_key)
249368

250369
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
251370

src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,47 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
6161

6262
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
6363
chkpt_vars = {}
64+
is_compressed = bool(model_params.get("compressed_int4", False))
65+
hf_key_prefix = model_params.get("hf_key_prefix", "")
66+
67+
def _normalize(raw_key):
68+
if not hf_key_prefix:
69+
return raw_key
70+
if raw_key.startswith(hf_key_prefix):
71+
return raw_key[len(hf_key_prefix) :]
72+
return None
73+
6474
for i, ckpt_path in enumerate(ckpt_paths):
6575
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
6676
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
67-
for key in f.keys():
77+
for raw_key in f.keys():
78+
key = _normalize(raw_key)
79+
if key is None:
80+
continue
81+
82+
if is_compressed and key.endswith(".weight_packed"):
83+
base = key[: -len(".weight_packed")]
84+
hf_key = base + ".weight"
85+
parts = hf_key.split(".")
86+
layer = int(parts[2]) if "layers" in hf_key else 0
87+
if not ds_ckpt.is_key_allowed(hf_key, ds_ckpt.MTP_KEYS_TO_SKIP):
88+
continue
89+
mapped_key = ds_ckpt.hf_to_maxtext_mapping(
90+
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
91+
).get(hf_key)
92+
if not mapped_key:
93+
continue
94+
raw_base = raw_key[: -len(".weight_packed")]
95+
shape_t = f.get_tensor(raw_base + ".weight_shape")
96+
chkpt_vars[mapped_key] = ds_ckpt.dequantize_pack_quantized_int4(
97+
f.get_tensor(raw_key),
98+
f.get_tensor(raw_base + ".weight_scale"),
99+
shape_t.tolist(),
100+
)
101+
continue
102+
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
103+
continue
104+
68105
parts = key.split(".")
69106
layer = int(parts[2]) if "layers" in key else 0
70107
if key.endswith("_scale_inv"):
@@ -74,7 +111,7 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
74111
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
75112
).get(key)
76113
if mapped_key:
77-
chkpt_vars[mapped_key] = f.get_tensor(key)
114+
chkpt_vars[mapped_key] = f.get_tensor(raw_key)
78115
else:
79116
# This catches keys that are allowed but missing from the mapping dictionary
80117
max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.")

tests/end_to_end/tpu/kimi/Run_Kimi.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# Kimi
1818

19-
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**.
19+
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. Supported models for checkpoint conversion are **Kimi K2 (1T)**, **Kimi K2-Thinking**, **Kimi K2.5** (text branch), and **Kimi K2.6** (text branch). All variants share the same DeepSeek-V3-style architecture (61 layers, 384 routed experts, MLA).
2020

2121
* **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks.
2222
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
@@ -45,6 +45,30 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam
4545
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE
4646
```
4747

48+
### Quantized variants: K2-Thinking, K2.5, K2.6
49+
50+
K2-Thinking, K2.5, and K2.6 ship with routed-expert weights stored as `compressed-tensors` `pack-quantized` int4 (group_size=32, symmetric). The FP8→BF16 step above does not apply; the converter dequantizes the int4 expert tiles to bf16 inline. Attention, shared experts, dense MLP, and `lm_head` are already bf16 in the released checkpoints (per the `quantization_config.ignore` regex list) and pass through unchanged.
51+
52+
K2.5 and K2.6 are multimodal wrappers (`KimiK25ForConditionalGeneration`); the converter strips the `language_model.` prefix and silently drops `vision_tower.*` / `mm_projector.*` keys to produce a text-only MaxText checkpoint.
53+
54+
1. Download the model from HuggingFace — pick one:
55+
```sh
56+
hf download moonshotai/Kimi-K2-Thinking --local-dir $LOCAL_HF_PATH # --model_size kimi-k2-thinking
57+
hf download moonshotai/Kimi-K2.5 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.5-text
58+
hf download moonshotai/Kimi-K2.6 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.6-text
59+
```
60+
61+
2. Convert directly to Orbax (no intermediate FP8→BF16 pass):
62+
```sh
63+
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt \
64+
--model_size <kimi-k2-thinking | kimi-k2.5-text | kimi-k2.6-text> \
65+
--base_model_path $LOCAL_HF_PATH \
66+
--maxtext_model_path $GCS_PATH_TO_SAVE
67+
```
68+
Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout.
69+
70+
> **Note:** Pre-training / fine-tuning / decoding flows below currently reference `model_name=kimi-k2-1t`. To run those for the new variants, a parallel `src/maxtext/configs/models/<variant>.yml` must be added; the converter changes here only cover the checkpoint side.
71+
4872
## Pre-training
4973
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). To use MuonClip optimizer, you need `optax>=0.2.7` and `tokamax>=0.0.11`.
5074

0 commit comments

Comments
 (0)