Skip to content

Latest commit

 

History

History
106 lines (77 loc) · 4.21 KB

File metadata and controls

106 lines (77 loc) · 4.21 KB

SFT on single-host TPUs

Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.

This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.

We use Tunix, a JAX-based library designed for post-training tasks, to perform SFT.

In this tutorial we use a single host TPU VM such as v6e-8/v5p-8. Let's get started!

Install MaxText and Post-Training dependencies

# Create a virtual environment
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activate

Run the following commands to get all the necessary installations.

uv pip install "maxtext[tpu-post-train]>=0.2.0" --resolution=lowest
install_maxtext_tpu_post_train_extra_deps

Setup environment variables

Set the following environment variables before running SFT.

# -- Model configuration --
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
export HF_TOKEN=<Hugging Face access token>

# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
export STEPS=<number of fine-tuning steps to run> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face dataset name> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<data split for train> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages']

Get your model checkpoint

This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.

Option 1: Using an existing MaxText checkpoint

If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Option 2: Converting a Hugging Face checkpoint

Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Run SFT on Hugging Face Dataset

Now you are ready to run SFT using the following command:

python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \
    run_name=${RUN_NAME?} \
    base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
    model_name=${PRE_TRAINED_MODEL?} \
    load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH?} \
    hf_access_token=${HF_TOKEN?} \
    tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER?} \
    per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
    steps=${STEPS?} \
    hf_path=${DATASET_NAME?} \
    train_split=${TRAIN_SPLIT?} \
    train_data_columns=${TRAIN_DATA_COLUMNS?} \
    profiler=xplane

Your fine-tuned model checkpoints will be saved here: $BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints.