|
48 | 48 | absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log |
49 | 49 |
|
50 | 50 |
|
| 51 | +# Group size used by Kimi-K2 quantized variants for per-group symmetric int4 scales. |
| 52 | +# Matches `quantization_config.config_groups.group_0.weights.group_size` in the HF config. |
| 53 | +INT4_GROUP_SIZE = 32 |
| 54 | + |
| 55 | + |
| 56 | +def dequantize_pack_quantized_int4( |
| 57 | + packed: torch.Tensor, |
| 58 | + scale: torch.Tensor, |
| 59 | + out_shape, |
| 60 | +) -> torch.Tensor: |
| 61 | + """Dequantize a compressed-tensors pack-quantized int4 weight to bf16. |
| 62 | +
|
| 63 | + packed: int32 [out, in/8]. Each int32 packs 8 weights along the input dim; |
| 64 | + weight k lives in bits [4k : 4k+4] (so the first weight is in the low 4 bits). |
| 65 | + scale: bf16/fp16 [out, in/group_size], symmetric (no zero point). |
| 66 | +
|
| 67 | + Each 4-bit value is unsigned 0..15; subtract 8 to get the signed weight in [-8, 7]. |
| 68 | + """ |
| 69 | + out_features, in_features = int(out_shape[0]), int(out_shape[1]) |
| 70 | + if in_features % INT4_GROUP_SIZE != 0: |
| 71 | + raise ValueError(f"in_features={in_features} not divisible by group_size={INT4_GROUP_SIZE}") |
| 72 | + |
| 73 | + shifts = torch.arange(8, dtype=torch.int32) * 4 |
| 74 | + nibbles = (packed.to(torch.int32).unsqueeze(-1) >> shifts) & 0xF |
| 75 | + w_int = (nibbles - 8).reshape(out_features, -1)[:, :in_features].to(torch.float32) |
| 76 | + |
| 77 | + s = scale.to(torch.float32).unsqueeze(-1) |
| 78 | + w = (w_int.reshape(out_features, in_features // INT4_GROUP_SIZE, INT4_GROUP_SIZE) * s).reshape( |
| 79 | + out_features, in_features |
| 80 | + ) |
| 81 | + return w.to(torch.bfloat16) |
| 82 | + |
| 83 | + |
51 | 84 | MODEL_PARAMS_DICT = { |
52 | 85 | "deepseek2-16b": { |
53 | 86 | "num_layers": 27, |
|
88 | 121 | "v_head_dim": 128, |
89 | 122 | "has_mtp": False, |
90 | 123 | }, |
| 124 | + "kimi-k2-thinking": { |
| 125 | + "num_layers": 61, |
| 126 | + "first_num_dense_layers": 1, |
| 127 | + "base_num_query_heads": 64, |
| 128 | + "base_emb_dim": 7168, |
| 129 | + "num_experts": 384, |
| 130 | + "q_lora_rank": 1536, |
| 131 | + "kv_lora_rank": 512, |
| 132 | + "qk_nope_head_dim": 128, |
| 133 | + "qk_rope_head_dim": 64, |
| 134 | + "v_head_dim": 128, |
| 135 | + "has_mtp": False, |
| 136 | + # Only routed-expert projections are int4 (group_size=32, symmetric); |
| 137 | + # attention, shared experts, dense MLP, and lm_head stay bf16. |
| 138 | + "compressed_int4": True, |
| 139 | + }, |
| 140 | + # Multimodal wrapper; text branch matches kimi-k2-thinking but keys are |
| 141 | + # prefixed `language_model.`. Vision keys are dropped (text-only target). |
| 142 | + "kimi-k2.5-text": { |
| 143 | + "num_layers": 61, |
| 144 | + "first_num_dense_layers": 1, |
| 145 | + "base_num_query_heads": 64, |
| 146 | + "base_emb_dim": 7168, |
| 147 | + "num_experts": 384, |
| 148 | + "q_lora_rank": 1536, |
| 149 | + "kv_lora_rank": 512, |
| 150 | + "qk_nope_head_dim": 128, |
| 151 | + "qk_rope_head_dim": 64, |
| 152 | + "v_head_dim": 128, |
| 153 | + "has_mtp": False, |
| 154 | + "compressed_int4": True, |
| 155 | + "hf_key_prefix": "language_model.", |
| 156 | + }, |
| 157 | + # K2.6 reuses the K2.5 multimodal wrapper (KimiK25ForConditionalGeneration); |
| 158 | + # text branch shape and quantization layout are identical to kimi-k2.5-text. |
| 159 | + "kimi-k2.6-text": { |
| 160 | + "num_layers": 61, |
| 161 | + "first_num_dense_layers": 1, |
| 162 | + "base_num_query_heads": 64, |
| 163 | + "base_emb_dim": 7168, |
| 164 | + "num_experts": 384, |
| 165 | + "q_lora_rank": 1536, |
| 166 | + "kv_lora_rank": 512, |
| 167 | + "qk_nope_head_dim": 128, |
| 168 | + "qk_rope_head_dim": 64, |
| 169 | + "v_head_dim": 128, |
| 170 | + "has_mtp": False, |
| 171 | + "compressed_int4": True, |
| 172 | + "hf_key_prefix": "language_model.", |
| 173 | + }, |
91 | 174 | } |
92 | 175 |
|
93 | 176 |
|
@@ -232,20 +315,52 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info, |
232 | 315 |
|
233 | 316 | ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) |
234 | 317 | chkpt_vars = {} |
| 318 | + is_compressed = bool(model_params.get("compressed_int4", False)) |
| 319 | + hf_key_prefix = model_params.get("hf_key_prefix", "") # for multimodal text-only |
| 320 | + |
| 321 | + def _normalize(raw_key): |
| 322 | + """Strip multimodal prefix; return None to drop keys outside the text branch.""" |
| 323 | + if not hf_key_prefix: |
| 324 | + return raw_key |
| 325 | + if raw_key.startswith(hf_key_prefix): |
| 326 | + return raw_key[len(hf_key_prefix) :] |
| 327 | + return None |
| 328 | + |
235 | 329 | for i, ckpt_path in enumerate(ckpt_paths): |
236 | 330 | max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") |
237 | 331 | with safe_open(ckpt_path, framework="pt", device="cpu") as f: |
238 | | - for key in f.keys(): |
239 | | - parts = key.split(".") |
240 | | - layer = int(parts[2]) if "layers" in key else 0 |
| 332 | + for raw_key in f.keys(): |
| 333 | + key = _normalize(raw_key) |
| 334 | + if key is None: |
| 335 | + continue # vision_tower / mm_projector etc. when text-only |
241 | 336 | if key.endswith("_scale_inv"): |
242 | 337 | raise ValueError("fp8 checkpoint is not supported.") |
243 | | - if is_key_allowed(key, MTP_KEYS_TO_SKIP): |
244 | | - mapped_key = hf_to_maxtext_mapping( |
245 | | - layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp |
246 | | - ).get(key) |
247 | | - if mapped_key: |
248 | | - chkpt_vars[mapped_key] = f.get_tensor(key) |
| 338 | + if is_compressed and key.endswith((".weight_scale", ".weight_shape")): |
| 339 | + continue # consumed alongside the matching .weight_packed |
| 340 | + |
| 341 | + # Compressed weights advertise as ".weight_packed"; resolve back to the |
| 342 | + # logical ".weight" name so it matches the maxtext mapping table. |
| 343 | + is_packed = is_compressed and key.endswith(".weight_packed") |
| 344 | + hf_key = key[: -len(".weight_packed")] + ".weight" if is_packed else key |
| 345 | + |
| 346 | + if not is_key_allowed(hf_key, MTP_KEYS_TO_SKIP): |
| 347 | + continue |
| 348 | + layer = int(hf_key.split(".")[2]) if "layers" in hf_key else 0 |
| 349 | + mapped_key = hf_to_maxtext_mapping( |
| 350 | + layer, num_experts, first_num_dense_layers, base_num_decoder_layers, has_mtp |
| 351 | + ).get(hf_key) |
| 352 | + if not mapped_key: |
| 353 | + continue |
| 354 | + |
| 355 | + if is_packed: |
| 356 | + raw_base = raw_key[: -len(".weight_packed")] |
| 357 | + chkpt_vars[mapped_key] = dequantize_pack_quantized_int4( |
| 358 | + f.get_tensor(raw_key), |
| 359 | + f.get_tensor(raw_base + ".weight_scale"), |
| 360 | + f.get_tensor(raw_base + ".weight_shape").tolist(), |
| 361 | + ) |
| 362 | + else: |
| 363 | + chkpt_vars[mapped_key] = f.get_tensor(raw_key) |
249 | 364 |
|
250 | 365 | logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) |
251 | 366 |
|
|
0 commit comments