diff --git a/README.md b/README.md index 63254a9108..07d21f8560 100644 --- a/README.md +++ b/README.md @@ -118,25 +118,25 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp * Gemma 2 (2B, 9B, 27B) * Gemma 1 (2B, 7B) * Alibaba - * Qwen 2.5 (1.5B, 7B, 14B) - * Qwen 3 MoE 2507 (235B, 480B) - * Qwen 3 MoE (30B, 235B) + * Qwen 3 Next (80B) + * Qwen 3 MoE (30B, 235B), Qwen 3 MoE 2507 (235B, 480B) * Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B) -* DeepSeek + * Qwen 2.5 (1.5B, 7B, 14B) +* DeepSeek AI * DeepSeek V3.2 (671B) * DeepSeek V3.1 (671B) - * DeepSeek V3 0324 (671B) & DeepSeek R1 0528 (671B) + * DeepSeek V3 0324 (671B), DeepSeek R1 0528 (671B) * DeepSeek V2 (16B, 236B) -* Kimi - * Kimi K2 +* Moonshot AI + * Kimi K2 (1T) * Meta * Llama 4 Scout (109B) & Maverick (400B) - * Llama 3.3 70B, 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B) + * Llama 3.3 (70B), 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B) * Llama 2 (7B, 13B, 70B) -* Open AI +* OpenAI * GPT-OSS (20B, 120B) * GPT3 (52K, 6B, 22B, 175B) -* Mistral +* Mistral AI * Mixtral (8x7B, 8x22B) * Mistral (7B) * Diffusion Models diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index c562a05357..c1b9adab36 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -256,16 +256,6 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: Please pass tokenizer_path in your command if this is not intended." ) - # Preprocess muon_consistent_rms to be None or float - if key == "muon_consistent_rms": - if value in ["None", "none"]: - new_value = None - else: - try: - new_value = float(value) - except ValueError as e: - raise ValueError("muon_consistent_rms should be None or float") from e - pydantic_kwargs[key] = new_value return pydantic_kwargs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index af809e2b42..b4bf97c66c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1343,7 +1343,7 @@ class Muon(BaseModel): 0, description="Strength of the weight decay regularization. This is multiplied with the learning rate.", ) - muon_consistent_rms: None | float = Field( + muon_consistent_rms: float | None = Field( None, description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).", ) diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 2ae7e5f8e5..4faa627bbf 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -197,6 +197,8 @@ def get_optimizer(config, learning_rate_schedule, model=None): muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config) else: raise ValueError("Please specify model to extract muon dimension number.") + # TODO(shuningjin): remove + print(f"DEBUG: {config.muon_consistent_rms}, {type(config.muon_consistent_rms)}") muon_kwargs = { # Shared parameters: "nesterov" uses default "learning_rate": learning_rate_schedule, diff --git a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md index 973811ee9f..6432ac9825 100644 --- a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -170,7 +170,10 @@ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/conf ``` ## Continued pre-training for V3.2 Sparse Attention -**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**. + +**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. Note that Indexer is activated only if `max_target_length` > `indexer_topk` (2048). + +DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**. 1. **Dense Warmup Stage** The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen. @@ -186,6 +189,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ async_checkpointing=false \ ici_fsdp_parallelism=128 \ steps=5 \ + # Indexer is activated only if max_target_length > indexer_topk (2048) max_target_length=4096 \ attention=flash \ dtype=bfloat16 \ @@ -212,6 +216,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ async_checkpointing=false \ ici_fsdp_parallelism=128 \ steps=5 \ + # Indexer is activated only if max_target_length > indexer_topk (2048) max_target_length=4096 \ attention=flash \ dtype=bfloat16 \ diff --git a/tests/end_to_end/tpu/kimi/Run_Kimi.md b/tests/end_to_end/tpu/kimi/Run_Kimi.md index 2e80205d18..dad1155da5 100644 --- a/tests/end_to_end/tpu/kimi/Run_Kimi.md +++ b/tests/end_to_end/tpu/kimi/Run_Kimi.md @@ -16,10 +16,10 @@ # Kimi -Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1T)**. +Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported model is **Kimi K2 (1T)**. * **[Kimi K2](https://arxiv.org/pdf/2507.20534)** features a massive 1.04 trillion total parameters with 32 billion activated parameters. The architecture is similar to DeepSeek-V3. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks. -* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient [Muon](https://kellerjordan.github.io/posts/muon) optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training. +* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient **[Muon optimizer](https://kellerjordan.github.io/posts/muon)** combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training. * **Agentic Excellence**: K2 is specifically post-trained using a large-scale agentic data synthesis pipeline and Reinforcement Learning (RL), achieving state-of-the-art performance on benchmarks like Tau2-Bench and SWE-Bench. ## Checkpoint Conversion @@ -46,7 +46,7 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam ``` ## Pre-training -You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 (adjust parallelism for the 1T parameter scale). To use MuonClip optimizer, you need `optax>=0.2.7` and `tokamax>=0.0.11`. +You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on tpu7x-512 with 256 chips. To use **MuonClip optimizer**, you need `optax>=0.2.7` and `tokamax>=0.0.11`. ```sh python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ @@ -72,9 +72,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ dataset_type=synthetic \ scan_layers=True \ use_ring_of_experts=True \ + # muon optimizer opt_type=muon \ muon_consistent_rms=0.2 \ muon_weight_decay=0.1 \ + # qk clip use_qk_clip=True \ qk_clip_threshold=100 ``` @@ -109,9 +111,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ scan_layers=True \ load_parameters_path=${SCANNED_CHECKPOINT?} \ use_ring_of_experts=True \ + # muon optimizer opt_type=muon \ muon_consistent_rms=0.2 \ muon_weight_decay=0.1 \ + # qk clip use_qk_clip=True \ qk_clip_threshold=100 ``` @@ -122,18 +126,21 @@ Example command to run decoding with Kimi K2. Given its 1T size, high tensor par ```sh python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ - load_parameters_path=${CONVERTED_CHECKPOINT?} \ run_name=kimi_decode \ - per_device_batch_size=1 \ model_name=kimi-k2-1t \ - max_target_length=2048 \ tokenizer_type=huggingface \ tokenizer_path=moonshotai/Kimi-K2-Instruct \ + hf_access_token=${HF_TOKEN?} \ + load_parameters_path=${UNSCANNED_CKPT_PATH?} \ + scan_layers=False \ + enable_checkpointing=true \ + async_checkpointing=false \ + per_device_batch_size=1 \ + max_target_length=2048 \ attention=dot_product \ ici_tensor_parallelism=128 \ ici_fsdp_parallelism=1 \ - prompt="The primary goal of agentic intelligence is to " \ - scan_layers=False + prompt="The primary goal of agentic intelligence is to " ``` ## Correctness @@ -158,6 +165,8 @@ python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ --trust-remote-code=True ``` +Run command below to compare logits between HuggingFace and MaxText. + ```sh JAX_PLATFORMS=cpu python3 -m tests.forward_pass_logit_checker \ src/maxtext/configs/base.yml \ diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 4230c46174..0a40da2ab4 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -569,7 +569,6 @@ def test_moe_deepseek_scanned_bf16(self): ) ) - @pytest.mark.skip(reason="Fix sharding issue of all layers of DeepSeek") @pytest.mark.cpu_only def test_moe_deepseek_unscanned_bf16(self): temp_dir = gettempdir() @@ -734,7 +733,7 @@ def test_gpt3_6b(self): "", get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-256", + "compile_topology=v5p-8", "compile_topology_num_slices=1", "model_name=gpt3-6b", "per_device_batch_size=1", @@ -766,7 +765,7 @@ def test_qwen3_next(self): "", get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-256", + "compile_topology=v5p-64", "compile_topology_num_slices=1", "model_name=qwen3-next-80b-a3b", "per_device_batch_size=1", @@ -796,9 +795,6 @@ def test_deepseek32(self): "use_tokamax_splash=True", "dtype=bfloat16", "weight_dtype=bfloat16", - # without_device_limit - "n_routing_groups=-1", - "topk_routing_group=-1", ) ) @@ -948,9 +944,9 @@ def test_circular_pipeline_ag_per_repeat_ep_ds(self): ) @pytest.mark.cpu_only - def test_qk_clip(self): - """AOT test for qk-clip with DeepSeek3 Tiny model""" - compiled_trainstep_file = "/tmp/test_qk_clip.pickle" + def test_qk_clip_with_dot_product(self): + """AOT test for AdamW optimizer with QK clip on dot product attention for DeepSeek3 Tiny model""" + compiled_trainstep_file = "/tmp/test_qk_clip_with_dot_product.pickle" train_compile_main( ( "", @@ -963,13 +959,47 @@ def test_qk_clip(self): "sparse_matmul=True", "megablox=True", "use_tokamax_gmm=False", - # TODO(agagik): update to flash after support + "max_target_length=128", + "per_device_batch_size=1", + "dtype=bfloat16", + "weight_dtype=float32", + # dot product attention "attention=dot_product", "use_tokamax_splash=True", + # qk clip + "use_qk_clip=true", + "qk_clip_threshold=100", + ) + ) + + @pytest.mark.cpu_only + def test_muon_clip_with_tokamax_splash(self): + """AOT test for Muon optimizer with QK clip on tokamax splash attention for DeepSeek3 Tiny model""" + compiled_trainstep_file = "/tmp/test_muon_clip_with_tokamax_splash.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=1", + "model_name=deepseek3-tiny", + "scan_layers=True", + "sparse_matmul=True", + "megablox=True", + "use_tokamax_gmm=False", "max_target_length=128", "per_device_batch_size=1", "dtype=bfloat16", "weight_dtype=float32", + # tokamax splash attention + "attention=flash", + "use_tokamax_splash=True", + # muon optimizer + "opt_type=muon", + "muon_consistent_rms=0.2", + "muon_weight_decay=0.1", + # qk clip "use_qk_clip=true", "qk_clip_threshold=100", )