Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 2 additions & 40 deletions docs/tutorials/posttraining/rl.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,9 @@ rely on the vLLM library.

Let's get started!

## Create virtual environment and Install MaxText dependencies
## Install MaxText and post-training dependencies

```bash
# Create a virtual environment
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activate
```

### Option 1: From PyPI releases (Recommended)

Run the following commands to get all the necessary installations.

```bash
uv pip install "maxtext[tpu-post-train]>=0.2.0" --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
```

It installs MaxText and then for post-training, it installs primarily the following:

a. [Tunix](https://github.com/google/tunix) as the LLM Post-Training Library, and

b. `vllm-tpu` which is
[vllm](https://github.com/vllm-project/vllm) and
[tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby
providing TPU inference for vLLM, with unified JAX and PyTorch support.

### Option 2: From Github

For using a version newer than the latest PyPI release, you could also install the latest vetted versions of the dependencies from MaxText in the following way:

```bash
# 1. Clone the repository
git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext

# 2. Install dependencies in editable mode
uv pip install -e .[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
```
For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path to include all necessary post-training dependencies.

## Setup environment variables

Expand Down
25 changes: 6 additions & 19 deletions docs/tutorials/posttraining/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,7 @@ In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get st

## Install MaxText and Post-Training dependencies

```bash
# Create a virtual environment
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activate
```

Run the following commands to get all the necessary installations.

```bash
uv pip install "maxtext[tpu-post-train]>=0.2.0" --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
```
For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path to include all necessary post-training dependencies.

## Setup environment variables

Expand All @@ -53,7 +40,7 @@ Set the following environment variables before running SFT.

```sh
# -- Model configuration --
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b-Instruct'
export MODEL=<MaxText Model> # e.g., 'llama3.1-8b-Instruct'

# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
Expand All @@ -76,15 +63,15 @@ This section explains how to prepare your model checkpoint for use with MaxText.
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

```sh
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
```

### Option 2: Converting a Hugging Face checkpoint

Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

```sh
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
```

## Run SFT on Hugging Face Dataset
Expand All @@ -95,8 +82,8 @@ Now you are ready to run SFT using the following command:
python3 -m maxtext.trainers.post_train.sft.train_sft \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
model_name=${PRE_TRAINED_MODEL?} \
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH?} \
model_name=${MODEL?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
steps=${STEPS?} \
hf_path=${DATASET_NAME?} \
Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/posttraining/sft_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export STEPS=<Fine-Tuning Steps> # e.g., 1000
export HF_TOKEN=<Hugging Face Access Token>

# -- Model Configuration --
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
export MODEL=<MaxText Model> # e.g., deepseek3-671b

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
Expand All @@ -79,10 +79,10 @@ This section explains how to prepare your model checkpoint for use with MaxText.
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

```bash
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
```

**Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags:
**Note:** Make sure that `MAXTEXT_CKPT_PATH` has the checkpoints created using the correct storage flags:

```
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX.
Expand All @@ -95,7 +95,7 @@ checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS))
Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

```bash
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # gs://my-bucket/my-checkpoint-directory/0/items
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # gs://my-bucket/my-checkpoint-directory/0/items
```

## Submit workload on GKE cluster
Expand All @@ -113,7 +113,7 @@ xpk workload create \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
--command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
Expand All @@ -131,7 +131,7 @@ xpk workload create-pathways \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
Loading