diff --git a/pytest.ini b/pytest.ini index 185a0157c9..10ed0cc6f5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -11,6 +11,7 @@ addopts = --ignore=tests/integration/smoke/train_smoke_test.py --ignore=tests/integration/smoke/train_using_ragged_dot_smoke_test.py --ignore=tests/unit/dequantize_mxfp4_test.py + --ignore=tests/unit/dequantize_pack_quantized_int4_test.py --ignore=tests/unit/gemma3_layers_test.py --ignore=tests/unit/gpt_vs_reference_test.py --ignore=tests/unit/llama4_layers_test.py diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py index dcaa323362..9880e164f7 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py @@ -48,6 +48,39 @@ absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log +# Group size used by Kimi-K2 quantized variants for per-group symmetric int4 scales. +# Matches `quantization_config.config_groups.group_0.weights.group_size` in the HF config. +INT4_GROUP_SIZE = 32 + + +def dequantize_pack_quantized_int4( + packed: torch.Tensor, + scale: torch.Tensor, + out_shape, +) -> torch.Tensor: + """Dequantize a compressed-tensors pack-quantized int4 weight to bf16. + + packed: int32 [out, in/8]. Each int32 packs 8 weights along the input dim; + weight k lives in bits [4k : 4k+4] (so the first weight is in the low 4 bits). + scale: bf16/fp16 [out, in/group_size], symmetric (no zero point). + + Each 4-bit value is unsigned 0..15; subtract 8 to get the signed weight in [-8, 7]. + """ + out_features, in_features = int(out_shape[0]), int(out_shape[1]) + if in_features % INT4_GROUP_SIZE != 0: + raise ValueError(f"in_features={in_features} not divisible by group_size={INT4_GROUP_SIZE}") + + shifts = torch.arange(8, dtype=torch.int32) * 4 + nibbles = (packed.to(torch.int32).unsqueeze(-1) >> shifts) & 0xF + w_int = (nibbles - 8).reshape(out_features, -1)[:, :in_features].to(torch.float32) + + s = scale.to(torch.float32).unsqueeze(-1) + w = (w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE) * s).reshape( + out_features, in_features + ) + return w.to(torch.bfloat16) + + MODEL_PARAMS_DICT = { "deepseek2-16b": { "num_layers": 27, @@ -88,6 +121,56 @@ "v_head_dim": 128, "has_mtp": False, }, + "kimi-k2-thinking": { + "num_layers": 61, + "first_num_dense_layers": 1, + "base_num_query_heads": 64, + "base_emb_dim": 7168, + "num_experts": 384, + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "has_mtp": False, + # Only routed-expert projections are int4 (group_size=32, symmetric); + # attention, shared experts, dense MLP, and lm_head stay bf16. + "compressed_int4": True, + }, + # Multimodal wrapper; text branch matches kimi-k2-thinking but keys are + # prefixed `language_model.`. Vision keys are dropped (text-only target). + "kimi-k2.5-text": { + "num_layers": 61, + "first_num_dense_layers": 1, + "base_num_query_heads": 64, + "base_emb_dim": 7168, + "num_experts": 384, + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "has_mtp": False, + "compressed_int4": True, + "hf_key_prefix": "language_model.", + }, + # K2.6 reuses the K2.5 multimodal wrapper (KimiK25ForConditionalGeneration); + # text branch shape and quantization layout are identical to kimi-k2.5-text. + "kimi-k2.6-text": { + "num_layers": 61, + "first_num_dense_layers": 1, + "base_num_query_heads": 64, + "base_emb_dim": 7168, + "num_experts": 384, + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "has_mtp": False, + "compressed_int4": True, + "hf_key_prefix": "language_model.", + }, } @@ -232,20 +315,52 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info, ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) chkpt_vars = {} + is_compressed = bool(model_params.get("compressed_int4", False)) + hf_key_prefix = model_params.get("hf_key_prefix", "") # for multimodal text-only + + def _normalize(raw_key): + """Strip multimodal prefix; return None to drop keys outside the text branch.""" + if not hf_key_prefix: + return raw_key + if raw_key.startswith(hf_key_prefix): + return raw_key[len(hf_key_prefix) :] + return None + for i, ckpt_path in enumerate(ckpt_paths): max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") with safe_open(ckpt_path, framework="pt", device="cpu") as f: - for key in f.keys(): - parts = key.split(".") - layer = int(parts[2]) if "layers" in key else 0 + for raw_key in f.keys(): + key = _normalize(raw_key) + if key is None: + continue # vision_tower / mm_projector etc. when text-only if key.endswith("_scale_inv"): raise ValueError("fp8 checkpoint is not supported.") - if is_key_allowed(key, MTP_KEYS_TO_SKIP): - mapped_key = hf_to_maxtext_mapping( - layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp - ).get(key) - if mapped_key: - chkpt_vars[mapped_key] = f.get_tensor(key) + if is_compressed and key.endswith((".weight_scale", ".weight_shape")): + continue # consumed alongside the matching .weight_packed + + # Compressed weights advertise as ".weight_packed"; resolve back to the + # logical ".weight" name so it matches the maxtext mapping table. + is_packed = is_compressed and key.endswith(".weight_packed") + hf_key = key[: -len(".weight_packed")] + ".weight" if is_packed else key + + if not is_key_allowed(hf_key, MTP_KEYS_TO_SKIP): + continue + layer = int(hf_key.split(".")[2]) if "layers" in hf_key else 0 + mapped_key = hf_to_maxtext_mapping( + layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp + ).get(hf_key) + if not mapped_key: + continue + + if is_packed: + raw_base = raw_key[: -len(".weight_packed")] + chkpt_vars[mapped_key] = dequantize_pack_quantized_int4( + f.get_tensor(raw_key), + f.get_tensor(raw_base + ".weight_scale"), + f.get_tensor(raw_base + ".weight_shape").tolist(), + ) + else: + chkpt_vars[mapped_key] = f.get_tensor(raw_key) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py index f5e978c80d..97d2f3954f 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py @@ -61,10 +61,47 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) chkpt_vars = {} + is_compressed = bool(model_params.get("compressed_int4", False)) + hf_key_prefix = model_params.get("hf_key_prefix", "") + + def _normalize(raw_key): + if not hf_key_prefix: + return raw_key + if raw_key.startswith(hf_key_prefix): + return raw_key[len(hf_key_prefix) :] + return None + for i, ckpt_path in enumerate(ckpt_paths): max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") with safe_open(ckpt_path, framework="pt", device="cpu") as f: - for key in f.keys(): + for raw_key in f.keys(): + key = _normalize(raw_key) + if key is None: + continue + + if is_compressed and key.endswith(".weight_packed"): + base = key[: -len(".weight_packed")] + hf_key = base + ".weight" + parts = hf_key.split(".") + layer = int(parts[2]) if "layers" in hf_key else 0 + if not ds_ckpt.is_key_allowed(hf_key, ds_ckpt.MTP_KEYS_TO_SKIP): + continue + mapped_key = ds_ckpt.hf_to_maxtext_mapping( + layer, num_experts, first_num_dense_layers, base_num_decoder_layers + ).get(hf_key) + if not mapped_key: + continue + raw_base = raw_key[: -len(".weight_packed")] + shape_t = f.get_tensor(raw_base + ".weight_shape") + chkpt_vars[mapped_key] = ds_ckpt.dequantize_pack_quantized_int4( + f.get_tensor(raw_key), + f.get_tensor(raw_base + ".weight_scale"), + shape_t.tolist(), + ) + continue + if is_compressed and key.endswith((".weight_scale", ".weight_shape")): + continue + parts = key.split(".") layer = int(parts[2]) if "layers" in key else 0 if key.endswith("_scale_inv"): @@ -74,7 +111,7 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) layer, num_experts, first_num_dense_layers, base_num_decoder_layers ).get(key) if mapped_key: - chkpt_vars[mapped_key] = f.get_tensor(key) + chkpt_vars[mapped_key] = f.get_tensor(raw_key) else: # This catches keys that are allowed but missing from the mapping dictionary max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.") diff --git a/tests/end_to_end/tpu/kimi/Run_Kimi.md b/tests/end_to_end/tpu/kimi/Run_Kimi.md index 2e80205d18..702a36712e 100644 --- a/tests/end_to_end/tpu/kimi/Run_Kimi.md +++ b/tests/end_to_end/tpu/kimi/Run_Kimi.md @@ -16,7 +16,7 @@ # Kimi -Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**. +Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. Supported models for checkpoint conversion are **Kimi K2 (1T)**, **Kimi K2-Thinking**, **Kimi K2.5** (text branch), and **Kimi K2.6** (text branch). All variants share the same DeepSeek-V3-style architecture (61 layers, 384 routed experts, MLA). * **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks. * **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training. @@ -45,6 +45,30 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt --model_size kimi-k2-1t --base_model_path $LOCAL_BF16_PATH --maxtext_model_path $GCS_PATH_TO_SAVE ``` +### Quantized variants: K2-Thinking, K2.5, K2.6 + +K2-Thinking/K2.5/K2.6 ship routed-expert weights as int4 (compressed-tensors pack-quantized, group_size=32 symmetric); the converter dequantizes them inline to bf16. The FP8→bf16 path above does not apply. Other tensors (attention, shared experts, dense MLP, `lm_head`) are already bf16 in the checkpoint and pass through unchanged. + +K2.5 and K2.6 are multimodal wrappers (`KimiK25ForConditionalGeneration`); the converter strips the `language_model.` prefix and silently drops `vision_tower.*` / `mm_projector.*` keys to produce a text-only MaxText checkpoint. + +1. Download the model from HuggingFace — pick one: +```sh +hf download moonshotai/Kimi-K2-Thinking --local-dir $LOCAL_HF_PATH # --model_size kimi-k2-thinking +hf download moonshotai/Kimi-K2.5 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.5-text +hf download moonshotai/Kimi-K2.6 --local-dir $LOCAL_HF_PATH # --model_size kimi-k2.6-text +``` + +2. Convert directly to Orbax (no intermediate FP8→BF16 pass): +```sh +python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt \ + --model_size \ + --base_model_path $LOCAL_HF_PATH \ + --maxtext_model_path $GCS_PATH_TO_SAVE +``` +Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout. + +> **Note:** Pre-training / fine-tuning / decoding flows below currently reference `model_name=kimi-k2-1t`. To run those for the new variants, a parallel `src/maxtext/configs/models/.yml` must be added; the converter changes here only cover the checkpoint side. + ## Pre-training You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). To use MuonClip optimizer, you need `optax>=0.2.7` and `tokamax>=0.0.11`. diff --git a/tests/unit/dequantize_pack_quantized_int4_test.py b/tests/unit/dequantize_pack_quantized_int4_test.py new file mode 100644 index 0000000000..fd7e419484 --- /dev/null +++ b/tests/unit/dequantize_pack_quantized_int4_test.py @@ -0,0 +1,160 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for dequantize_pack_quantized_int4 (kimi-k2-thinking / kimi-k2.5 path). + +Validated against the canonical compressed_tensors packer/unpacker: + - unpack matches `unpack_from_int32` bit-for-bit + - bf16 dequant matches the (int * bf16_scale).bf16 reference exactly + - signed boundary values (-8, +7) decode correctly + - zero-magnitude groups dequant to exactly 0 + - end-to-end kimi-k2-thinking expert tile shape (in=7168, out=2048, group=32) + +Not run in GitHub runners (depends on torch + compressed_tensors). +""" + +import unittest +import pytest +import torch + +from compressed_tensors.compressors.quantized_compressors.pack_quantized import ( + pack_to_int32, + unpack_from_int32, +) + +from maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_ckpt import ( + INT4_GROUP_SIZE, + dequantize_pack_quantized_int4, +) + + +_NUM_BITS = 4 +_QMAX = 2 ** (_NUM_BITS - 1) - 1 # 7 for symmetric int4 +_QMIN = -(_QMAX + 1) # -8 + + +def _quantize_per_group(w_fp32: torch.Tensor, group_size: int = INT4_GROUP_SIZE): + """Symmetric per-group int4 quantization → (int8 values in [-8, 7], fp32 scale).""" + out_features, in_features = w_fp32.shape + groups = w_fp32.reshape(out_features, in_features // group_size, group_size) + absmax = groups.abs().amax(dim=-1, keepdim=True) + scale = (absmax / _QMAX).clamp(min=1e-12) + w_int = torch.round(groups / scale).clamp(_QMIN, _QMAX).to(torch.int8) + return w_int.reshape(out_features, in_features), scale.squeeze(-1) + + +@pytest.mark.cpu_only +class DequantizePackQuantizedInt4Test(unittest.TestCase): + + def setUp(self): + torch.manual_seed(0) + + def test_unpack_matches_canonical_bit_exact(self): + """Helper's unpack (with scale=1) recovers signed int values matching the + reference unpacker exactly. Pins down nibble endianness and the +8 offset + used by compressed_tensors' pack_to_int32.""" + out_features, in_features = 64, 128 + w_fp = torch.randn(out_features, in_features) * 0.1 + w_int, _ = _quantize_per_group(w_fp) + packed = pack_to_int32(w_int, num_bits=_NUM_BITS) + self.assertEqual(packed.dtype, torch.int32) + self.assertEqual(tuple(packed.shape), (out_features, in_features // 8)) + + canonical = unpack_from_int32(packed, num_bits=_NUM_BITS, shape=torch.Size([out_features, in_features])) + self.assertTrue(torch.equal(canonical.to(torch.int8), w_int)) + + ones = torch.ones(out_features, in_features // INT4_GROUP_SIZE, dtype=torch.float32) + ours = dequantize_pack_quantized_int4(packed, ones, [out_features, in_features]) + self.assertTrue(torch.equal(ours.to(torch.int8), w_int)) + + def test_dequant_matches_bf16_reference(self): + """Full dequant matches (int * bf16_scale).bf16 bit-for-bit.""" + out_features, in_features = 64, 128 + w_fp = torch.randn(out_features, in_features) * 0.1 + w_int, scale_fp32 = _quantize_per_group(w_fp) + packed = pack_to_int32(w_int, num_bits=_NUM_BITS) + + scale_bf16 = scale_fp32.to(torch.bfloat16) + # Round-trip the scale through bf16 on both sides for a fair comparison. + ref = ( + ( + w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE).to(torch.float32) + * scale_bf16.to(torch.float32).unsqueeze(-1) + ) + .reshape(out_features, in_features) + .to(torch.bfloat16) + ) + + ours = dequantize_pack_quantized_int4(packed, scale_bf16, [out_features, in_features]) + self.assertEqual(ours.dtype, torch.bfloat16) + self.assertTrue(torch.equal(ours, ref)) + + def test_signed_boundary_values_decode(self): + """Most-negative (-8) and most-positive (+7) int4 values decode correctly.""" + out_features, in_features = 16, 64 + scale = torch.ones(out_features, in_features // INT4_GROUP_SIZE, dtype=torch.bfloat16) + + minus_eight = torch.full((out_features, in_features), -8, dtype=torch.int8) + packed = pack_to_int32(minus_eight, num_bits=_NUM_BITS) + out = dequantize_pack_quantized_int4(packed, scale, [out_features, in_features]) + self.assertTrue(torch.all(out.to(torch.float32) == -8.0)) + + plus_seven = torch.full((out_features, in_features), 7, dtype=torch.int8) + packed = pack_to_int32(plus_seven, num_bits=_NUM_BITS) + out = dequantize_pack_quantized_int4(packed, scale, [out_features, in_features]) + self.assertTrue(torch.all(out.to(torch.float32) == 7.0)) + + def test_zero_group_dequants_to_zero(self): + """Groups whose magnitudes round to zero ints must dequant to exactly 0.""" + out_features, in_features = 8, 64 + zero_int = torch.zeros(out_features, in_features, dtype=torch.int8) + packed = pack_to_int32(zero_int, num_bits=_NUM_BITS) + scale = torch.full((out_features, in_features // INT4_GROUP_SIZE), 0.123, dtype=torch.bfloat16) + out = dequantize_pack_quantized_int4(packed, scale, [out_features, in_features]) + self.assertTrue(torch.all(out == 0)) + + def test_in_features_not_divisible_raises(self): + out_features, in_features = 4, 48 # 48 % 32 != 0 + int_vals = torch.zeros(out_features, in_features, dtype=torch.int8) + packed = pack_to_int32(int_vals, num_bits=_NUM_BITS) + scale = torch.ones(out_features, 2, dtype=torch.bfloat16) + with self.assertRaises(ValueError): + dequantize_pack_quantized_int4(packed, scale, [out_features, in_features]) + + def test_kimi_k2_thinking_expert_tile_shape(self): + """Realistic tile size: a single expert's gate_proj is [moe_intermediate_size, + hidden_size] = [2048, 7168] for kimi-k2 / kimi-k2-thinking / kimi-k2.5.""" + out_features, in_features = 2048, 7168 + w_fp = torch.randn(out_features, in_features) * 0.05 + w_int, scale_fp32 = _quantize_per_group(w_fp) + packed = pack_to_int32(w_int, num_bits=_NUM_BITS) + + scale_bf16 = scale_fp32.to(torch.bfloat16) + out = dequantize_pack_quantized_int4(packed, scale_bf16, [out_features, in_features]) + self.assertEqual(tuple(out.shape), (out_features, in_features)) + self.assertEqual(out.dtype, torch.bfloat16) + + ref = ( + ( + w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE).to(torch.float32) + * scale_bf16.to(torch.float32).unsqueeze(-1) + ) + .reshape(out_features, in_features) + .to(torch.bfloat16) + ) + self.assertTrue(torch.equal(out, ref)) + + +if __name__ == "__main__": + unittest.main()