This guide provides comprehensive instructions for setting up MaxText on a local machine or single-host environment, covering everything from cloning the repo and dependency installation to building with Docker. By walking through the process of pre-training a small model, you will gain the foundational knowledge to run jobs on TPUs/GPUs.
Before you can begin a training run, you need to configure your storage environment and set up the basic MaxText configuration.
You'll need a GCS bucket to store all your training artifacts, such as logs, metrics, and model checkpoints.
- In your Google Cloud project, create a new storage bucket.
- Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the
Storage Admin(roles/storage.admin) role to the service account associated with your VMs.
MaxText uses a primary YAML file, configs/base.yml, to manage its settings. This default configuration sets up a llama2 style decoder-only model with approximately 1 billion parameters.
- Before running your first model, take a moment to review this file. Pay special attention to these core settings:
run_name: The name for your experiment.per_device_batch_size: Controls how many examples are processed per chip. You may need to lower this for larger models to avoid running out of memory.max_target_length: The maximum sequence length for the model.learning_rate: The core hyperparameter for the optimizer.- Mode shape parameters:
base_num_decoder_layers,base_emb_dim,base_num_query_heads,base_num_kv_heads, andhead_dim.
- Override settings (optional): You can modify training parameters in two ways: by editing
configs/base.ymldirectly or by passing them as command-line arguments to the training script which is the recommended method. For example, to change the number of training steps, you can pass--steps=500when runningtrain.py. - Note: You must update the variable
base_output_directorywhich is initialized inconfigs/base.ymlto point to a folder within the GCS bucket you just created (e.g.,gs://your-bucket-name/maxtext-output).
Local development on a single host TPU/GPU VM is a convenient way to run MaxText on a single host. It doesn't scale to multiple hosts but is a good way to learn about MaxText. The following describes how to run Maxtext on TPU/GPU VMs.
-
Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as
v5litepod-8,v5p-8, orv4-8. For GPUs, you can usenvidia-h100-mega-80gb,nvidia-h200-141gb, ornvidia-b200. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus. -
For instructions on installing MaxText on your VM, please refer to the official documentation.
After the installation is complete, run a short training job using synthetic data to confirm everything is working correctly. This command trains a model for just 10 steps. Remember to replace $YOUR_JOB_NAME with a unique name for your run and gs://<my-bucket> with the path to the GCS bucket you configured in the prerequisites.
python3 -m maxtext.trainers.pre_train.train \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10Optional: If you want to try training on a real dataset, see Data Input Pipeline for data input options from sources like HuggingFace, Grain, and TFDS.
To demonstrate model output, run the following command:
python3 -m maxtext.inference.decode \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
per_device_batch_size=1Note: Because the model hasn't been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the load_parameters_path argument. For instructions on how to convert pre-trained Hugging Face model checkpoints (like Llama or Gemma) to MaxText's Orbax format, please refer to the Checkpoint Conversion Guide.
MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in src/maxtext/configs/models for TPU-oriented defaults, and src/maxtext/configs/models/gpu for GPU-oriented defaults.
To use a pre-configured model for TPUs, you override the model_name parameter, and MaxText will automatically load the corresponding configuration from the src/maxtext/configs/models directory and merge it with the settings from src/maxtext/configs/base.yml.
llama3-8b (TPU)
python3 -m maxtext.trainers.pre_train.train \
model_name=llama3-8b \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10qwen3-4b (TPU)
python3 -m maxtext.trainers.pre_train.train \
model_name=qwen3-4b \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10To use a GPU-optimized configuration, you should specify the path to the model's YAML file within the src/maxtext/configs/models/gpu directory as the main config file in the command. These files typically inherit from base.yml and set the appropriate model_name internally, as well as GPU-specific settings.
mixtral-8x7b (GPU)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/gpu/models/mixtral_8x7b.yml \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10This will load gpu/mixtral_8x7b.yml, which inherits from base.yml.
llama3-8b (GPU)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/gpu/models/llama3-8b.yml \
run_name=${YOUR_JOB_NAME?} \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10