Skip to content

Commit e9fa350

Browse files
committed
.
1 parent f3099c3 commit e9fa350

3 files changed

Lines changed: 25 additions & 12 deletions

File tree

tests/end_to_end/tpu/deepseek/Run_DeepSeek.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,10 @@ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/conf
170170
```
171171

172172
## Continued pre-training for V3.2 Sparse Attention
173-
**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.
173+
174+
**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. Note that Indexer is activated only if `max_target_length` > `indexer_topk` (2048).
175+
176+
DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.
174177

175178
1. **Dense Warmup Stage**
176179
The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.
@@ -186,6 +189,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
186189
async_checkpointing=false \
187190
ici_fsdp_parallelism=128 \
188191
steps=5 \
192+
# Indexer is activated only if max_target_length > indexer_topk (2048)
189193
max_target_length=4096 \
190194
attention=flash \
191195
dtype=bfloat16 \
@@ -212,6 +216,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
212216
async_checkpointing=false \
213217
ici_fsdp_parallelism=128 \
214218
steps=5 \
219+
# Indexer is activated only if max_target_length > indexer_topk (2048)
215220
max_target_length=4096 \
216221
attention=flash \
217222
dtype=bfloat16 \

tests/end_to_end/tpu/kimi/Run_Kimi.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
7272
dataset_type=synthetic \
7373
scan_layers=True \
7474
use_ring_of_experts=True \
75+
# muon optimizer
7576
opt_type=muon \
7677
muon_consistent_rms=0.2 \
7778
muon_weight_decay=0.1 \
79+
# qk clip
7880
use_qk_clip=True \
7981
qk_clip_threshold=100
8082
```
@@ -109,9 +111,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
109111
scan_layers=True \
110112
load_parameters_path=${SCANNED_CHECKPOINT?} \
111113
use_ring_of_experts=True \
114+
# muon optimizer
112115
opt_type=muon \
113116
muon_consistent_rms=0.2 \
114117
muon_weight_decay=0.1 \
118+
# qk clip
115119
use_qk_clip=True \
116120
qk_clip_threshold=100
117121
```
@@ -122,18 +126,21 @@ Example command to run decoding with Kimi K2. Given its 1T size, high tensor par
122126
```sh
123127
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
124128
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
125-
load_parameters_path=${CONVERTED_CHECKPOINT?} \
126129
run_name=kimi_decode \
127-
per_device_batch_size=1 \
128130
model_name=kimi-k2-1t \
129-
max_target_length=2048 \
130131
tokenizer_type=huggingface \
131132
tokenizer_path=moonshotai/Kimi-K2-Instruct \
133+
hf_access_token=${HF_TOKEN?} \
134+
load_parameters_path=${UNSCANNED_CKPT_PATH?} \
135+
scan_layers=False \
136+
enable_checkpointing=true \
137+
async_checkpointing=false \
138+
per_device_batch_size=1 \
139+
max_target_length=2048 \
132140
attention=dot_product \
133141
ici_tensor_parallelism=128 \
134142
ici_fsdp_parallelism=1 \
135-
prompt="The primary goal of agentic intelligence is to " \
136-
scan_layers=False
143+
prompt="The primary goal of agentic intelligence is to "
137144
```
138145

139146
## Correctness
@@ -158,6 +165,8 @@ python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
158165
--trust-remote-code=True
159166
```
160167

168+
Run command below to compare logits between HuggingFace and MaxText.
169+
161170
```sh
162171
JAX_PLATFORMS=cpu python3 -m tests.forward_pass_logit_checker \
163172
src/maxtext/configs/base.yml \

tests/unit/train_compile_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ def test_moe_deepseek_scanned_bf16(self):
569569
)
570570
)
571571

572-
@pytest.mark.skip(reason="Fix sharding issue of all layers of DeepSeek")
573572
@pytest.mark.cpu_only
574573
def test_moe_deepseek_unscanned_bf16(self):
575574
temp_dir = gettempdir()
@@ -964,10 +963,10 @@ def test_qk_clip_with_dot_product(self):
964963
"per_device_batch_size=1",
965964
"dtype=bfloat16",
966965
"weight_dtype=float32",
967-
# dot product
966+
# dot product attention
968967
"attention=dot_product",
969968
"use_tokamax_splash=True",
970-
# qk
969+
# qk clip
971970
"use_qk_clip=true",
972971
"qk_clip_threshold=100",
973972
)
@@ -993,14 +992,14 @@ def test_muon_clip_with_tokamax_splash(self):
993992
"per_device_batch_size=1",
994993
"dtype=bfloat16",
995994
"weight_dtype=float32",
996-
# tokamax splash
995+
# tokamax splash attention
997996
"attention=flash",
998997
"use_tokamax_splash=True",
999-
# muon
998+
# muon optimizer
1000999
"opt_type=muon",
10011000
"muon_consistent_rms=0.2",
10021001
"muon_weight_decay=0.1",
1003-
# qk
1002+
# qk clip
10041003
"use_qk_clip=true",
10051004
"qk_clip_threshold=100",
10061005
)

0 commit comments

Comments
 (0)