Skip to content

Add MoE and MLA remat policies#3414

Merged
copybara-service[bot] merged 3 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/add-moe-mla-remat-policies
Mar 25, 2026
Merged

Add MoE and MLA remat policies#3414
copybara-service[bot] merged 3 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/add-moe-mla-remat-policies

Conversation

@abhinavgoel95
Copy link
Copy Markdown
Contributor

@abhinavgoel95 abhinavgoel95 commented Mar 13, 2026

  • Added moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo for MoE layers
  • Added query_wa_proj, kv_wa_proj for MLA layers
  • Updated base.yml, types.py, and pyconfig_deprecated.py

Description

This PR adds rematerialization policy support for Mixture of Experts (MoE) and Multi-head Latent Attention (MLA) layer tensors.

Previously, MaxText only supported remat policies for standard dense layer tensors. This prevented fine-grained memory optimization for MoE models (like Mixtral, DeepSeek V3) and models using MLA architecture (like DeepSeek V3).

This change adds six new configurable remat tensors:

  • MoE tensors: moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo
  • MLA tensors: query_wa_proj, kv_wa_proj

Users can now configure these tensors with device, offload, or remat policies in their config files, enabling better memory management for large MoE models (e.g., DeepSeek V3 671B).

Files modified:

  • src/maxtext/configs/base.yml - Added default 'remat' values
  • src/maxtext/configs/types.py - Added Field definitions with descriptions
  • src/maxtext/configs/pyconfig_deprecated.py - Added to validation whitelist

All new tensors default to 'remat', maintaining backward compatibility.

Tests

Tested with DeepSeek V3 671B (41 layers) on 128 GPUs with various remat configurations:

  • Baseline with all tensors set to remat - ✅ Works
  • Custom policies with selective offload and device placement - ✅ Works
  • Verified backward compatibility with Llama models (no regression)

Example config usage:

moe_mlpwi: 'offload'
query_wa_proj: 'device'

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files.

Comment thread src/maxtext/layers/moe.py
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this might affect all legacy TPU recipes/performance for MoE models. We should make an announcement after it gets merged. Thanks!

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/add-moe-mla-remat-policies branch from a3779fc to d7fd385 Compare March 16, 2026 17:52
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 87.50000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_mla.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/add-moe-mla-remat-policies branch from 2fc4407 to 5109d24 Compare March 24, 2026 18:48
@copybara-service copybara-service Bot merged commit de51021 into AI-Hypercomputer:main Mar 25, 2026
3 checks passed
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also update

def get_remat_policy(self):
"""Get remat policy"""
policy = None
cfg = self.config
if cfg.remat_policy != "none":
if cfg.remat_policy in ("minimal_with_context", "minimal_flash"):
# save all
if cfg.remat_policy == "minimal_flash":
max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.")
max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.")
policy = self.minimal_policy(with_context=True)
elif cfg.remat_policy == "minimal":
# save all except context
policy = self.minimal_policy()
elif cfg.remat_policy == "minimal_with_quantization":
if cfg.scan_layers:
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=False, with_quantization=True)
elif cfg.remat_policy == "minimal_with_context_and_quantization":
if cfg.scan_layers:
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=True, with_quantization=True)
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"context",
"out_proj",
)
elif cfg.remat_policy == "save_dot_except_mlpwi":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwo",
)
elif cfg.remat_policy == "save_dot_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
)
elif cfg.remat_policy == "save_qkv_proj":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
)
elif cfg.remat_policy == "qkv_proj_offloaded":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "minimal_offloaded":
# offload all except context
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "custom":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=cfg.tensors_on_device,
names_which_can_be_offloaded=cfg.tensors_to_offload,
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "save_out_proj":
policy = jax.checkpoint_policies.save_only_these_names(
"out_proj",
)
else:
assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies"
policy = None
return policy
?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants