Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ addopts =
--ignore=tests/integration/smoke/train_smoke_test.py
--ignore=tests/integration/smoke/train_using_ragged_dot_smoke_test.py
--ignore=tests/unit/dequantize_mxfp4_test.py
--ignore=tests/unit/dequantize_pack_quantized_int4_test.py
--ignore=tests/unit/gemma3_layers_test.py
--ignore=tests/unit/gpt_vs_reference_test.py
--ignore=tests/unit/llama4_layers_test.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,39 @@
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log


# Group size used by Kimi-K2 quantized variants for per-group symmetric int4 scales.
# Matches `quantization_config.config_groups.group_0.weights.group_size` in the HF config.
INT4_GROUP_SIZE = 32


def dequantize_pack_quantized_int4(
Comment thread
gagika marked this conversation as resolved.
packed: torch.Tensor,
scale: torch.Tensor,
out_shape,
) -> torch.Tensor:
"""Dequantize a compressed-tensors pack-quantized int4 weight to bf16.

packed: int32 [out, in/8]. Each int32 packs 8 weights along the input dim;
weight k lives in bits [4k : 4k+4] (so the first weight is in the low 4 bits).
scale: bf16/fp16 [out, in/group_size], symmetric (no zero point).

Each 4-bit value is unsigned 0..15; subtract 8 to get the signed weight in [-8, 7].
"""
out_features, in_features = int(out_shape[0]), int(out_shape[1])
if in_features % INT4_GROUP_SIZE != 0:
raise ValueError(f"in_features={in_features} not divisible by group_size={INT4_GROUP_SIZE}")

shifts = torch.arange(8, dtype=torch.int32) * 4
nibbles = (packed.to(torch.int32).unsqueeze(-1) >> shifts) & 0xF
w_int = (nibbles - 8).reshape(out_features, -1)[:, :in_features].to(torch.float32)

s = scale.to(torch.float32).unsqueeze(-1)
w = (w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE) * s).reshape(
out_features, in_features
)
return w.to(torch.bfloat16)


MODEL_PARAMS_DICT = {
"deepseek2-16b": {
"num_layers": 27,
Expand Down Expand Up @@ -88,6 +121,56 @@
"v_head_dim": 128,
"has_mtp": False,
},
"kimi-k2-thinking": {
"num_layers": 61,
"first_num_dense_layers": 1,
"base_num_query_heads": 64,
"base_emb_dim": 7168,
"num_experts": 384,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"has_mtp": False,
# Only routed-expert projections are int4 (group_size=32, symmetric);
# attention, shared experts, dense MLP, and lm_head stay bf16.
"compressed_int4": True,
},
# Multimodal wrapper; text branch matches kimi-k2-thinking but keys are
# prefixed `language_model.`. Vision keys are dropped (text-only target).
"kimi-k2.5-text": {
"num_layers": 61,
"first_num_dense_layers": 1,
"base_num_query_heads": 64,
"base_emb_dim": 7168,
"num_experts": 384,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"has_mtp": False,
"compressed_int4": True,
"hf_key_prefix": "language_model.",
},
# K2.6 reuses the K2.5 multimodal wrapper (KimiK25ForConditionalGeneration);
# text branch shape and quantization layout are identical to kimi-k2.5-text.
"kimi-k2.6-text": {
"num_layers": 61,
"first_num_dense_layers": 1,
"base_num_query_heads": 64,
"base_emb_dim": 7168,
"num_experts": 384,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"has_mtp": False,
"compressed_int4": True,
"hf_key_prefix": "language_model.",
},
}


Expand Down Expand Up @@ -232,20 +315,52 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info,

ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
chkpt_vars = {}
is_compressed = bool(model_params.get("compressed_int4", False))
hf_key_prefix = model_params.get("hf_key_prefix", "") # for multimodal text-only

def _normalize(raw_key):
"""Strip multimodal prefix; return None to drop keys outside the text branch."""
if not hf_key_prefix:
return raw_key
if raw_key.startswith(hf_key_prefix):
return raw_key[len(hf_key_prefix) :]
return None

for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
parts = key.split(".")
layer = int(parts[2]) if "layers" in key else 0
for raw_key in f.keys():
Comment thread
gagika marked this conversation as resolved.
key = _normalize(raw_key)
if key is None:
continue # vision_tower / mm_projector etc. when text-only
if key.endswith("_scale_inv"):
raise ValueError("fp8 checkpoint is not supported.")
if is_key_allowed(key, MTP_KEYS_TO_SKIP):
mapped_key = hf_to_maxtext_mapping(
layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp
).get(key)
if mapped_key:
chkpt_vars[mapped_key] = f.get_tensor(key)
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
continue # consumed alongside the matching .weight_packed

