@@ -653,8 +653,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
653653 shard_mode = self .config .shard_mode ,
654654 rngs = self .rngs ,
655655 )
656+ elif self .config .fused_mla_lora_proj :
657+ # Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim).
658+ self .wq_kv_a = DenseGeneral (
659+ in_features_shape = self .config .emb_dim ,
660+ out_features_shape = self .q_lora_rank + self .kv_lora_rank + self .qk_rope_head_dim ,
661+ axis = - 1 ,
662+ kernel_init = self .kernel_init ,
663+ kernel_axes = ("embed" , "q_kv_lora_up_proj" ),
664+ dtype = self .dtype ,
665+ weight_dtype = self .weight_dtype ,
666+ quant = self .quant ,
667+ matmul_precision = self .config .matmul_precision ,
668+ shard_mode = self .config .shard_mode ,
669+ rngs = self .rngs ,
670+ )
671+ self .q_norm = RMSNorm (
672+ num_features = self .q_lora_rank ,
673+ dtype = self .config .dtype ,
674+ weight_dtype = self .config .weight_dtype ,
675+ epsilon = self .config .normalization_layer_epsilon ,
676+ kernel_axes = ("norm" ,),
677+ rngs = self .rngs ,
678+ )
679+ self .wq_b = DenseGeneral (
680+ in_features_shape = self .q_lora_rank ,
681+ out_features_shape = (self .num_query_heads , self .qk_head_dim ),
682+ axis = - 1 ,
683+ kernel_init = self .kernel_init ,
684+ kernel_axes = ("q_lora" , "q_heads" , "kv" ),
685+ dtype = self .dtype ,
686+ weight_dtype = self .weight_dtype ,
687+ quant = self .quant ,
688+ matmul_precision = self .config .matmul_precision ,
689+ shard_mode = self .config .shard_mode ,
690+ rngs = self .rngs ,
691+ )
656692 else :
657- # LoRA path for Q .
693+ # Separate Q LoRA up-projection .
658694 self .wq_a = DenseGeneral (
659695 in_features_shape = self .config .emb_dim ,
660696 out_features_shape = self .q_lora_rank ,
@@ -690,20 +726,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
690726 rngs = self .rngs ,
691727 )
692728
693- # KV LoRA path.
694- self .wkv_a = DenseGeneral (
695- in_features_shape = self .config .emb_dim ,
696- out_features_shape = self .kv_lora_rank + self .qk_rope_head_dim ,
697- axis = - 1 ,
698- kernel_init = self .kernel_init ,
699- kernel_axes = ("embed" , "kv_lora_up_proj" ),
700- dtype = self .dtype ,
701- weight_dtype = self .weight_dtype ,
702- quant = self .quant ,
703- matmul_precision = self .config .matmul_precision ,
704- shard_mode = self .config .shard_mode ,
705- rngs = self .rngs ,
706- )
729+ if not self .config .fused_mla_lora_proj :
730+ # KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV.
731+ self .wkv_a = DenseGeneral (
732+ in_features_shape = self .config .emb_dim ,
733+ out_features_shape = self .kv_lora_rank + self .qk_rope_head_dim ,
734+ axis = - 1 ,
735+ kernel_init = self .kernel_init ,
736+ kernel_axes = ("embed" , "kv_lora_up_proj" ),
737+ dtype = self .dtype ,
738+ weight_dtype = self .weight_dtype ,
739+ quant = self .quant ,
740+ matmul_precision = self .config .matmul_precision ,
741+ shard_mode = self .config .shard_mode ,
742+ rngs = self .rngs ,
743+ )
707744 self .kv_norm = RMSNorm (
708745 num_features = self .kv_lora_rank ,
709746 dtype = self .config .dtype ,
@@ -791,8 +828,11 @@ def mla_query_projection(
791828 if self .q_lora_rank == 0 :
792829 q = self .query (inputs_q , out_sharding = query_sharding )
793830 else :
794- # LoRA path
795- low_rank_q = self .wq_a (inputs_q , out_sharding = wqa_out_sharding ) # [B, L, q_lora_rank]
831+ # LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused).
832+ if not self .config .fused_mla_lora_proj :
833+ low_rank_q = self .wq_a (inputs_q , out_sharding = wqa_out_sharding ) # [B, L, q_lora_rank]
834+ else :
835+ low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__
796836 low_rank_q = self .q_norm (low_rank_q ) # RMSNorm on low rank
797837 low_rank_q = checkpoint_name (low_rank_q , "mla_q" )
798838 q = self .wq_b (low_rank_q , out_sharding = query_sharding ) # [B, L, n_heads, qk_head_dim]
@@ -931,7 +971,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
931971 else :
932972 wka_logical_name = (KV_BATCH , LENGTH_NO_EXP , KV_LORA_UP_PROJ )
933973 wkva_out_sharding = create_sharding (self .mesh , wka_logical_name )
934- low_rank = self .wkv_a (inputs , out_sharding = wkva_out_sharding )
974+ if self .config .fused_mla_lora_proj :
975+ low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__
976+ else :
977+ low_rank = self .wkv_a (inputs , out_sharding = wkva_out_sharding )
935978 low_rank_main , low_rank_rope = jnp .split (low_rank , [self .kv_lora_rank ], axis = - 1 )
936979 low_rank_main = self .kv_norm (low_rank_main )
937980 low_rank_main = checkpoint_name (low_rank_main , "mla_kv" )
@@ -1002,12 +1045,23 @@ def __call__(
10021045 inputs_kv = self ._maybe_shard_with_logical (inputs_kv , self .input_axis_names )
10031046 out_logical_name = (BATCH , LENGTH_NO_EXP , HEAD , D_KV )
10041047
1005- query , low_rank_q = self .mla_query_projection (inputs_q , inputs_positions , model_mode )
1006- if self .config .force_q_layout :
1007- query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
1008- key , value , cached_values = self .mla_kv_projection (
1009- inputs_kv , inputs_positions , decoder_segment_ids , model_mode , previous_chunk
1010- )
1048+ if self .config .fused_mla_lora_proj :
1049+ # Single matmul for both Q and KV LoRA up-projections, then split.
1050+ fused_lora = self .wq_kv_a (inputs_q )
1051+ lora_q , lora_kv = jnp .split (fused_lora , [self .q_lora_rank ], axis = - 1 )
1052+ query , low_rank_q = self .mla_query_projection (lora_q , inputs_positions , model_mode )
1053+ if self .config .force_q_layout :
1054+ query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
1055+ key , value , cached_values = self .mla_kv_projection (
1056+ lora_kv , inputs_positions , decoder_segment_ids , model_mode , previous_chunk
1057+ )
1058+ else :
1059+ query , low_rank_q = self .mla_query_projection (inputs_q , inputs_positions , model_mode )
1060+ if self .config .force_q_layout :
1061+ query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
1062+ key , value , cached_values = self .mla_kv_projection (
1063+ inputs_kv , inputs_positions , decoder_segment_ids , model_mode , previous_chunk
1064+ )
10111065 query = checkpoint_name (query , "query_proj" )
10121066 key = checkpoint_name (key , "key_proj" )
10131067 value = checkpoint_name (value , "value_proj" )
0 commit comments