Skip to content

Commit de51021

Browse files
Merge pull request #3414 from abhinavgoel95:abgoel/add-moe-mla-remat-policies
PiperOrigin-RevId: 888933852
2 parents 16b6848 + 0ac53b5 commit de51021

6 files changed

Lines changed: 92 additions & 80 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,16 @@ mlpwi: 'remat'
317317
mlpwi_0: 'remat'
318318
mlpwi_1: 'remat'
319319
mlpwo: 'remat'
320+
moe_mlpwi_0: 'remat'
321+
moe_mlpwi_1: 'remat'
322+
moe_mlpwo: 'remat'
320323
query_proj: 'remat'
321324
key_proj: 'remat'
322325
value_proj: 'remat'
323326
qkv_proj: 'remat'
324327
out_proj: 'remat'
328+
query_wa_proj: 'remat'
329+
kv_wa_proj: 'remat'
325330
mla_q: 'remat'
326331
mla_kv: 'remat'
327332
attention_out: 'remat'

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,14 @@ def validate_and_assign_remat_tensors(keys):
518518
"mlpwi_0",
519519
"mlpwi_1",
520520
"mlpwo",
521+
"moe_mlpwi_0",
522+
"moe_mlpwi_1",
523+
"moe_mlpwo",
521524
"query_proj",
522525
"key_proj",
523526
"value_proj",
527+
"query_wa_proj",
528+
"kv_wa_proj",
524529
"out_proj",
525530
]
526531
assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True"

src/maxtext/configs/types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,29 @@ class RematAndOffload(BaseModel):
912912
RematLocation.REMAT,
913913
description="Remat policy for the second MLP layer's output.",
914914
)
915+
moe_mlpwi_0: RematLocation = Field(
916+
RematLocation.REMAT,
917+
description="Remat policy for the first part of a gated MoE's output.",
918+
)
919+
moe_mlpwi_1: RematLocation = Field(
920+
RematLocation.REMAT,
921+
description="Remat policy for the second part of a gated MoE's output.",
922+
)
923+
moe_mlpwo: RematLocation = Field(
924+
RematLocation.REMAT,
925+
description="Remat policy for the second MoE layer's output.",
926+
)
915927
query_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the query projection.")
916928
key_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the key projection.")
917929
value_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the value projection.")
930+
query_wa_proj: RematLocation = Field(
931+
RematLocation.REMAT,
932+
description="Remat policy for the MLA query weighted attention projection.",
933+
)
934+
kv_wa_proj: RematLocation = Field(
935+
RematLocation.REMAT,
936+
description="Remat policy for the MLA key and value weighted attention projection.",
937+
)
918938
qkv_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for fused QKV projection.")
919939
out_proj: RematLocation = Field(
920940
RematLocation.REMAT,

src/maxtext/layers/attention_mla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ def mla_query_projection(
817817
else:
818818
# LoRA path
819819
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
820+
low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj")
820821
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
821822
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
822823
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
@@ -956,6 +957,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
956957
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
957958
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
958959
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
960+
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
959961
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
960962
low_rank_main = self.kv_norm(low_rank_main)
961963
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")

src/maxtext/layers/moe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12741274
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
12751275
if self.config.mlp_bias:
12761276
layer_w0 = layer_w0 + w0_bias
1277-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1277+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
12781278

12791279
layer_w1 = gmm_fn(
12801280
x,
@@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12881288
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
12891289
if self.config.mlp_bias:
12901290
layer_w1 = layer_w1 + w1_bias
1291-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1291+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
12921292
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
12931293

12941294
intermediate_output = gmm_fn(
@@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13051305
)
13061306
if self.config.mlp_bias:
13071307
intermediate_output = intermediate_output + wo_bias
1308-
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")
1308+
intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo")
13091309

13101310
if self.config.use_ring_of_experts:
13111311
# Set the outputs of tokens which were not processed to 0.
@@ -1862,7 +1862,7 @@ def dense_matmul(
18621862
layer_w0,
18631863
mlp_axis,
18641864
)
1865-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1865+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
18661866
with jax.named_scope("wi_1"):
18671867
w1_kernel_axes = ("exp", None, "mlp")
18681868
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
@@ -1878,7 +1878,7 @@ def dense_matmul(
18781878
layer_w1,
18791879
mlp_axis,
18801880
)
1881-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1881+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
18821882
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
18831883
with jax.named_scope("wo"):
18841884
wo_kernel_axes = ("exp", "mlp", None)
@@ -1904,7 +1904,7 @@ def dense_matmul(
19041904
"activation_embed_moe",
19051905
),
19061906
)
1907-
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1907+
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
19081908
with jax.named_scope("combine"):
19091909
# Matmul & element wise operation
19101910
output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)(
@@ -1935,7 +1935,7 @@ def dense_matmul(
19351935
layer_w0 = layer_w0 + w0_bias[None, None, :, :]
19361936
if self.config.activations_in_float32:
19371937
layer_w0 = layer_w0.astype(jnp.float32)
1938-
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
1938+
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
19391939
with jax.named_scope("wi_1"):
19401940
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
19411941
"BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision
@@ -1944,7 +1944,7 @@ def dense_matmul(
19441944
layer_w1 = layer_w1 + w1_bias[None, None, :, :]
19451945
if self.config.activations_in_float32:
19461946
layer_w1 = layer_w1.astype(jnp.float32)
1947-
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
1947+
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
19481948
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
19491949

19501950
with jax.named_scope("wo"):
@@ -1958,7 +1958,7 @@ def dense_matmul(
19581958
intermediate_layer = intermediate_layer + wo_bias[None, None, :, :]
19591959
if self.config.activations_in_float32:
19601960
intermediate_layer = intermediate_layer.astype(jnp.float32)
1961-
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1961+
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
19621962
with jax.named_scope("weight_sum"):
19631963
if is_llama4_decoder_layer:
19641964
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)

tests/unit/attention_test.py

Lines changed: 51 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -56,62 +56,52 @@ def test_one_block_mask(self):
5656
bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]])
5757
# pylint: disable=protected-access
5858
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
59-
expected_mask = np.asarray(
60-
[
61-
[
62-
[False, False, False, False, False, False],
63-
[False, True, True, True, False, False],
64-
[False, True, True, True, False, False],
65-
[False, True, True, True, False, False],
66-
[False, False, False, False, False, False],
67-
[False, False, False, False, False, False],
68-
]
69-
]
70-
)
59+
expected_mask = np.asarray([[
60+
[False, False, False, False, False, False],
61+
[False, True, True, True, False, False],
62+
[False, True, True, True, False, False],
63+
[False, True, True, True, False, False],
64+
[False, False, False, False, False, False],
65+
[False, False, False, False, False, False],
66+
]])
7167
np.testing.assert_array_equal(block_mask, expected_mask)
7268

7369
def test_two_blocks_mask(self):
7470
bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]])
7571
# pylint: disable=protected-access
7672
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
77-
expected_mask = np.asarray(
78-
[
79-
[
80-
[False, False, False, False, False, False],
81-
[False, True, True, False, False, False],
82-
[False, True, True, False, False, False],
83-
[False, False, False, False, False, False],
84-
[False, False, False, False, True, True],
85-
[False, False, False, False, True, True],
86-
]
87-
]
88-
)
73+
expected_mask = np.asarray([[
74+
[False, False, False, False, False, False],
75+
[False, True, True, False, False, False],
76+
[False, True, True, False, False, False],
77+
[False, False, False, False, False, False],
78+
[False, False, False, False, True, True],
79+
[False, False, False, False, True, True],
80+
]])
8981
np.testing.assert_array_equal(block_mask, expected_mask)
9082