# Compressed weights advertise as ".weight_packed"; resolve back to the
# logical ".weight" name so it matches the maxtext mapping table.
is_packed = is_compressed and key.endswith(".weight_packed")
hf_key = key[: -len(".weight_packed")] + ".weight" if is_packed else key

if not is_key_allowed(hf_key, MTP_KEYS_TO_SKIP):
continue
layer = int(hf_key.split(".")[2]) if "layers" in hf_key else 0
mapped_key = hf_to_maxtext_mapping(
layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp
).get(hf_key)
if not mapped_key:
continue

if is_packed:
raw_base = raw_key[: -len(".weight_packed")]
chkpt_vars[mapped_key] = dequantize_pack_quantized_int4(
f.get_tensor(raw_key),
f.get_tensor(raw_base + ".weight_scale"),
f.get_tensor(raw_base + ".weight_shape").tolist(),
)
else:
chkpt_vars[mapped_key] = f.get_tensor(raw_key)

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,47 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)

ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
chkpt_vars = {}
is_compressed = bool(model_params.get("compressed_int4", False))
hf_key_prefix = model_params.get("hf_key_prefix", "")

def _normalize(raw_key):
if not hf_key_prefix:
return raw_key
if raw_key.startswith(hf_key_prefix):
return raw_key[len(hf_key_prefix) :]
return None

for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
for raw_key in f.keys():
key = _normalize(raw_key)
if key is None:
continue

if is_compressed and key.endswith(".weight_packed"):
base = key[: -len(".weight_packed")]
hf_key = base + ".weight"
parts = hf_key.split(".")
layer = int(parts[2]) if "layers" in hf_key else 0
if not ds_ckpt.is_key_allowed(hf_key, ds_ckpt.MTP_KEYS_TO_SKIP):
continue
mapped_key = ds_ckpt.hf_to_maxtext_mapping(
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
).get(hf_key)
if not mapped_key:
continue
raw_base = raw_key[: -len(".weight_packed")]
shape_t = f.get_tensor(raw_base + ".weight_shape")
chkpt_vars[mapped_key] = ds_ckpt.dequantize_pack_quantized_int4(
f.get_tensor(raw_key),
f.get_tensor(raw_base + ".weight_scale"),
shape_t.tolist(),
)
continue
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
continue

parts = key.split(".")
layer = int(parts[2]) if "layers" in key else 0
if key.endswith("_scale_inv"):
Expand All @@ -74,7 +111,7 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
).get(key)
if mapped_key:
chkpt_vars[mapped_key] = f.get_tensor(key)
chkpt_vars[mapped_key] = f.get_tensor(raw_key)
else:
# This catches keys that are allowed but missing from the mapping dictionary
max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.")
Expand Down
26 changes: 25 additions & 1 deletion tests/end_to_end/tpu/kimi/Run_Kimi.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# Kimi

Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**.
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. Supported models for checkpoint conversion are **Kimi K2 (1T)**, **Kimi K2-Thinking**, **Kimi K2.5** (text branch), and **Kimi K2.6** (text branch). All variants share the same DeepSeek-V3-style architecture (61 layers, 384 routed experts, MLA).

* **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks.
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
Expand Down Expand Up @@ -45,6 +45,30 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE
```

### Quantized variants: K2-Thinking, K2.5, K2.6

K2-Thinking/K2.5/K2.6 ship routed-expert weights as int4 (compressed-tensors pack-quantized, group_size=32 symmetric); the converter dequantizes them inline to bf16. The FP8→bf16 path above does not apply. Other tensors (attention, shared experts, dense MLP, `lm_head`) are already bf16 in the checkpoint and pass through unchanged.

K2.5 and K2.6 are multimodal wrappers (`KimiK25ForConditionalGeneration`); the converter strips the `language_model.` prefix and silently drops `vision_tower.*` / `mm_projector.*` keys to produce a text-only MaxText checkpoint.

1. Download the model from HuggingFace — pick one:
```sh
hf download moonshotai/Kimi-K2-Thinking --local-dir $LOCAL_HF_PATH # --model_size kimi-k2-thinking
hf download moonshotai/Kimi-K2.5 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.5-text
hf download moonshotai/Kimi-K2.6 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.6-text
```

2. Convert directly to Orbax (no intermediate FP8→BF16 pass):
```sh
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt \
--model_size <kimi-k2-thinking | kimi-k2.5-text | kimi-k2.6-text> \
--base_model_path $LOCAL_HF_PATH \
--maxtext_model_path $GCS_PATH_TO_SAVE
```
Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout.

> **Note:** Pre-training / fine-tuning / decoding flows below currently reference `model_name=kimi-k2-1t`. To run those for the new variants, a parallel `src/maxtext/configs/models/<variant>.yml` must be added; the converter changes here only cover the checkpoint side.

## Pre-training
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). To use MuonClip optimizer, you need `optax>=0.2.7` and `tokamax>=0.0.11`.

Expand Down
Loading
Loading