Skip to content

Commit 95ef3e1

Browse files
Merge pull request #3108 from AI-Hypercomputer:qinwen/add_checkpoint
PiperOrigin-RevId: 866687753
2 parents f62ee44 + 94f24a0 commit 95ef3e1

5 files changed

Lines changed: 18 additions & 0 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ qkv_proj: 'remat'
311311
out_proj: 'remat'
312312
mla_q: 'remat'
313313
mla_kv: 'remat'
314+
attention_out: 'remat'
314315

315316
optimizer_memory_host_offload: False
316317
parameter_memory_host_offload: False

src/MaxText/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,11 @@ class RematAndOffload(BaseModel):
876876
RematLocation.REMAT,
877877
description="Remat policy for the mla's key and value projection.",
878878
)
879+
attention_out: RematLocation = Field(
880+
RematLocation.REMAT,
881+
description="Remat policy for the attention output.",
882+
)
883+
879884
optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.")
880885
parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.")
881886

@@ -2064,6 +2069,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20642069
"mla_kv",
20652070
"mla_q",
20662071
"qkv_proj",
2072+
"attention_out",
20672073
"out_proj",
20682074
]
20692075
self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"]

src/MaxText/layers/attention_mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ def __call__(
10381038
# Pass the index_mask to the Attention Op
10391039
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)
10401040

1041+
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
10411042
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
10421043
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
10431044
else:

src/MaxText/layers/attentions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,7 @@ def __call__(
11321132
bidirectional_mask,
11331133
self.sinks,
11341134
)
1135+
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
11351136
if model_mode == MODEL_MODE_PREFILL:
11361137
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
11371138
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def mla(
339339
qk_nope_head_dim=qk_nope_head_dim,
340340
mscale=mscale,
341341
)
342+
query = jax.ad_checkpoint.checkpoint_name(query, "query_proj")
342343
key, value = kv_projection(
343344
inputs,
344345
positions,
@@ -358,6 +359,8 @@ def mla(
358359
qk_nope_head_dim=qk_nope_head_dim,
359360
num_query_heads=num_query_heads,
360361
)
362+
key = jax.ad_checkpoint.checkpoint_name(key, "key_proj")
363+
value = jax.ad_checkpoint.checkpoint_name(value, "value_proj")
361364
out = attention_op_fn(
362365
query,
363366
key,
@@ -366,7 +369,9 @@ def mla(
366369
model_mode,
367370
cached_values=[None, None],
368371
)
372+
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
369373
out = dot(out, out_weights, axes=2)
374+
out = jax.ad_checkpoint.checkpoint_name(out, "out_proj")
370375
return out
371376

372377

@@ -405,6 +410,7 @@ def query_projection(
405410
epsilon=epsilon,
406411
dtype=dtype,
407412
)
413+
low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q")
408414
q = dot(low_rank_q, wq_b_weights)
409415

410416
# Split into non-positional and rotary parts.
@@ -454,6 +460,7 @@ def kv_projection(
454460
epsilon=kv_norm_epsilon,
455461
dtype=dtype,
456462
)
463+
low_rank_main = jax.ad_checkpoint.checkpoint_name(low_rank_main, "mla_kv")
457464
key_rope = jnp.expand_dims(low_rank_rope, axis=2)
458465
key_rope = yarn(
459466
key_rope,
@@ -693,6 +700,8 @@ def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size,
693700
)
694701
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
695702
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size)
703+
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
704+
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
696705
intermediate_layer = jax.nn.silu(layer_w0) * layer_w1
697706
intermediate_layer *= weights[:, None]
698707
return gmm_fn(intermediate_layer, wo, tiling=wo_tile_size)

0 commit comments

Comments
 (0)