This guide provides comprehensive instructions for setting up MaxText on a local machine or single-host environment, covering everything from cloning the repo and dependency installation to building with Docker. By walking through the process of pre-training a small model, you will gain the foundational knowledge to run jobs on TPUs/GPUs.
Before you can begin a training run, you need to configure your storage environment and set up the basic MaxText configuration.
You'll need a GCS bucket to store all your training artifacts, such as logs, metrics, and model checkpoints.
- In your Google Cloud project, create a new storage bucket.
- Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the
Storage Admin(roles/storage.admin) role to the service account associated with your VMs.
MaxText uses a primary YAML file, configs/base.yml, to manage its settings. This default configuration sets up a llama2 style decoder-only model with approximately 1 billion parameters.
- Before running your first model, take a moment to review this file. Pay special attention to these core settings:
run_name: The name for your experiment.per_device_batch_size: Controls how many examples are processed per chip. You may need to lower this for larger models to avoid running out of memory.max_target_length: The maximum sequence length for the model.learning_rate: The core hyperparameter for the optimizer.- Mode shape parameters:
base_num_decoder_layers,base_emb_dim,base_num_query_heads,base_num_kv_heads, andhead_dim.
- Override settings (optional): You can modify training parameters in two ways: by editing
configs/base.ymldirectly or by passing them as command-line arguments to the training script which is the recommended method. For example, to change the number of training steps, you can pass--steps=500when runningtrain.py. - Note: You must update the variable
base_output_directorywhich is initialized inconfigs/base.ymlto point to a folder within the GCS bucket you just created (e.g.,gs://your-bucket-name/maxtext-output).
Local development on a single host TPU/GPU VM is a convenient way to run MaxText on a single host. It doesn't scale to multiple hosts but is a good way to learn about MaxText. The following describes how to run Maxtext on TPU/GPU VMs.
-
Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as
v5litepod-8,v5p-8, orv4-8. For GPUs, you can usenvidia-h100-mega-80gb,nvidia-h200-141gb, ornvidia-b200. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus. -
For instructions on installing MaxText on your VM, please refer to the official documentation.
After the installation is complete, run a short training job using synthetic data to confirm everything is working correctly. This command trains a model for just 10 steps. Remember to replace $YOUR_JOB_NAME with a unique name for your run and gs://<my-bucket> with the path to the GCS bucket you configured in the prerequisites.
python3 -m maxtext.trainers.pre_train.train \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10Optional: If you want to try training on a real dataset, see Data Input Pipeline for data input options from sources like HuggingFace, Grain, and TFDS.
To demonstrate model output, run the following command:
python3 -m maxtext.inference.decode \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
per_device_batch_size=1Note: Because the model hasn't been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the load_parameters_path argument.
MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in src/maxtext/configs/models for TPU-oriented defaults, and src/maxtext/configs/models/gpu for GPU-oriented defaults.
To use a pre-configured model for TPUs, you override the model_name parameter, and MaxText will automatically load the corresponding configuration from the src/maxtext/configs/models directory and merge it with the settings from src/maxtext/configs/base.yml.
llama3-8b (TPU)
python3 -m maxtext.trainers.pre_train.train \
model_name=llama3-8b \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10qwen3-4b (TPU)
python3 -m maxtext.trainers.pre_train.train \
model_name=qwen3-4b \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10To use a GPU-optimized configuration, you should specify the path to the model's YAML file within the src/maxtext/configs/models/gpu directory as the main config file in the command. These files typically inherit from base.yml and set the appropriate model_name internally, as well as GPU-specific settings.
mixtral-8x7b (GPU)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/gpu/models/mixtral_8x7b.yml \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10This will load gpu/mixtral_8x7b.yml, which inherits from base.yml.
llama3-8b (GPU)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/gpu/models/llama3-8b.yml \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10