Skip to content

Commit a3779fc

Browse files
Abhinav GoelAbhinav Goel
authored andcommitted
Add checkpoint_name calls for moe_mlpwi/mlpwo and query/kv_wa_proj
Wire up the remat config keys added in 717ddf5 to actual checkpoint_name call sites in the layer code: - moe.py: rename mlpwi_0/1 and mlpwo -> moe_mlpwi_0/1 and moe_mlpwo (9 sites) - attention_mla.py: add query_wa_proj after wq_a and kv_wa_proj after wkv_a
1 parent 0033295 commit a3779fc

2 files changed

Lines changed: 11 additions & 9 deletions

File tree

src/maxtext/layers/attention_mla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,7 @@ def mla_query_projection(
794794
else:
795795
# LoRA path
796796
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
797+
low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj")
797798
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
798799
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
799800
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
@@ -933,6 +934,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
933934
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
934935
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
935936
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
937+
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
936938
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
937939
low_rank_main = self.kv_norm(low_rank_main)
938940
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")

src/maxtext/layers/moe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741274
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
12751275
if self.config.mlp_bias:
12761276
layer_w0 = layer_w0 + w0_bias
1277-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1277+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
12781278

12791279
layer_w1 = gmm_fn(
12801280
x,
@@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12881288
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
12891289
if self.config.mlp_bias:
12901290
layer_w1 = layer_w1 + w1_bias
1291-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1291+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
12921292
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
12931293

12941294
intermediate_output = gmm_fn(
@@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13051305
)
13061306
if self.config.mlp_bias:
13071307
intermediate_output = intermediate_output + wo_bias
1308-
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")
1308+
intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo")
13091309

13101310
if self.config.use_ring_of_experts:
13111311
# Set the outputs of tokens which were not processed to 0.
@@ -1860,7 +1860,7 @@ def dense_matmul(
18601860
layer_w0,
18611861
mlp_axis,
18621862
)
1863-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1863+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
18641864
with jax.named_scope("wi_1"):
18651865
w1_kernel_axes = ("exp", None, "mlp")
18661866
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
@@ -1876,7 +1876,7 @@ def dense_matmul(
18761876
layer_w1,
18771877
mlp_axis,
18781878
)
1879-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1879+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
18801880
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
18811881
with jax.named_scope("wo"):
18821882
wo_kernel_axes = ("exp", "mlp", None)
@@ -1902,7 +1902,7 @@ def dense_matmul(
19021902
"activation_embed",
19031903
),
19041904
)
1905-
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1905+
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
19061906
with jax.named_scope("combine"):
19071907
# Matmul & element wise operation
19081908
output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)(
@@ -1931,7 +1931,7 @@ def dense_matmul(
19311931
layer_w0 = layer_w0 + w0_bias[None, None, :, :]
19321932
if self.config.activations_in_float32:
19331933
layer_w0 = layer_w0.astype(jnp.float32)
1934-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1934+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
19351935
with jax.named_scope("wi_1"):
19361936
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
19371937
"BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision
@@ -1940,7 +1940,7 @@ def dense_matmul(
19401940
layer_w1 = layer_w1 + w1_bias[None, None, :, :]
19411941
if self.config.activations_in_float32:
19421942
layer_w1 = layer_w1.astype(jnp.float32)
1943-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1943+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
19441944
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
19451945

19461946
with jax.named_scope("wo"):
@@ -1954,7 +1954,7 @@ def dense_matmul(
19541954
intermediate_layer = intermediate_layer + wo_bias[None, None, :, :]
19551955
if self.config.activations_in_float32:
19561956
intermediate_layer = intermediate_layer.astype(jnp.float32)
1957-
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1957+
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
19581958
with jax.named_scope("weight_sum"):
19591959
if is_llama4_decoder_layer:
19601960
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)

0 commit comments

Comments
 (0)