Training Performance Validation of Primus Docker with Megatron backend on the AMD Instinct Accelerators
Primus framework with megatron backend is designed to enable efficient training of large-scale language models on AMD GPUs. By leveraging AMD Instinct™ MI300X/MI350X accelerators, Primus Megatron framwework delivers enhanced scalability, performance, and resource utilization for AI workloads. It is purpose-built to support models like Llama 2, Llama 3/3.1, DeepseekV2/V3, and Mixtral MOE, enabling developers to train next-generation AI models with greater efficiency. See the GitHub repository at AMD-AIG-AIMA/Primus.
Note
rocm/pytorch-training docker hub registry will be depreciated, in the future, please go to rocm/primus for latest ROCm pytorch training dockers, which will cover all the pytorch training ecosystem frameworks (e.g. TorchTitan, TorchTune, Megatron-LM, etc.).
The ROCm PyTorch Training Docker rocm/primus:v26.3 (rocm/pytorch-training:v26.3) container, available through AMD Infinity Hub, provides a prebuilt, optimized environment for pre-training a model on the AMD Instinct™ MI300X, MI325X, MI350X and MI355X accelerator. This ROCm PyTorch Docker includes the following components:
| Software component | Version |
|---|---|
| ROCm | 7.2.1 |
| Python | 3.12.3 |
| PyTorch | 2.10.0+git94c6e04 |
| Transformer Engine | 2.12.0.dev0+40434cf6 |
| Flash Attention | 2.8.3 |
| hipBLASLt | 1.3.0-c4b2dc9869 |
| Triton | 3.6.0 |
| RCCL | 2.27.7 |
Primus-Megatron-backend provides the following key features to train large language models efficiently:
- Primus Turbo with optimized attention and grouped gemm kernel
- Transformer Engine (TE)
- APEX
- GEMM tuning
- Torch.compile
- Flash Attention (FA) 3
- AITER Attention
- Fused kernels
- Pre-training
- FP8-GEMM
- Multi-node Support
- 3D parallelism: TP + SP + CP
- Distributed optimizer
The following models are pre-optimized for performance on the AMD Instinct MI300X accelerator.
- Llama 2 7B
- Llama 2 70B
- Llama 3/3.1 8B
- Llama 3/3.1/3.3 70B
- DeepSeek-V2-lite
- DeepSeek-V3
- Mixtral 8x7B
- Mixtral 8x22B
- Qwen 2.5 7/72B
- Zebra-Llama 1B/3B/8B
- Qwen3-30B-A3B
- Qwen3-235B-A22B
- Qwen 3 32B (SFT/ LoRA)
- GPT-OSS-20B
- GPT-OSS-120B
If you have already validated your system, skip this step; otherwise, please complete the following system validation and optimization steps to set up your system before starting training.
Generally, application performance can benefit from disabling NUMA auto-balancing. However, it might be detrimental to performance with certain types of workloads.
Run the command cat /proc/sys/kernel/numa_balancing to check your current NUMA (Non-Uniform Memory Access) settings. Output 0 indicates this setting is disabled. If there is no output or the output is 1, run the following command to disable NUMA auto-balancing.
sudo sh -c 'echo 0 > /proc/sys/kernel/numa_balancing'See Disable NUMA auto-balancing for more information.
The pre-built ROCm Primus-Megatron-backend environment allows users to quickly validate system performance, conduct training benchmarks, and achieve superior performance for models like Llama 2 and Llama 3.1. The docker is powered by Primus-turbo optimizations to achieve optimal performance.
This container should not be expected to provide generalized performance across all training workloads. Users should expect the container perform in the model configurations described below, but other configurations and run conditions are not validated by AMD. Use the following instructions to set up the environment, configure the script to train models, and reproduce the benchmark results on the MI300X accelerators with the AMD Megatron-LM Docker image.
-
Download Docker Image Download the Docker image required for training:
# MI300/MI325/MI35X docker pull rocm/primus:v26.3 -
Launch Docker Container Start the Docker container:
docker run -it --device /dev/dri --device /dev/kfd --device /dev/infiniband --network host --ipc host --group-add video --cap-add SYS_PTRACE --security-opt seccomp=unconfined --privileged -v $HOME:/userHome --shm-size 128G --name primus_training_env rocm/primus:v26.3Note: It's not recommended to bind the
$HOMEdirectory to the container using-v $HOME:$HOME. A good practice is only bind the directory you need to the container. -
Execute the training_env container (optional if no already in the container)
docker start primus_training_env docker exec -it primus_training_env bash
The docker container hosts verified commit e16b27b from Primus repository.
Primus defines training yaml for each model inside examples/megatron/configs/ repository. For example, use examples/megatron/configs/llama3.1_8B-pretrain.yaml for updating llama3.1_8B training parameters. Other yaml for the supported model can be found with examples/megatron/configs/${MODEL_NAME}-pretrain.yaml naming convention in this repository.
Users can toggle various training parameters such as micro_batch_size, global_batch_size, train_iters and other training paramaters inside the pretrain yamls.
Note:
- Supported model definition can be found inside the primus/configs/models/megatron/ repository.
- To migrate existing workload from Rocm/Megatron-LM to primus or add new Workload, please follow the Migration Guide.
You can use either mock data or real data for training.
-
Mock Data: The pretraining yaml scripts by default use
mock_data: true. -
Real Data: To use real data for training, set the variable
train_data_path: nullto your tokenized data path and setmock_data: false.
In primus, each model uses tokenizer from huggingface. For example, llama3.1-8B model uses tokenizer_model: meta-llama/Llama-3.1-8B and tokenizer_type: Llama3Tokenizer defined in the llama3.1-8B model definition. Please use HF_TOKEN with right permissions to access the tokenizer for each model.
# Export your HF_TOKEN in the workspace
export HF_TOKEN=<your_hftoken>To run model training on a single node, go to /workspace/Primus/ folder, and use the following command for setup. Once the setup is complete, use the individual model commands to start training:
pip install -r requirements.txt#Set these variables for better performance only on MI300/MI325X
export HSA_NO_SCRATCH_RECLAIM=1
export NVTE_CK_IS_V3_ATOMIC_FP32=1
export PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32=1 #for better performance- Llama3.1-8B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_8B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml- Llama3.1-8B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_8B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml- Llama2-7B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama2_7B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml- Llama2-7B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama2_7B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml- Llama3.1-70B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_70B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml- Llama2-70B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama2_70B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml- Llama3.3-70B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.3_70B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/llama3.3_70B-BF16-pretrain.yamlExamples for MoE models with expert parallelism enabled, i.e, expert_model_parallel_size > 1
- DeepSeekV2-Lite BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_deepseek_v2_lite.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml- Mixtral 8x7B:
bash runner/primus-cli direct \
--log_file /tmp/primus_mixtral_8x7B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml- QWEN2.5 7B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen2.5_7B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/qwen2.5_7B-BF16-pretrain.yaml- QWEN2.5 7B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen2.5_7B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/qwen2.5_7B-FP8-pretrain.yaml- QWEN2.5 72B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen2.5_72B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/qwen2.5_72B-BF16-pretrain.yaml- Zebra-Llama-1B BF16:
PRIMUS_TRAIN_RUNTIME=legacy bash runner/primus-cli direct \
--log_file /tmp/primus_zebra_llama_1B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml- Qwen3-32B BF16 LoRA:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_32b.log \
-- train posttrain \
--config examples/megatron_bridge/configs/MI300X/qwen3_32b_lora_posttrain.yaml- Qwen3-32B BF16 SFT:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_32b_sft.log \
-- train posttrain \
--config examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml- Qwen3-30B (A3B) BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_30B_A3B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/qwen3_30B_A3B-BF16-pretrain.yaml- Qwen3-30B (A3B) FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_30B_A3B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/qwen3_30B_A3B-FP8-pretrain.yaml- GPT-OSS-20B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_gpt_oss_20B.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/gpt_oss_20B-BF16-pretrain.yaml- GPT-OSS-20B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_gpt_oss_20B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI300X/gpt_oss_20B-FP8-pretrain.yaml- Llama3.1-8B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_8B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml- Llama3.1-8B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_8B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml- Llama2-7B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama2_7B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama2_7B-FP8-pretrain.yaml- Llama2-7B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama2_7B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama2_7B-BF16-pretrain.yaml- Llama3.1-70B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_70B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml- Llama3.1-70B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.1_70B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml- Llama2-70B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama2_70B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama2_70B-BF16-pretrain.yaml- Llama3.3-70B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_llama3.3_70B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/llama3.3_70B-BF16-pretrain.yaml- DeepSeekV2-Lite BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_deepseek_v2_lite.log \
-- train pretrain \
--config examples/megatron/configs//MI355X/deepseek_v2_lite-BF16-pretrain.yaml- Mixtral 8x7B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_mixtral_8x7B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-BF16-pretrain.yaml- QWEN2.5 7B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen2.5_7B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/qwen2.5_7B-BF16-pretrain.yaml- QWEN2.5 7B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen2.5_7B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/qwen2.5_7B-FP8-pretrain.yaml- QWEN2.5 72B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen2.5_72B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/qwen2.5_72B-BF16-pretrain.yaml- Zebra-Llama-1B BF16:
PRIMUS_TRAIN_RUNTIME=legacy bash runner/primus-cli direct \
--log_file /tmp/primus_zebra_llama_1B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml- Qwen3-32B BF16 LoRA:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_32b_lora.log \
-- train posttrain \
--config examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml- Qwen3-30B (A3B) BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_30B_A3B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/qwen3_30B_A3B-BF16-pretrain.yaml- Qwen3-30B (A3B) FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_qwen3_30B_A3B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/qwen3_30B_A3B-FP8-pretrain.yaml- GPT-OSS-20B BF16:
bash runner/primus-cli direct \
--log_file /tmp/primus_gpt_oss_20B.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml- GPT-OSS-20B FP8:
bash runner/primus-cli direct \
--log_file /tmp/primus_gpt_oss_20B_fp8.log \
-- train pretrain \
--config examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yamlTo run training on multiple nodes, you can use primus-cli (recommended) or the run_slurm_pretrain.sh script to launch multinode workloads. Below we list multinode setup and examples to run multinode tests.
MultiNode Setup:
Verify NCCL / network env first. The
primus-clilauncher script sets sensibleNCCL_*defaults viabase_env.sh, but auto-detection can pick the wrong device on multi-NIC nodes. Always confirmNCCL_IB_HCA,NCCL_IB_GID_INDEX,NCCL_SOCKET_IFNAME, andGLOO_SOCKET_IFNAME(set to the same value asNCCL_SOCKET_IFNAME) are correct for your fabric. If necessary, you canexportthese environment variables before running.
git clone --recurse-submodules https://github.com/AMD-AGI/Primus.git
cd Primus/
git checkout release/v26.3
git submodule update --init --recursive
export DOCKER_IMAGE=rocm/primus:v26.3
export HF_TOKEN=<your_HF_token>
export NCCL_IB_HCA=<your_NCCL_IB_HCA> # specify which RDMA interfaces to use for communication
export NCCL_SOCKET_IFNAME=<your_NCCL_SOCKET_IFNAME> # your Network Interface
export GLOO_SOCKET_IFNAME=<your_GLOO_SOCKET_IFNAME> # your Network Interface
export NCCL_IB_GID_INDEX=3 # Set InfiniBand GID index for NCCL communication. Default is 3 for ROCE
# MI300/MI325 only extra settings
export HSA_NO_SCRATCH_RECLAIM=1
export NVTE_CK_IS_V3_ATOMIC_FP32=1
export PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32=1 #for better performanceFor clusters using AMD AINIC, the following environment variables should be set.
export USING_AINIC=1
export NCCL_PXN_DISABLE=0
export NCCL_IB_GID_INDEX=1Notes:
- Make sure correct network drivers are installed on the nodes. If inside a docker, either install the drivers inside the docker container or pass the network drivers from the host while creating docker container.
- If
NCCL_IB_HCAandNCCL_SOCKET_IFNAMEare not set, Primus will try to auto-detect. However, since NICs can vary accross different cluster, it is encouraged to explicitly export your NCCL parameters for the cluster. - To find your network interface, you can use
ip a. - To find rdma interfaces, you can use
ibv_devicesto get the list of all the RDMA/IB devices.
- Llama3.1-8B FP8 8 Node:
# Adjust the training parameters. For e.g., `global_batch_size: 8 * #single_node_bs` for 8 nodes in this case
NNODES=8 EXP=examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml bash ./examples/run_slurm_pretrain.sh --global_batch_size 1024- Llama2-7B FP8 8 Node:
# Adjust the training parameters. For e.g., `global_batch_size: 8 * #single_node_bs` for 8 nodes in this case
NNODES=8 EXP=examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml bash ./examples/run_slurm_pretrain.sh --global_batch_size 2048- Llama3.1-70B FP8 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 4 --global_batch_size 256 --recompute_num_layers 80- Llama3.1-70B BF16 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 1 --global_batch_size 256 --recompute_num_layers 12- Llama2-70B FP8 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/llama2_70B-FP8-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 10 --global_batch_size 640 --recompute_num_layers 80- Llama2-70B BF16 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml bash ./examples/run_slurm_pretrain.sh --micro_batch_size 2 --global_batch_size 1536 --recompute_num_layers 12- Llama3.3-70B FP8 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/llama3.3_70B-FP8-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 4 --global_batch_size 256 --recompute_num_layers 80- Llama3.3-70B BF16 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/llama3.3_70B-BF16-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 1 --global_batch_size 256 --recompute_num_layers 12- Mixtral 8x7B BF16 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 2 --global_batch_size 256- Qwen2.5-72B FP8 8 Nodes:
NNODES=8 EXP=examples/megatron/configs/MI300X/qwen2.5_72B-FP8-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 8 --global_batch_size 512 --recompute_num_layers 80- Mixtral-8x22B BF16 4 Nodes MI355
Launch the training using the primus-cli (recommended)
# In the Primus directory
./primus-cli slurm srun -N 4 -- train pretrain --config examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml --micro_batch_size 1 --global_batch_size 512 --num_virtual_stages_per_pipeline_rank 2 --pipeline_model_parallel_size 4 --expert_model_parallel_size 8 --recompute_num_layers 1 --moe_use_legacy_grouped_gemm True --gradient_accumulation_fusion TrueLaunch the training using the legacy script
NNODES=4 EXP=examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 1 --global_batch_size 512 --num_virtual_stages_per_pipeline_rank 2 --pipeline_model_parallel_size 4 --expert_model_parallel_size 8 --recompute_num_layers 1 --moe_use_legacy_grouped_gemm True --gradient_accumulation_fusion True- Llama3.1-405B FP8 8 Nodes MI325
Launch the training using the primus-cli (recommended)
# In the Primus directory
./primus-cli slurm srun -N 8 -- train pretrain --config examples/megatron/configs/MI325X/llama3.1_405B-FP8-pretrain.yaml --micro_batch_size 1 --global_batch_size 256 --decoder_first_pipeline_num_layers 15 --decoder_last_pipeline_num_layers 15We use TP=8 for Llama3.1-405B model on 8 nodes. Because it has 126 layers which is not divisible by 8, we need to set decoder_first_pipeline_num_layers and decoder_last_pipeline_num_layers.
Launch the training using the legacy script
NNODES=8 EXP=examples/megatron/configs/MI300X/llama3.1_405B-FP8-pretrain.yaml bash examples/run_slurm_pretrain.sh --micro_batch_size 1 --global_batch_size 256 --decoder_first_pipeline_num_layers 15 --decoder_last_pipeline_num_layers 15-
fp8: `--fp8 hybrid`` enables fp8 GEMMS
-
use_torch_fsdp2:
use_torch_fsdp2: 1enables torch fsdp-v2.Note that if FSDP is enabled, then turn these variables to false
use_distributed_optimizer: false,overlap_param_gather: false. -
profile: To enable pytorch profiling, set all these parameter:
profile: true use_pytorch_profiler: true profile_step_end: 7 profile_step_start: 6
-
train_iters: Set the total number of iterations (default: 50).
-
mock_data: By default set to true.
-
micro_batch_size: Micro batch size
-
global_batch_size: Global Batch size
-
recompute_granularity:
Activation Checkpointing (
null,sel,full). Default: null. When set tofull, also setrecompute_num_layersandrecompute_method: (uniform or block) -
num_layers: Using reduced number of layers as a proxy model