Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# 1. Run inference on the pre-converted checkpoint.
# 2. Run pre-training starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the pre-training run.
# 4. Convert the checkpoint produced by the pre-training run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
Expand All @@ -30,10 +29,7 @@ UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/it
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
DATASET_PATH=gs://maxtext-dataset

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 2: Run inference on the original checkpoint converted from Hugging Face
# Step 1: Run inference on the original checkpoint converted from Hugging Face
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
Expand All @@ -53,7 +49,7 @@ else
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi

# Step 3: Run Pre-training on the converted checkpoint
# Step 2: Run Pre-training on the converted checkpoint
# We can also run training by using the scanned converted checkpoint
# Note that scanned checkpoint helps with efficient training
python3 -m maxtext.trainers.pre_train.train \
Expand All @@ -65,7 +61,7 @@ python3 -m maxtext.trainers.pre_train.train \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False \
model_name=${MODEL_NAME} scan_layers=false use_multimodal=${USE_MULTIMODAL}

# Step 4: Run inference on the checkpoint generated from the previous run
# Step 3: Run inference on the checkpoint generated from the previous run
if [ ${USE_MULTIMODAL} == true ]; then
python3 -m maxtext.inference.decode \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
Expand All @@ -85,10 +81,4 @@ else
scan_layers=false prompt='I love to' attention=\'dot_product\'
fi

# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} tokenizer_type="huggingface" \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/train/${run_id}/checkpoints/4/items \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=false

10 changes: 1 addition & 9 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# 1. Run inference on the pre-converted checkpoint.
# 2. Run SFT of Gemma3-4B on ChartQA dataset with the converted checkpoint.
# 3. Run inference on the checkpoint produced by the SFT run.
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
Expand All @@ -26,8 +25,7 @@ BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Install torch and google-jetstream
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# Step 1: Install google-jetstream
python3 -m pip install google-jetstream@https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip --no-deps

# Step 2: Run inference on the original checkpoint converted from Hugging Face
Expand Down Expand Up @@ -79,9 +77,3 @@ python3 -m maxtext.inference.decode \
image_path=\'tests/assets/test_image.jpg\' \
attention=\'dot_product\'

# Step 5: Convert the SFT checkpoint back to HuggingFace format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/multimodal/sft/${run_id}/checkpoints/4/items \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=true scan_layers=false
18 changes: 4 additions & 14 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_rl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# 1. Run inference on the pre-converted checkpoint.
# 2. Run RL starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the RL run.
# 4. Convert the checkpoint produced by the RL run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
Expand All @@ -25,10 +24,7 @@ BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 2: Run inference on the original checkpoint converted from Hugging Face
# Step 1: Run inference on the original checkpoint converted from Hugging Face
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
Expand All @@ -37,7 +33,7 @@ python3 -m maxtext.inference.vllm_decode \
prompt='Suggest some famous landmarks in London.' \
use_chat_template=True scan_layers=false

# Step 3: Run RL on the converted checkpoint
# Step 2: Run RL on the converted checkpoint
python3 -m maxtext.trainers.post_train.rl.train_rl \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/rl \
load_parameters_path=${SCANNED_CKPT_PATH} \
Expand All @@ -50,18 +46,12 @@ python3 -m maxtext.trainers.post_train.rl.train_rl \
vllm_additional_config='{"maxtext_config": {"model_name": "gemma3-4b", "log_config": "false"}}'


# Step 4: Run inference on the checkpoint generated from the previous run
# Step 3: Run inference on the checkpoint generated from the previous run
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
hbm_utilization_vllm=0.5 \
prompt='Suggest some famous landmarks in London.' \
use_chat_template=True scan_layers=true

# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/rl/${run_id}/checkpoints/actor/5/model_params \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=false scan_layers=true

17 changes: 4 additions & 13 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# 1. Run inference on the pre-converted checkpoint.
# 2. Run SFT starting from the pre-converted checkpoint.
# 3. Run inference on the checkpoint produced by the SFT run.
# 4. Convert the checkpoint produced by the SFT run back to HuggingFace format.

# Usage:
# export HF_TOKEN=<your Hugging Face access token>
Expand All @@ -25,10 +24,7 @@ BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/${MODEL_NAME}
UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/unscanned/${run_id}/0/items
SCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/to_maxtext/scanned/${run_id}/0/items

# Step 1: Install torch
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 2: Run inference on the original checkpoint converted from Hugging Face
# Step 1: Run inference on the original checkpoint converted from Hugging Face
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${UNSCANNED_CKPT_PATH} \
Expand All @@ -37,7 +33,7 @@ python3 -m maxtext.inference.vllm_decode \
prompt="Suggest some famous landmarks in London." \
use_chat_template=True scan_layers=false

# Step 3: Run SFT on the converted checkpoint
# Step 2: Run SFT on the converted checkpoint
python3 -m maxtext.trainers.post_train.sft.train_sft \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/sft \
load_parameters_path=${SCANNED_CKPT_PATH} \
Expand All @@ -46,7 +42,7 @@ python3 -m maxtext.trainers.post_train.sft.train_sft \
model_name=${MODEL_NAME} enable_single_controller=True \
checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False

# Step 4: Run inference on the checkpoint generated from the previous run
# Step 3: Run inference on the checkpoint generated from the previous run
python3 -m maxtext.inference.vllm_decode \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
Expand All @@ -55,9 +51,4 @@ python3 -m maxtext.inference.vllm_decode \
prompt="Suggest some famous landmarks in London." \
use_chat_template=True scan_layers=true

# Step 5: Convert the checkpoint from MaxText format to Hugging Face format
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
load_parameters_path=${BASE_OUTPUT_DIRECTORY}/sft/${run_id}/checkpoints/5/model_params \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/unscanned/${run_id} \
use_multimodal=false scan_layers=true

31 changes: 31 additions & 0 deletions tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh
Comment thread
YixuanWang-99 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Converts a MaxText checkpoint to a Hugging Face model checkpoint.

# Usage:
# export RUN_ID=$(date +%Y-%m-%d-%H-%M-%S)
# bash test_gemma3_to_hf.sh $RUN_ID $CHECKPOINT_PATH $USE_MULTIMODAL $SCAN_LAYERS

set -ex

run_id=$1
CKPT_PATH=$2
USE_MULTIMODAL=${3:-false}
SCAN_LAYERS=${4:-false}

MODEL_NAME='gemma3-4b'
BASE_OUTPUT_DIRECTORY="gs://runner-maxtext-logs/${MODEL_NAME}"

if [ "${SCAN_LAYERS,,}" = "true" ]; then
scan_status="scanned"
else
scan_status="unscanned"
fi

python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL_NAME} \
tokenizer_type="huggingface" \
load_parameters_path=${CKPT_PATH} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/to_huggingface/${scan_status}/${run_id} \
use_multimodal=${USE_MULTIMODAL} \
scan_layers=$SCAN_LAYERS
Loading