Skip to content

Commit 8e0fcfb

Browse files
committed
Add fused_mla_lora_proj config flag for MLA LoRA up-projections
Add a `fused_mla_lora_proj` boolean config (default False) that fuses the separate wq_a (emb→q_lora_rank) and wkv_a (emb→kv_lora_rank+rope_head_dim) MLA LoRA up-projection matmuls into a single wq_kv_a matmul (emb→q_lora_rank+kv_lora_rank+rope_head_dim), followed by a split. This halves the number of matmul kernel launches for the LoRA up-projection step. The flag is modelled after the existing `fused_qkv` config and requires `attention_type='mla'` and `q_lora_rank > 0`. Note: wq_kv_a uses a different weight name than wq_a/wkv_a, so checkpoints are not cross-compatible between fused and unfused modes.
1 parent 21e3372 commit 8e0fcfb

3 files changed

Lines changed: 88 additions & 24 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
368368

369369
# Combine matmuls for QKV and MLP
370370
fused_qkv: False
371+
fused_mla_lora_proj: False # Fuse MLA Q+KV LoRA up-projections (wq_a+wkv_a) into a single matmul. Requires q_lora_rank > 0.
371372
fused_mlp: False
372373

373374
record_internal_nn_metrics: 0

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ class ModelArchitecture(BaseModel):
431431
)
432432
normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.")
433433
fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.")
434+
fused_mla_lora_proj: bool = Field(
435+
False,
436+
description="Fuse MLA Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul. Requires q_lora_rank > 0.",
437+
)
434438
attention_bias: bool = Field(
435439
False,
436440
description="If True, adds a learnable bias to the query, key, and value projections.",
@@ -2505,6 +2509,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25052509
"Please disable attn_logits_soft_cap when using use_qk_clip."
25062510
)
25072511

2512+
if self.fused_mla_lora_proj and self.q_lora_rank == 0:
2513+
raise ValueError("`fused_mla_lora_proj` requires `q_lora_rank > 0`.")
2514+
if self.fused_mla_lora_proj and self.attention_type != "mla":
2515+
raise ValueError("`fused_mla_lora_proj` is only valid with `attention_type='mla'`.")
2516+
25082517
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
25092518
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
25102519
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":

src/maxtext/layers/attention_mla.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -653,8 +653,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
653653
shard_mode=self.config.shard_mode,
654654
rngs=self.rngs,
655655
)
656+
elif self.config.fused_mla_lora_proj:
657+
# Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim).
658+
self.wq_kv_a = DenseGeneral(
659+
in_features_shape=self.config.emb_dim,
660+
out_features_shape=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
661+
axis=-1,
662+
kernel_init=self.kernel_init,
663+
kernel_axes=("embed", "q_kv_lora_up_proj"),
664+
dtype=self.dtype,
665+
weight_dtype=self.weight_dtype,
666+
quant=self.quant,
667+
matmul_precision=self.config.matmul_precision,
668+
shard_mode=self.config.shard_mode,
669+
rngs=self.rngs,
670+
)
671+
self.q_norm = RMSNorm(
672+
num_features=self.q_lora_rank,
673+
dtype=self.config.dtype,
674+
weight_dtype=self.config.weight_dtype,
675+
epsilon=self.config.normalization_layer_epsilon,
676+
kernel_axes=("norm",),
677+
rngs=self.rngs,
678+
)
679+
self.wq_b = DenseGeneral(
680+
in_features_shape=self.q_lora_rank,
681+
out_features_shape=(self.num_query_heads, self.qk_head_dim),
682+
axis=-1,
683+
kernel_init=self.kernel_init,
684+
kernel_axes=("q_lora", "q_heads", "kv"),
685+
dtype=self.dtype,
686+
weight_dtype=self.weight_dtype,
687+
quant=self.quant,
688+
matmul_precision=self.config.matmul_precision,
689+
shard_mode=self.config.shard_mode,
690+
rngs=self.rngs,
691+
)
656692
else:
657-
# LoRA path for Q.
693+
# Separate Q LoRA up-projection.
658694
self.wq_a = DenseGeneral(
659695
in_features_shape=self.config.emb_dim,
660696
out_features_shape=self.q_lora_rank,
@@ -690,20 +726,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
690726
rngs=self.rngs,
691727
)
692728

693-
# KV LoRA path.
694-
self.wkv_a = DenseGeneral(
695-
in_features_shape=self.config.emb_dim,
696-
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
697-
axis=-1,
698-
kernel_init=self.kernel_init,
699-
kernel_axes=("embed", "kv_lora_up_proj"),
700-
dtype=self.dtype,
701-
weight_dtype=self.weight_dtype,
702-
quant=self.quant,
703-
matmul_precision=self.config.matmul_precision,
704-
shard_mode=self.config.shard_mode,
705-
rngs=self.rngs,
706-
)
729+
if not self.config.fused_mla_lora_proj:
730+
# KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV.
731+
self.wkv_a = DenseGeneral(
732+
in_features_shape=self.config.emb_dim,
733+
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
734+
axis=-1,
735+
kernel_init=self.kernel_init,
736+
kernel_axes=("embed", "kv_lora_up_proj"),
737+
dtype=self.dtype,
738+
weight_dtype=self.weight_dtype,
739+
quant=self.quant,
740+
matmul_precision=self.config.matmul_precision,
741+
shard_mode=self.config.shard_mode,
742+
rngs=self.rngs,
743+
)
707744
self.kv_norm = RMSNorm(
708745
num_features=self.kv_lora_rank,
709746
dtype=self.config.dtype,
@@ -791,8 +828,11 @@ def mla_query_projection(
791828
if self.q_lora_rank == 0:
792829
q = self.query(inputs_q, out_sharding=query_sharding)
793830
else:
794-
# LoRA path
795-
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
831+
# LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused).
832+
if not self.config.fused_mla_lora_proj:
833+
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
834+
else:
835+
low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__
796836
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
797837
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
798838
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
@@ -931,7 +971,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
931971
else:
932972
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
933973
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
934-
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
974+
if self.config.fused_mla_lora_proj:
975+
low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__
976+
else:
977+
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
935978
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
936979
low_rank_main = self.kv_norm(low_rank_main)
937980
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")
@@ -1002,12 +1045,23 @@ def __call__(
10021045
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
10031046
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
10041047

1005-
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
1006-
if self.config.force_q_layout:
1007-
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1008-
key, value, cached_values = self.mla_kv_projection(
1009-
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
1010-
)
1048+
if self.config.fused_mla_lora_proj:
1049+
# Single matmul for both Q and KV LoRA up-projections, then split.
1050+
fused_lora = self.wq_kv_a(inputs_q)
1051+
lora_q, lora_kv = jnp.split(fused_lora, [self.q_lora_rank], axis=-1)
1052+
query, low_rank_q = self.mla_query_projection(lora_q, inputs_positions, model_mode)
1053+
if self.config.force_q_layout:
1054+
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1055+
key, value, cached_values = self.mla_kv_projection(
1056+
lora_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
1057+
)
1058+
else:
1059+
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
1060+
if self.config.force_q_layout:
1061+
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1062+
key, value, cached_values = self.mla_kv_projection(
1063+
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
1064+
)
10111065
query = checkpoint_name(query, "query_proj")
10121066
key = checkpoint_name(key, "key_proj")
10131067
value = checkpoint_name(value, "value_proj")

0 commit comments

Comments
 (0)