@@ -528,6 +528,38 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info,
528528 if has_mtp :
529529 max_logging .log ("Processing MTP Layer" )
530530
531+ # Initialize the mtp_block dictionary structure
532+ jax_weights ["mtp_block" ] = {
533+ "mtp_layer_1" : {
534+ "mtp_1_embedding_norm" : {"scale" : None },
535+ "mtp_1_hidden_state_norm" : {"scale" : None },
536+ "mtp_1_projection" : {"kernel" : None },
537+ "mtp_1_transformer_layer" : {
538+ "pre_self_attention_layer_norm" : {"scale" : None },
539+ "post_self_attention_layer_norm" : {"scale" : None },
540+ "self_attention" : {
541+ "kv_norm" : {"scale" : None },
542+ "wkv_a" : {"kernel" : None },
543+ "wkv_b" : {"kernel" : None },
544+ "out" : {"kernel" : None },
545+ },
546+ "DeepSeekMoeBlock_0" : {
547+ "MoeBlock_0" : {
548+ "wi_0" : None ,
549+ "wi_1" : None ,
550+ "wo" : None ,
551+ "gate" : {"kernel" : None },
552+ },
553+ "shared_experts" : {
554+ "wi_0" : {"kernel" : None },
555+ "wi_1" : {"kernel" : None },
556+ "wo" : {"kernel" : None },
557+ },
558+ },
559+ },
560+ }
561+ }
562+
531563 # MTP unique components
532564 jax_weights ["mtp_block" ]["mtp_layer_1" ]["mtp_1_embedding_norm" ]["scale" ] = (
533565 chkpt_vars ["mtp_block.mtp_layer_1.mtp_1_embedding_norm.scale" ].to (torch .float16 ).numpy ()
0 commit comments