Training Performance Validation with ROCm Maxtext-jax Training Docker on the AMD Instinct Accelerators
MaxText framework for ROCm is a specialized fork from upstream MaxText, designed to enable training of large language model (LLM) on AMD GPUs. By leveraging AMD Instinct™ MI300X and MI355X GPUs, MaxText delivers great scalability, performance, and resource utilization for AI workload. See the GitHub repository at ROCm/maxtext.
AMD provides a ready-to-use Docker image for AMD Instinct MI300X and MI355X GPUs containing essential components, including Jax, XLA, ROCm libraries, and MaxText utilities. It contains the following software components to accelerate training workloads:
Note
Shardy is a new config in JAX 0.6.0. You might get related errors if it's not configured correctly. For now you can turn it off by setting shardy=False during the training run. You can also follow the migration guide to enable it.
Note
- It is known that you may see NaNs in the losses while using real data (not synthetic data) when setting
packing=TrueandNVTE_CK_IS_V3_ATOMIC_FP32=0. Make sure to setNVTE_CK_IS_V3_ATOMIC_FP32=1for production training when using real data and input sequence packing (packing=True).
Note
There is a known slight performance regression for DeepSeek-V2-lite (16B) in v26.3. This is being tracked and will be addressed in a future release.
Note
JAX 0.9.1 Early Access known issues:
- There is a known performance regression for MoE models (DeepSeek-V2-lite and Mixtral-8x7B).
- The trace viewer in profiling may be missing some information in the flame graph.
| Software component | Version |
|---|---|
| ROCm | 7.2.1 |
| Jax | 0.8.2 |
| Python | 3.12.3 |
| Transformer Engine | 2.8.0.dev0+9b312832 |
| hipBLASLt | 1.3.0+bfcf25fa18 |
MaxText supports the following key features to train large language models efficiently:
- Transformer Engine (TE)
- Flash Attention (FA) 3, with or without input sequence packing
- GEMM tuning
- Multi-node Support
- NANOO FP8 (for MI300X) or FP8 (for MI355X)
The following models are pre-optimized for performance on the AMD Instinct MI300X and MI355X accelerator.
- Llama 2 7B
- Llama 2 70B
- Llama 3/3.1 8B
- Llama 3/3.1 70B
- Llama 3.1 405B
- Llama 3.3 70B
- DeepSeek-V2-lite (16B)
- Mixtral-8x7B
- Qwen3 14B
- Qwen3 30B-A3B
Note: Some models, such as Llama 3, require an external license agreement through a third party (for example, Meta).
If you have already validated your system, skip this step. Otherwise, please complete the following system validation and optimization steps to set up your system before starting training.
This Docker image is optimized for specific model configurations outlined below. Performance can vary for other training workloads, as AMD doesn’t validate configurations and run conditions outside those described.
For multinode, we need to make sure we have all the packages installed based on the network device we use. You can check multi node examples on how to install these packages before running the workload. You need to only do the set up below if you are using multinode with RDMA, otherwise skip this part.
Install the packages below for building and installing the RDMA driver:
apt install iproute2 -y
apt install -y linux-headers-"$(uname -r)" libelf-dev
apt install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev rdma-core strace libibmad5 libibnetdisc5 ibverbs-providers libibumad-dev libibumad3 libibverbs1 libnl-3-dev libnl-route-3-devPlease refer to your NIC manufacturer's webpage for further steps about compiling and install the RoCE driver, e.g. for Broadcom, please refer to the section Compiling Broadcom NIC Software from Source in Ethernet Networking Guide for AMD Instinct MI300X GPU Clusters
Set the following env variables. You can again check the multinode examples on how to set these variables.
-
Master Address: Change
localhostto the master node's hostname:export MASTER_ADDR="${MASTER_ADDR:-localhost}"
-
Number of Nodes: Set the number of nodes you want to train on (e.g., 2, 4, 8):
export NNODES="${NNODES:-1}"
-
Node Rank: Set the rank of each node (0 for master, 1 for the first worker node, etc.):
export NODE_RANK="${NODE_RANK:-0}"
-
Network Interface Update the network interface in the script to match your system’s network interface. To find your network interface, run (out of container):
ip a
Then, update the following variables in the script:
export NCCL_SOCKET_IFNAME=ens50f0np0 -
RDMA Interface First make sure that packages above are installed on all the nodes. Then set the RDMA interfaces to use for communication.
# If using Broadcom NIC export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 # If using Mellanox NIC export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9
Note
The only models supported in this workflow are those listed in the above section.
This container should not be expected to provide generalized performance across all training workloads. Users should expect the container perform in the model configurations described below, but other configurations and run conditions are not validated by AMD. Use the following instructions to set up the environment, configure the script to train models, and reproduce the benchmark results on the MI300X, MI325X, MI350X, MI355X accelerators with the Docker image.
Users have two choices to reproduce the benchmark results using this Automation and Dashboarding repository.
Jax MaxText has also been integrated into Primus, which supports multiple backends including Megatron-LM, TorchTitan, and JAX MaxText, alongside ROCm-optimized components. Users can now use the unified primus-cli to run training jobs with Jax MaxText backend.
Clone the ROCm Model Automation and Dashboarding (MAD) repository to a local directory and install the required packages on the host machine.
git clone https://github.com/ROCm/MAD
cd MAD
pip install -r requirements.txtRun models through MAD-integrated benchmarking with the following command:
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
python3 tools/run_models.py --tags <mad_model> --keep-model-dir --live-output --timeout 28800For example, use this command to run a performance benchmark test of the Llama 2 7B model on one GPU with bf16 data type in the host machine.
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
python3 tools/run_models.py --tags jax_maxtext_train_llama-2-7b --keep-model-dir --live-output --timeout 28800Note
The madengine package is now available allowing for the replacement of run_models.py.
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
python3 madengine run --tags jax_maxtext_train_llama-2-7b --keep-model-dir --live-output --timeout 28800ROCm MAD launches a Docker container with the name container_ci-jax_maxtext_train_llama-2-7b. The latency and throughput reports of the model are collected in the following path:
~/MAD/perf.csv| model_name |
|---|
| jax_maxtext_train_llama-2-7b |
| jax_maxtext_train_llama-2-70b |
| jax_maxtext_train_llama-3.1-8b |
| jax_maxtext_train_llama-3.1-70b |
| jax_maxtext_train_llama-3.1-405b |
| jax_maxtext_train_llama-3.3-70b |
| jax_maxtext_train_deepseek-v2-lite-16b |
| jax_maxtext_train_mixtral-8x7b |
| jax_maxtext_train_qwen3-14b |
| jax_maxtext_train_qwen3-30b-a3b |
Download and launch the Docker image
Use the following command to pull the Docker image from Docker Hub.
docker pull rocm/jax-training:maxtext-v26.3
Note
Please adjust the following variables based on your environment.
Export variables
- MAD_SECRETS_HFTOKEN is your HuggingFace token to access models, tokenizers, data. See this page for more info.
- HF_HOME is where huggingface_hub will store local data, please refer to Huggingface cli Document on how to download the data. If you already have downloaded/cached huggingface artifacts, set this variable to that path. Downloaded files typically get cached to a place like this:
~/.cache/huggingface.
export MAD_SECRETS_HFTOKEN=<Your HuggingFace token>
export HF_HOME=<Location of saved/cached HuggingFace models>
Launch the Docker container.
docker run -it --device /dev/dri --device /dev/kfd --network host --ipc host --group-add video --cap-add SYS_PTRACE --security-opt seccomp=unconfined --privileged -v $HOME:$HOME -v $HOME/.ssh:/root/.ssh -v $HF_HOME:/hf_cache -e HF_HOME=/hf_cache -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN --shm-size 64G --name training_env rocm/jax-training:maxtext-v26.3
Execute the training_env container (optional if not already in the container)
docker start maxtext_training
docker exec -it maxtext_training bash
Clone Model Automation and Dashboarding (MAD) repo
git clone https://github.com/ROCm/MAD.git
cd MAD/scripts/jax-maxtext
Run setup scripts to install libraries and datasets needed for benchmarking
./jax-maxtext_benchmark_setup.sh -m <model>
Run the benchmark in quantized or unquantized mode.
# For unquantized training
./jax-maxtext_benchmark_report.sh -m <model>
# Or for quantized training
./jax-maxtext_benchmark_report.sh -m <model> -q nanoo_fp8
The performance results should be written to a file in the parent folder.
- Single-node training with Llama 2 7B model
Setup
./jax-maxtext_benchmark_setup.sh -m Llama-2-7B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Llama-2-7B
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8
- Single-node training with Llama 2 70B model
Setup
./jax-maxtext_benchmark_setup.sh -m Llama-2-70B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Llama-2-70B
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m Llama-2-70B -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Llama-2-70B -q fp8
- Single-node training with Llama 3.1 8B model
Setup
./jax-maxtext_benchmark_setup.sh -m Llama-3.1-8B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Llama-3.1-8B -q fp8
- Single-node training with Llama 3.1 70B model
Setup
./jax-maxtext_benchmark_setup.sh -m Llama-3.1-70B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Llama-3.1-70B
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Llama-3.1-70B -q fp8
- Single-node training with Llama 3.3 70B model
Setup
./jax-maxtext_benchmark_setup.sh -m Llama-3.3-70B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Llama-3.3-70B
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Llama-3.3-70B -q fp8
- Single-node training with DeepSeek2 16B model
Setup
./jax-maxtext_benchmark_setup.sh -m DeepSeek-V2-lite
For unquantized training
./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m DeepSeek-V2-lite -q fp8
- Single-node training with Mixtral-8x7B model
Setup
./jax-maxtext_benchmark_setup.sh -m Mixtral-8x7B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Mixtral-8x7B -q fp8
- Single-node training with Qwen3 14B model
Setup
./jax-maxtext_benchmark_setup.sh -m Qwen3-14B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Qwen3-14B
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m Qwen3-14B -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Qwen3-14B -q fp8
- Single-node training with Qwen3 30B-A3B model (MoE)
Setup
./jax-maxtext_benchmark_setup.sh -m Qwen3-30B-A3B
For unquantized training
./jax-maxtext_benchmark_report.sh -m Qwen3-30B-A3B
Or for nanoo_fp8 quantized training on MI300X
./jax-maxtext_benchmark_report.sh -m Qwen3-30B-A3B -q nanoo_fp8
Or for fp8 quantized training on MI355X
./jax-maxtext_benchmark_report.sh -m Qwen3-30B-A3B -q fp8
Note: these scripts will launch the docker and execute the benchmark, so please run them outside of any docker.
The examples below use Slurm for running on multiple nodes. The unified multinode benchmark script accepts a configuration file that specifies the model and training parameters.
To run multi-node training, use the following command:
sbatch -N <NUM_NODES> jax_maxtext_multinode_benchmark.sh <config_file.yml> [docker_image]Parameters:
<NUM_NODES>: Number of nodes to use for training (e.g., 2, 4, 8)<config_file.yml>: Path to the YAML configuration file containing model and training parameters[docker_image]: (Optional) Docker image to use. If not specified, defaults torocm/jax-training:maxtext-v26.3
Configuration files are available in the scripts/jax-maxtext/env_scripts/ directory for different models and GPU architectures:
For MI300X (gfx942):
llama2_7b.yml- Llama 2 7Bllama2_70b.yml- Llama 2 70Bllama3_8b.yml- Llama 3 8Bllama3_70b.yml- Llama 3 70Bqwen3_14b.yml- Qwen3 14Bqwen3_30b_a3b.yml- Qwen3 30B-A3B
For MI355X (gfx950):
gfx950_llama2_7b.yml- Llama 2 7Bgfx950_llama2_70b.yml- Llama 2 70Bgfx950_llama3_8b.yml- Llama 3 8Bgfx950_llama3_70b.yml- Llama 3 70Bgfx950_llama3.1_405b.yml- Llama 3.1 405Bgfx950_qwen3_14b.yml- Qwen3 14Bgfx950_qwen3_30b_a3b.yml- Qwen3 30B-A3B
- Multi-node training with Llama 2 7B model on 2 nodes:
sbatch -N 2 jax_maxtext_multinode_benchmark.sh env_scripts/llama2_7b.yml- Multi-node training with Llama 2 70B model on 4 nodes with custom image:
sbatch -N 4 jax_maxtext_multinode_benchmark.sh env_scripts/llama2_70b.yml rocm/jax-training:maxtext-v26.3- Multi-node training with Llama 3 8B model on 2 nodes:
sbatch -N 2 jax_maxtext_multinode_benchmark.sh env_scripts/llama3_8b.yml- Multi-node training with Llama 3 70B model on 8 nodes:
sbatch -N 8 jax_maxtext_multinode_benchmark.sh env_scripts/llama3_70b.yml- Multi-node training with Llama 3.1 405B model on MI355X (gfx950) with 8 nodes:
sbatch -N 8 jax_maxtext_multinode_benchmark.sh env_scripts/gfx950_llama3.1_405b.ymlClone the Primus repository
git clone https://github.com/AMD-AIG-AIMA/Primus.git
cd Primus
git checkout main
git submodule update --init third_party/maxtext/
Run the training job with primus-cli
For detailed usage of primus-cli, please refer to Primus CLI User Guide.
Here are some examples of using primus-cli to run training jobs with Jax MaxText backend.
Direct Mode: Running the training directly on current host or within an existing docker container.
./primus-cli direct -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yamlContainer Mode: execute in Docker/Podman containers
./primus-cli container --image rocm/jax-training:maxtext-v26.3 \
-- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yamlSlurm Mode: execute distributed training on a Slurm cluster
# Use a custom config file, where you can specify the docker image and set environment variables.
./primus-cli --config my_maxtext_config.yaml slurm srun -N 8 \
-- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yamlMaxText has built-in XPlane profiling support via JAX's profiler. Traces capture GPU kernel timelines, RCCL collectives, HLO graphs, and more. The output can be viewed in TensorBoard's Trace Viewer or analyzed with TraceLens.
The following MaxText config keys control profiling:
profiler=xplane # Use xplane format (produces .xplane.pb files)
skip_first_n_steps_for_profiler=2 # Skip compilation/warmup steps
profiler_steps=5 # Number of steps to profile
upload_all_profiler_results=True # Save all GPU profiles (not just GPU0)
Choosing step counts:
stepsshould be >skip_first_n_steps_for_profiler+profiler_steps(e.g.,steps=12with skip=2, profile=5 gives 5 warmup + 5 profiled + 2 cooldown)skip_first_n_steps_for_profiler=2skips step 0 (compilation) and step 1 (warmup)profiler_steps=5is typically enough; more steps = larger.xplane.pbfiles
The model YAML configs under scripts/jax-maxtext/env_scripts/ already include a profiler key (set to "" by default). To enable profiling when running through MAD or madengine, edit the YAML config for your model and set the profiler fields:
profiler: "xplane"
skip_first_n_steps_for_profiler: 2
profiler_steps: 5
upload_all_profiler_results: True
steps: 12Then run the benchmark as usual:
# Via madengine
python3 madengine run --tags jax_maxtext_train_llama-3.1-8b --keep-model-dir --live-output --timeout 28800
# Or via run_models.py
python3 tools/run_models.py --tags jax_maxtext_train_llama-3.1-8b --keep-model-dir --live-output --timeout 28800Profile output will be written under the base_output_directory specified in the YAML (see Output Structure below). Use --keep-model-dir so the container's output directory is preserved after the run.
#!/bin/bash
set -e
IMAGE="$1" # Docker image, e.g. rocm/jax-training:maxtext-v26.3
TAG="$2" # Short tag for output folder, e.g. v26.3_llama2_7b
PROFILE_DIR="/path/to/profiles/${TAG}"
mkdir -p "${PROFILE_DIR}"
docker run --rm --privileged --network=host \
--device=/dev/dri --device=/dev/kfd --ipc=host \
-v "${PROFILE_DIR}:/mnt/profile" \
"${IMAGE}" bash -c '
export XLA_PYTHON_CLIENT_MEM_FRACTION=.97
export LD_LIBRARY_PATH=/usr/local/lib/:/opt/rocm/lib:$LD_LIBRARY_PATH
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=True --xla_gpu_enable_command_buffer= <your other XLA flags>"
export GPU_MAX_HW_QUEUES=2
cd /workspace/maxtext
python3 -m MaxText.train src/MaxText/configs/base.yml \
run_name=profile \
base_output_directory=/mnt/profile \
hardware=gpu \
steps=12 \
model_name=<your-model> \
dataset_type=synthetic \
enable_checkpointing=False \
enable_goodput_recording=False \
monitor_goodput=False \
<your model-specific flags> \
profiler=xplane \
skip_first_n_steps_for_profiler=2 \
profiler_steps=5 \
upload_all_profiler_results=True
' 2>&1 | tee "${PROFILE_DIR}/run.log"
echo "Profile files:"
find "${PROFILE_DIR}" -name "*.xplane.pb" -o -name "*.trace.json.gz" 2>/dev/nullMaxText writes profiles in TensorBoard format:
<base_output_directory>/
└── profile/
└── tensorboard/
└── plugins/
└── profile/
└── <YYYY_MM_DD_HH_MM_SS>/
├── <hostname>.xplane.pb # Raw XPlane proto (GPU timelines)
├── <hostname>.trace.json.gz # Trace viewer data
└── *.hlo_proto.pb # HLO graphs for each compiled module
pip install tensorboard tensorboard-plugin-profile
# Point --logdir at the directory containing the tensorboard/ folder
tensorboard --logdir /path/to/profiles/<TAG>/profile --port 6006Navigate to Profile > Trace Viewer in the TensorBoard UI.
Tips:
- Zoom into a single training step (skip the first profiled step as it may have residual warmup)
- Look at individual GPU streams to see compute/RCCL overlap
- Use
profiler_steps=5(not more) to keep.xplane.pbunder ~100MB - Too many steps can produce files >500MB that TensorBoard struggles to load
enable_checkpointing=Falseavoids checkpoint I/O noise in the tracedataset_type=syntheticeliminates data loading variability
If you need to collect a trace and the JAX profiler isn't working then you can use rocprofv3 as a temporary workaround like this:
rocprofv3 --hip-trace --kernel-trace --memory-copy-trace --rccl-trace --output-format pftrace -d ./v3_traces -- python3 app.py
- Just replace
python3 app.pywith any command line command that you want to run such as./jax-maxtext_benchmark_report.sh -m Llama-2-7B. - You can set the directory where you want the .json traces to be saved using
-d <TRACE_DIRECTORY> - The resulting traces can be opened in perfetto: https://ui.perfetto.dev/