diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index 34a405e912..d8cfc17a44 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -42,47 +42,9 @@ rely on the vLLM library. Let's get started! -## Create virtual environment and Install MaxText dependencies +## Install MaxText and post-training dependencies -```bash -# Create a virtual environment -export VENV_NAME= # e.g., maxtext_venv -pip install uv -uv venv --python 3.12 --seed ${VENV_NAME?} -source ${VENV_NAME?}/bin/activate -``` - -### Option 1: From PyPI releases (Recommended) - -Run the following commands to get all the necessary installations. - -```bash -uv pip install "maxtext[tpu-post-train]>=0.2.0" --resolution=lowest -install_maxtext_tpu_post_train_extra_deps -``` - -It installs MaxText and then for post-training, it installs primarily the following: - -a. [Tunix](https://github.com/google/tunix) as the LLM Post-Training Library, and - -b. `vllm-tpu` which is -[vllm](https://github.com/vllm-project/vllm) and -[tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby -providing TPU inference for vLLM, with unified JAX and PyTorch support. - -### Option 2: From Github - -For using a version newer than the latest PyPI release, you could also install the latest vetted versions of the dependencies from MaxText in the following way: - -```bash -# 1. Clone the repository -git clone https://github.com/AI-Hypercomputer/maxtext.git -cd maxtext - -# 2. Install dependencies in editable mode -uv pip install -e .[tpu-post-train] --resolution=lowest -install_maxtext_tpu_post_train_extra_deps -``` +For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path to include all necessary post-training dependencies. ## Setup environment variables diff --git a/docs/tutorials/posttraining/sft.md b/docs/tutorials/posttraining/sft.md index 6559abbd79..c7ed9f45c5 100644 --- a/docs/tutorials/posttraining/sft.md +++ b/docs/tutorials/posttraining/sft.md @@ -26,20 +26,7 @@ In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get st ## Install MaxText and Post-Training dependencies -```bash -# Create a virtual environment -export VENV_NAME= # e.g., maxtext_venv -pip install uv -uv venv --python 3.12 --seed ${VENV_NAME?} -source ${VENV_NAME?}/bin/activate -``` - -Run the following commands to get all the necessary installations. - -```bash -uv pip install "maxtext[tpu-post-train]>=0.2.0" --resolution=lowest -install_maxtext_tpu_post_train_extra_deps -``` +For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path to include all necessary post-training dependencies. ## Setup environment variables @@ -53,7 +40,7 @@ Set the following environment variables before running SFT. ```sh # -- Model configuration -- -export PRE_TRAINED_MODEL= # e.g., 'llama3.1-8b-Instruct' +export MODEL= # e.g., 'llama3.1-8b-Instruct' # -- MaxText configuration -- export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory @@ -76,7 +63,7 @@ This section explains how to prepare your model checkpoint for use with MaxText. If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. ```sh -export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` ### Option 2: Converting a Hugging Face checkpoint @@ -84,7 +71,7 @@ export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs: Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```sh -export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` ## Run SFT on Hugging Face Dataset @@ -95,8 +82,8 @@ Now you are ready to run SFT using the following command: python3 -m maxtext.trainers.post_train.sft.train_sft \ run_name=${RUN_NAME?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ - model_name=${PRE_TRAINED_MODEL?} \ - load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH?} \ + model_name=${MODEL?} \ + load_parameters_path=${MAXTEXT_CKPT_PATH?} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \ steps=${STEPS?} \ hf_path=${DATASET_NAME?} \ diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index cc5c63b2ac..b54819cee0 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -62,7 +62,7 @@ export STEPS= # e.g., 1000 export HF_TOKEN= # -- Model Configuration -- -export MODEL_NAME= # e.g., deepseek3-671b +export MODEL= # e.g., deepseek3-671b # -- Dataset configuration -- export DATASET_NAME= # e.g., HuggingFaceH4/ultrachat_200k @@ -79,10 +79,10 @@ This section explains how to prepare your model checkpoint for use with MaxText. If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. ```bash -export MODEL_CHECKPOINT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` -**Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags: +**Note:** Make sure that `MAXTEXT_CKPT_PATH` has the checkpoints created using the correct storage flags: ``` export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX. @@ -95,7 +95,7 @@ checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash -export MODEL_CHECKPOINT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items +export MAXTEXT_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items ``` ## Submit workload on GKE cluster @@ -113,7 +113,7 @@ xpk workload create \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}" +--command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. @@ -131,7 +131,7 @@ xpk workload create-pathways \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True" +--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.