Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,25 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
* Gemma 2 (2B, 9B, 27B)
* Gemma 1 (2B, 7B)
* Alibaba
* Qwen 2.5 (1.5B, 7B, 14B)
* Qwen 3 MoE 2507 (235B, 480B)
* Qwen 3 MoE (30B, 235B)
* Qwen 3 Next (80B)
* Qwen 3 MoE (30B, 235B), Qwen 3 MoE 2507 (235B, 480B)
* Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B)
* DeepSeek
* Qwen 2.5 (1.5B, 7B, 14B)
* DeepSeek AI
* DeepSeek V3.2 (671B)
* DeepSeek V3.1 (671B)
* DeepSeek V3 0324 (671B) & DeepSeek R1 0528 (671B)
* DeepSeek V3 0324 (671B), DeepSeek R1 0528 (671B)
* DeepSeek V2 (16B, 236B)
* Kimi
* Kimi K2
* Moonshot AI
* Kimi K2 (1T)
* Meta
* Llama 4 Scout (109B) & Maverick (400B)
* Llama 3.3 70B, 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B)
* Llama 3.3 (70B), 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B)
* Llama 2 (7B, 13B, 70B)
* Open AI
* OpenAI
* GPT-OSS (20B, 120B)
* GPT3 (52K, 6B, 22B, 175B)
* Mistral
* Mistral AI
* Mixtral (8x7B, 8x22B)
* Mistral (7B)
* Diffusion Models
Expand Down
10 changes: 0 additions & 10 deletions src/maxtext/configs/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,6 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
Please pass tokenizer_path in your command if this is not intended."
)

# Preprocess muon_consistent_rms to be None or float
if key == "muon_consistent_rms":
if value in ["None", "none"]:
new_value = None
else:
try:
new_value = float(value)
except ValueError as e:
raise ValueError("muon_consistent_rms should be None or float") from e

pydantic_kwargs[key] = new_value

return pydantic_kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ class Muon(BaseModel):
0,
description="Strength of the weight decay regularization. This is multiplied with the learning rate.",
)
muon_consistent_rms: None | float = Field(
muon_consistent_rms: float | None = Field(
None,
description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).",
)
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def get_optimizer(config, learning_rate_schedule, model=None):
muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config)
else:
raise ValueError("Please specify model to extract muon dimension number.")
# TODO(shuningjin): remove
print(f"DEBUG: {config.muon_consistent_rms}, {type(config.muon_consistent_rms)}")
muon_kwargs = {
# Shared parameters: "nesterov" uses default
"learning_rate": learning_rate_schedule,
Expand Down
7 changes: 6 additions & 1 deletion tests/end_to_end/tpu/deepseek/Run_DeepSeek.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/conf
```

## Continued pre-training for V3.2 Sparse Attention
**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.

**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. Note that Indexer is activated only if `max_target_length` > `indexer_topk` (2048).

DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.

1. **Dense Warmup Stage**
The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.
Expand All @@ -186,6 +189,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
async_checkpointing=false \
ici_fsdp_parallelism=128 \
steps=5 \
# Indexer is activated only if max_target_length > indexer_topk (2048)
max_target_length=4096 \
attention=flash \
dtype=bfloat16 \
Expand All @@ -212,6 +216,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
async_checkpointing=false \
ici_fsdp_parallelism=128 \
steps=5 \
# Indexer is activated only if max_target_length > indexer_topk (2048)
max_target_length=4096 \
attention=flash \
dtype=bfloat16 \
Expand Down
25 changes: 17 additions & 8 deletions tests/end_to_end/tpu/kimi/Run_Kimi.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

# Kimi

Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**.
Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported model is **Kimi K2 (1T)**.

* **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks.
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient **[Muon optimizer](https://kellerjordan.github.io/posts/muon)** combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
* **Agentic Excellence**: K2 is specifically post-trained using a large-scale agentic data synthesis pipeline and Reinforcement Learning (RL), achieving state-of-the-art performance on benchmarks like Tau2-Bench and SWE-Bench.

## Checkpoint Conversion
Expand All @@ -46,7 +46,7 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam
```

## Pre-training
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). To use MuonClip optimizer, you need `optax>=0.2.7` and `tokamax>=0.0.11`.
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 with 256 chips. To use **MuonClip optimizer**, you need `optax>=0.2.7` and `tokamax>=0.0.11`.

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
Expand All @@ -72,9 +72,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
dataset_type=synthetic \
scan_layers=True \
use_ring_of_experts=True \
# muon optimizer
opt_type=muon \
muon_consistent_rms=0.2 \
muon_weight_decay=0.1 \
# qk clip
use_qk_clip=True \
qk_clip_threshold=100
```
Expand Down Expand Up @@ -109,9 +111,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
scan_layers=True \
load_parameters_path=${SCANNED_CHECKPOINT?} \
use_ring_of_experts=True \
# muon optimizer
opt_type=muon \
muon_consistent_rms=0.2 \
muon_weight_decay=0.1 \
# qk clip
use_qk_clip=True \
qk_clip_threshold=100
```
Expand All @@ -122,18 +126,21 @@ Example command to run decoding with Kimi K2. Given its 1T size, high tensor par
```sh
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
load_parameters_path=${CONVERTED_CHECKPOINT?} \
run_name=kimi_decode \
per_device_batch_size=1 \
model_name=kimi-k2-1t \
max_target_length=2048 \
tokenizer_type=huggingface \
tokenizer_path=moonshotai/Kimi-K2-Instruct \
hf_access_token=${HF_TOKEN?} \
load_parameters_path=${UNSCANNED_CKPT_PATH?} \
scan_layers=False \
enable_checkpointing=true \
async_checkpointing=false \
per_device_batch_size=1 \
max_target_length=2048 \
attention=dot_product \
ici_tensor_parallelism=128 \
ici_fsdp_parallelism=1 \
prompt="The primary goal of agentic intelligence is to " \
scan_layers=False
prompt="The primary goal of agentic intelligence is to "
```

