@@ -63,6 +63,8 @@ class MoEAOAConfigParams:
6363 # Runtime config
6464 model_prefix : str = "model."
6565
66+ index_n_heads : int = 0
67+
6668 # Extra statements to add
6769 extra_statements : List [str ] = field (default_factory = list )
6870
@@ -129,6 +131,7 @@ def _extract_params(cls, config: Any) -> MoEAOAConfigParams:
129131 use_qk_norm = getattr (config , "use_qk_norm" , False ),
130132 has_shared_experts = cls ._has_shared_experts (config ),
131133 model_prefix = cls ._get_model_prefix (config ),
134+ index_n_heads = getattr (config , "index_n_heads" , 0 ),
132135 )
133136
134137 @classmethod
@@ -369,6 +372,23 @@ def _get_mla_attention_statements(cls, params: MoEAOAConfigParams, prefix: str,
369372 ]
370373 )
371374
375+ if params .index_n_heads and params .index_n_heads > 0 :
376+ indexer_weights = [
377+ "wq_b" ,
378+ "wk" ,
379+ "weights_proj" ,
380+ ]
381+ statements .extend (
382+ [
383+ f"{ prefix } .self_attn.indexer.{ weight_name } .weight^T -> { prefix_offset } .self_attn.core_attention.indexer.{ weight_name } .weight"
384+ for weight_name in indexer_weights
385+ ]
386+ )
387+ statements += [
388+ f"{ prefix } .self_attn.indexer.k_norm.bias -> { prefix_offset } .self_attn.core_attention.indexer.k_norm.bias" ,
389+ f"{ prefix } .self_attn.indexer.k_norm.weight -> { prefix_offset } .self_attn.core_attention.indexer.k_norm.weight" ,
390+ ]
391+
372392 return statements
373393
374394 # ==================== MoE Expert Weights ====================
@@ -725,6 +745,23 @@ def _get_inv_mla_attention_statements(
725745 ]
726746 )
727747
748+ if params .index_n_heads and params .index_n_heads > 0 :
749+ indexer_weights = [
750+ "wq_b" ,
751+ "wk" ,
752+ "weights_proj" ,
753+ ]
754+ statements .extend (
755+ [
756+ f"{ prefix_offset } .self_attn.core_attention.indexer.{ weight_name } .weight^T -> { prefix } .self_attn.indexer.{ weight_name } .weight"
757+ for weight_name in indexer_weights
758+ ]
759+ )
760+ statements += [
761+ f"{ prefix_offset } .self_attn.core_attention.indexer.k_norm.bias -> { prefix } .self_attn.indexer.k_norm.bias" ,
762+ f"{ prefix_offset } .self_attn.core_attention.indexer.k_norm.weight -> { prefix } .self_attn.indexer.k_norm.weight" ,
763+ ]
764+
728765 return statements
729766
730767 # ==================== Inverse MoE Expert Weights ====================
0 commit comments