9183
def test_batch_block_masks(self):
9284
bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]])
9385
# pylint: disable=protected-access
9486
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
95-
expected_mask = np.asarray(
87+
expected_mask = np.asarray([
9688
[
97-
[
98-
[False, False, False, False, False, False],
99-
[False, True, True, True, False, False],
100-
[False, True, True, True, False, False],
101-
[False, True, True, True, False, False],
102-
[False, False, False, False, False, False],
103-
[False, False, False, False, False, False],
104-
],
105-
[
106-
[False, False, False, False, False, False],
107-
[False, True, True, False, False, False],
108-
[False, True, True, False, False, False],
109-
[False, False, False, False, False, False],
110-
[False, False, False, False, True, True],
111-
[False, False, False, False, True, True],
112-
],
113-
]
114-
)
89+
[False, False, False, False, False, False],
90+
[False, True, True, True, False, False],
91+
[False, True, True, True, False, False],
92+
[False, True, True, True, False, False],
93+
[False, False, False, False, False, False],
94+
[False, False, False, False, False, False],
95+
],
96+
[
97+
[False, False, False, False, False, False],
98+
[False, True, True, False, False, False],
99+
[False, True, True, False, False, False],
100+
[False, False, False, False, False, False],
101+
[False, False, False, False, True, True],
102+
[False, False, False, False, True, True],
103+
],
104+
])
115105
np.testing.assert_array_equal(block_mask, expected_mask)
116106

117107
def test_empty_block_mask(self):
@@ -141,34 +131,24 @@ def test_combine_with_causal_mask(self):
141131
# pylint: disable=protected-access
142132
image_mask = _make_bidirectional_block_mask(bidirectional_mask)
143133
combined_mask = causal_mask | image_mask[:, None, None, ...]
144-
expected_mask = np.asarray(
145-
[
146-
[
147-
[
148-
[
149-
[True, False, False, False, False, False],
150-
[True, True, True, True, False, False],
151-
[True, True, True, True, False, False],
152-
[True, True, True, True, False, False],
153-
[True, True, True, True, True, False],
154-
[True, True, True, True, True, True],
155-
]
156-
]
157-
],
158-
[
159-
[
160-
[
161-
[True, False, False, False, False, False],
162-
[True, True, True, False, False, False],
163-
[True, True, True, False, False, False],
164-
[True, True, True, True, False, False],
165-
[True, True, True, True, True, True],
166-
[True, True, True, True, True, True],
167-
]
168-
]
169-
],
170-
]
171-
)
134+
expected_mask = np.asarray([
135+
[[[
136+
[True, False, False, False, False, False],
137+
[True, True, True, True, False, False],
138+
[True, True, True, True, False, False],
139+
[True, True, True, True, False, False],
140+
[True, True, True, True, True, False],
141+
[True, True, True, True, True, True],
142+
]]],
143+
[[[
144+
[True, False, False, False, False, False],
145+
[True, True, True, False, False, False],
146+
[True, True, True, False, False, False],
147+
[True, True, True, True, False, False],
148+
[True, True, True, True, True, True],
149+
[True, True, True, True, True, True],
150+
]]],
151+
])
172152
np.testing.assert_array_equal(combined_mask, expected_mask)
173153

174154

0 commit comments

Comments
 (0)