Skip to content

Commit a0f16f8

Browse files
Merge pull request #3660 from AI-Hypercomputer:parambole/502609970
PiperOrigin-RevId: 899819197
2 parents a4e2510 + 0d1c41c commit a0f16f8

1 file changed

Lines changed: 32 additions & 0 deletions

File tree

src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,38 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info,
528528
if has_mtp:
529529
max_logging.log("Processing MTP Layer")
530530

531+
# Initialize the mtp_block dictionary structure
532+
jax_weights["mtp_block"] = {
533+
"mtp_layer_1": {
534+
"mtp_1_embedding_norm": {"scale": None},
535+
"mtp_1_hidden_state_norm": {"scale": None},
536+
"mtp_1_projection": {"kernel": None},
537+
"mtp_1_transformer_layer": {
538+
"pre_self_attention_layer_norm": {"scale": None},
539+
"post_self_attention_layer_norm": {"scale": None},
540+
"self_attention": {
541+
"kv_norm": {"scale": None},
542+
"wkv_a": {"kernel": None},
543+
"wkv_b": {"kernel": None},
544+
"out": {"kernel": None},
545+
},
546+
"DeepSeekMoeBlock_0": {
547+
"MoeBlock_0": {
548+
"wi_0": None,
549+
"wi_1": None,
550+
"wo": None,
551+
"gate": {"kernel": None},
552+
},
553+
"shared_experts": {
554+
"wi_0": {"kernel": None},
555+
"wi_1": {"kernel": None},
556+
"wo": {"kernel": None},
557+
},
558+
},
559+
},
560+
}
561+
}
562+
531563
# MTP unique components
532564
jax_weights["mtp_block"]["mtp_layer_1"]["mtp_1_embedding_norm"]["scale"] = (
533565
chkpt_vars["mtp_block.mtp_layer_1.mtp_1_embedding_norm.scale"].to(torch.float16).numpy()

0 commit comments

Comments
 (0)