Skip to content

Commit 8718c74

Browse files
committed
consolidate checkpoint conversion back to HF in Gemma3 tests
1 parent 4bcec6a commit 8718c74

5 files changed

Lines changed: 41 additions & 44 deletions

File tree

tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run pre-training starting from the pre-converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the pre-training run.
9-
# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -30,10 +29,7 @@ UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/it
3029
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
3130
DATASET_PATH=gs://maxtext-dataset
3231

33-
# Step 1: Install torch
34-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
35-
36-
# Step 2: Run inference on the original checkpoint converted from Hugging Face
32+
# Step 1: Run inference on the original checkpoint converted from Hugging Face
3733
if [ ${USE_MULTIMODAL} == true ]; then
3834
python3 -m maxtext.inference.decode \
3935
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
@@ -53,7 +49,7 @@ else
5349
scan_layers=false prompt='I love to' attention=\'dot_product\'
5450
fi
5551

56-
# Step 3: Run Pre-training on the converted checkpoint
52+
# Step 2: Run Pre-training on the converted checkpoint
5753
# We can also run training by using the scanned converted checkpoint
5854
# Note that scanned checkpoint helps with efficient training
5955
python3 -m maxtext.trainers.pre_train.train \
@@ -65,7 +61,7 @@ python3 -m maxtext.trainers.pre_train.train \
6561
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
6662
model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL}
6763

68-
# Step 4: Run inference on the checkpoint generated from the previous run
64+
# Step 3: Run inference on the checkpoint generated from the previous run
6965
if [ ${USE_MULTIMODAL} == true ]; then
7066
python3 -m maxtext.inference.decode \
7167
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
@@ -85,10 +81,4 @@ else
8581
scan_layers=false prompt='I love to' attention=\'dot_product\'
8682
fi
8783

88-
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
89-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
90-
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
91-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
92-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
93-
use_multimodal=${USE_MULTIMODAL} \
94-
scan_layers=false
84+

tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run SFT of Gemma3-4B on ChartQA dataset with the converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the SFT run.
9-
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -26,8 +25,7 @@ BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
2625
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
2726
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items
2827

29-
# Step 1: Install torch and google-jetstream
30-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
28+
# Step 1: Install google-jetstream
3129
python3 -m pip install google-jetstream@https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip --no-deps
3230

3331
# Step 2: Run inference on the original checkpoint converted from Hugging Face
@@ -79,9 +77,3 @@ python3 -m maxtext.inference.decode \
7977
image_path=\'tests/assets/test_image.jpg\' \
8078
attention=\'dot_product\'
8179

82-
# Step 5: Convert the SFT checkpoint back to HuggingFace format
83-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
84-
model_name=${MODEL_NAME} \
85-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/multimodal/sft/${run_id}/checkpoints/4/items \
86-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
87-
use_multimodal=true scan_layers=false

tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run RL starting from the pre-converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the RL run.
9-
# 4. Convert the checkpoint produced by the RL run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -25,10 +24,7 @@ BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
2524
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
2625
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items
2726

28-
# Step 1: Install torch
29-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
30-
31-
# Step 2: Run inference on the original checkpoint converted from Hugging Face
27+
# Step 1: Run inference on the original checkpoint converted from Hugging Face
3228
python3 -m maxtext.inference.vllm_decode \
3329
model_name=${MODEL_NAME} \
3430
load_parameters_path=${UNSCANNED_CKPT_PATH} \
@@ -37,7 +33,7 @@ python3 -m maxtext.inference.vllm_decode \
3733
prompt='Suggest some famous landmarks in London.' \
3834
use_chat_template=True scan_layers=false
3935

40-
# Step 3: Run RL on the converted checkpoint
36+
# Step 2: Run RL on the converted checkpoint
4137
python3 -m maxtext.trainers.post_train.rl.train_rl \
4238
base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \
4339
load_parameters_path=${SCANNED_CKPT_PATH} \
@@ -50,18 +46,12 @@ python3 -m maxtext.trainers.post_train.rl.train_rl \
5046
vllm_additional_config='{"maxtext_config": {"model_name": "gemma3-4b", "log_config": "false"}}'
5147

5248

53-
# Step 4: Run inference on the checkpoint generated from the previous run
49+
# Step 3: Run inference on the checkpoint generated from the previous run
5450
python3 -m maxtext.inference.vllm_decode \
5551
model_name=${MODEL_NAME} \
5652
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
5753
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
5854
hbm_utilization_vllm=0.5 \
5955
prompt='Suggest some famous landmarks in London.' \
6056
use_chat_template=True scan_layers=true
61-
62-
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
63-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
64-
model_name=${MODEL_NAME} \
65-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
66-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
67-
use_multimodal=false scan_layers=true
57+

tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# 1. Run inference on the pre-converted checkpoint.
77
# 2. Run SFT starting from the pre-converted checkpoint.
88
# 3. Run inference on the checkpoint produced by the SFT run.
9-
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.
109

1110
# Usage:
1211
# export HF_TOKEN=<your Hugging Face access token>
@@ -55,9 +54,4 @@ python3 -m maxtext.inference.vllm_decode \
5554
prompt="Suggest some famous landmarks in London." \
5655
use_chat_template=True scan_layers=true
5756

58-
# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
59-
python3 -m maxtext.checkpoint_conversion.to_huggingface \
60-
model_name=${MODEL_NAME} \
61-
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
62-
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
63-
use_multimodal=false scan_layers=true
57+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
3+
# Converts a MaxText checkpoint to a Hugging Face model checkpoint.
4+
5+
# Usage:
6+
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
7+
# bash test_gemma3_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS
8+
9+
set -ex
10+
11+
run_id=$1
12+
CKPT_PATH=$2
13+
USE_MULTIMODAL=${3:-false}
14+
SCAN_LAYERS=${4:-false}
15+
16+
MODEL_NAME='gemma3-4b'
17+
BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}"
18+
19+
if [ "${SCAN_LAYERS,,}" = "true" ]; then
20+
scan_status="scanned"
21+
else
22+
scan_status="unscanned"
23+
fi
24+
25+
python3 -m maxtext.checkpoint_conversion.to_huggingface \
26+
model_name=${MODEL_NAME} \
27+
tokenizer_type="huggingface" \
28+
load_parameters_path=${CKPT_PATH} \
29+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \
30+
use_multimodal=${USE_MULTIMODAL} \
31+
scan_layers=$SCAN_LAYERS

0 commit comments

Comments
 (0)