Skip to content

Commit cbb413f

Browse files
Merge pull request #3750 from AI-Hypercomputer:shuningjin-fix3
PiperOrigin-RevId: 906473585
2 parents 2184fe6 + d2bda3f commit cbb413f

4 files changed

Lines changed: 10 additions & 11 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2625,7 +2625,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
26252625
supports_flash_splash = self.attention == "flash" and self.use_tokamax_splash
26262626
if not (supports_dot_product or supports_flash_splash):
26272627
raise NotImplementedError(
2628-
"Sparse indexer is only supported dot_product attention or flash attention with tokamax splash."
2628+
"Sparse indexer is only supported with dot_product attention or flash attention with tokamax splash."
26292629
)
26302630
if self.indexer_loss_scaling_factor > 0.0 and self.indexer_topk >= self.max_target_length:
26312631
raise ValueError(

tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V2-Lite'
2020
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
2121
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
2222

23-
# e.g., $HOME/maxtext/src/maxtext
24-
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"
25-
2623
if [ -z "${BASE_OUTPUT_PATH}" ]; then
2724
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
2825
# this bucket will store all the files generated by MaxText during a run

tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3'
1818
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
1919
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
2020

21-
# e.g., $HOME/maxtext/src/maxtext
22-
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"
23-
2421
if [ -z "${BASE_OUTPUT_PATH}" ]; then
2522
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
2623
# this bucket will store all the files generated by MaxText during a run

tests/end_to_end/tpu/deepseek/v3.2-671b/2_test_deepseek.sh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2'
1515
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
1616
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
1717

18-
# e.g., $HOME/maxtext/src/maxtext
19-
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"
18+
# TODO: remove this version pin when deepseek32 become available in transformers library
19+
# deepseek_v32 is missing from transformer config, will cause error with transformers==5.6.1
20+
python3 -m pip install transformers==4.57.3
2021

2122
if [ -z "${BASE_OUTPUT_PATH}" ]; then
2223
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
@@ -46,10 +47,14 @@ if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
4647
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
4748
fi
4849

49-
# override deepseek3.2-671b.yml with indexer_topk=2
50+
# Note: override deepseek3.2-671b.yml with indexer_topk=2 for testing
5051
OVERRIDE="override_model_config=True indexer_topk=2"
5152
python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true ${OVERRIDE} --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --max_kl_div=0.3
5253

54+
# Run pre-training - tokamax_gmm implementation
55+
# Note: use sgd due to memory constraint
56+
python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash use_tokamax_splash=True sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=4096 ici_fsdp_parallelism=-1 opt_type=sgd
57+
5358
# Run decoding - megablox implementation
54-
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
59+
# Note: decode requires the access token for huggingface tokenizer even if the model is not gated
5560
python3 -m maxtext.inference.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=3072 max_target_length=4096 ici_fsdp_parallelism=1 ici_tensor_parallelism=-1 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "

0 commit comments

Comments
 (0)