Skip to content

Commit 9d79644

Browse files
committed
feat: consolidate the step 5 (checkpoint conversion back to HF)
1 parent 4bcec6a commit 9d79644

5 files changed

Lines changed: 34 additions & 30 deletions

File tree

tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run pre-training starting from the pre-converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the pre-training run.
9-
# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -85,10 +84,4 @@ else
8584
scan_layers=false prompt='I love to' attention=\'dot_product\'
8685
fi
8786

88-
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
89-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
90-
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
91-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
92-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
93-
use_multimodal=${USE_MULTIMODAL} \
94-
scan_layers=false
87+

tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run SFT of Gemma3-4B on ChartQA dataset with the converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the SFT run.
9-
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -79,9 +78,3 @@ python3 -m maxtext.inference.decode \
7978
image_path=\'tests/assets/test_image.jpg\' \
8079
attention=\'dot_product\'
8180

82-
# Step 5: Convert the SFT checkpoint back to HuggingFace format
83-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
84-
model_name=${MODEL_NAME} \
85-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/multimodal/sft/${run_id}/checkpoints/4/items \
86-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
87-
use_multimodal=true scan_layers=false

tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run RL starting from the pre-converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the RL run.
9-
# 4. Convert the checkpoint produced by the RL run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -58,10 +57,4 @@ python3 -m maxtext.inference.vllm_decode \
5857
hbm_utilization_vllm=0.5 \
5958
prompt='Suggest some famous landmarks in London.' \
6059
use_chat_template=True scan_layers=true
61-
62-
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
63-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
64-
model_name=${MODEL_NAME} \
65-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
66-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
67-
use_multimodal=false scan_layers=true
60+

tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run SFT starting from the pre-converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the SFT run.
9-
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -55,9 +54,4 @@ python3 -m maxtext.inference.vllm_decode \
5554
prompt="Suggest some famous landmarks in London." \
5655
use_chat_template=True scan_layers=true
5756

58-
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
59-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
60-
model_name=${MODEL_NAME} \
61-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
62-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
63-
use_multimodal=false scan_layers=true
57+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
3+
# Converts a MaxText checkpoint to a Hugging Face model checkpoint.
4+
5+
# Usage:
6+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
7+
# bash test_gemma3_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS
8+
9+
set -ex
10+
11+
run_id=$1
12+
CKPT_PATH=$2
13+
USE_MULTIMODAL=${3:-false}
14+
SCAN_LAYERS=${4:-false}
15+
16+
MODEL_NAME='gemma3-4b'
17+
BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}"
18+
19+
if [ "${SCAN_LAYERS,,}" = "true" ]; then
20+
scan_status="scanned"
21+
else
22+
scan_status="unscanned"
23+
fi
24+
25+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
26+
model_name=${MODEL_NAME} \
27+
tokenizer_type="huggingface" \
28+
load_parameters_path=${CKPT_PATH} \
29+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \
30+
use_multimodal=${USE_MULTIMODAL} \
31+
scan_layers=$SCAN_LAYERS

0 commit comments

Comments
 (0)