## Correctness
Expand All @@ -158,6 +165,8 @@ python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
--trust-remote-code=True
```

Run command below to compare logits between HuggingFace and MaxText.

```sh
JAX_PLATFORMS=cpu python3 -m tests.forward_pass_logit_checker \
src/maxtext/configs/base.yml \
Expand Down
50 changes: 40 additions & 10 deletions tests/unit/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def test_moe_deepseek_scanned_bf16(self):
)
)

@pytest.mark.skip(reason="Fix sharding issue of all layers of DeepSeek")
@pytest.mark.cpu_only
def test_moe_deepseek_unscanned_bf16(self):
temp_dir = gettempdir()
Expand Down Expand Up @@ -734,7 +733,7 @@ def test_gpt3_6b(self):
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"compile_topology=v5p-8",
"compile_topology_num_slices=1",
"model_name=gpt3-6b",
"per_device_batch_size=1",
Expand Down Expand Up @@ -766,7 +765,7 @@ def test_qwen3_next(self):
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-256",
"compile_topology=v5p-64",
"compile_topology_num_slices=1",
"model_name=qwen3-next-80b-a3b",
"per_device_batch_size=1",
Expand Down Expand Up @@ -796,9 +795,6 @@ def test_deepseek32(self):
"use_tokamax_splash=True",
"dtype=bfloat16",
"weight_dtype=bfloat16",
# without_device_limit
"n_routing_groups=-1",
"topk_routing_group=-1",
)
)

Expand Down Expand Up @@ -948,9 +944,9 @@ def test_circular_pipeline_ag_per_repeat_ep_ds(self):
)

@pytest.mark.cpu_only
def test_qk_clip(self):
"""AOT test for qk-clip with DeepSeek3 Tiny model"""
compiled_trainstep_file = "/tmp/test_qk_clip.pickle"
def test_qk_clip_with_dot_product(self):
"""AOT test for AdamW optimizer with QK clip on dot product attention for DeepSeek3 Tiny model"""
compiled_trainstep_file = "/tmp/test_qk_clip_with_dot_product.pickle"
train_compile_main(
(
"",
Expand All @@ -963,13 +959,47 @@ def test_qk_clip(self):
"sparse_matmul=True",
"megablox=True",
"use_tokamax_gmm=False",
# TODO(agagik): update to flash after support
"max_target_length=128",
"per_device_batch_size=1",
"dtype=bfloat16",
"weight_dtype=float32",
# dot product attention
"attention=dot_product",
"use_tokamax_splash=True",
# qk clip
"use_qk_clip=true",
"qk_clip_threshold=100",
)
)

@pytest.mark.cpu_only
def test_muon_clip_with_tokamax_splash(self):
"""AOT test for Muon optimizer with QK clip on tokamax splash attention for DeepSeek3 Tiny model"""
compiled_trainstep_file = "/tmp/test_muon_clip_with_tokamax_splash.pickle"
train_compile_main(
(
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-8",
"compile_topology_num_slices=1",
"model_name=deepseek3-tiny",
"scan_layers=True",
"sparse_matmul=True",
"megablox=True",
"use_tokamax_gmm=False",
"max_target_length=128",
"per_device_batch_size=1",
"dtype=bfloat16",
"weight_dtype=float32",
# tokamax splash attention
"attention=flash",
"use_tokamax_splash=True",
# muon optimizer
"opt_type=muon",
"muon_consistent_rms=0.2",
"muon_weight_decay=0.1",
# qk clip
"use_qk_clip=true",
"qk_clip_threshold=100",
)
Expand Down
Loading