Skip to content

Latest commit

 

History

History
329 lines (204 loc) · 19.2 KB

File metadata and controls

329 lines (204 loc) · 19.2 KB

Customizing model configs for TPUs

Introduction

This document provides a guide to optimize and customize your LLM model configurations for higher performance (i.e. MFU) on Cloud TPU. Note that this document focuses exclusively on performance tuning. The analysis of model quality and convergence behavior is outside of scope.

Step 1. Identify initial configs

To begin, identify your model's size, review open-source model configs, and establish the initial configurations for each block. You can use our reference calculator (on Colab) to estimate parameters and FLOPs for dense, Mixtral-like Mixture of Experts (MoE), and DeepSeek-like MoE models to help you estimate the parameter count and FLOPs.

Based on resources like Language Modeling from Scratch, common architectural ratios include:

Dense models

  • mlp_dim / emb_dim: 2.5-4
  • head_dim * num_query_heads / emb_dim: 1-2
  • emb_dim / num_decoder_layers: 100-200

MoE models

  • sparsity (num_experts / num_experts_per_tok): 4-32
  • moe_mlp_dim / emb_dim: 0.3-3

Step 2. Consider TPU best practices

Model configs

To unlock peak performance on TPUs, it is critical to keep the Matrix Multiply Unit (MXU) fully utilized. The MXU is the primary computational engine, with the Trillium and Ironwood chips specifically optimized for 256×256 matrix multiplications (earlier TPU versions, like v4/v5e/v5p, are optimized for 128×128 operations). Processing smaller matrix multiplications (e.g., two 128×128 operations on Trillium and Ironwood) will halve the efficiency compared to a single, fully-utilized 256×256 operation.

Therefore, for optimal efficiency:

  • Model and MLP Dimensions: Design your model's emb_dim and mlp_dim to be multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
  • Self-Attention Head Dimension: Ensure your attention head_dim are also multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).

Generally, larger multiples are more efficient. If achieving these specific multiples isn't possible, prioritize dimensions to a multiple of either 8 or 128 to help the XLA compiler optimize memory and computation.

To achieve efficient memory usage on a TPU, configure your training with the largest batch size that fits within its memory limits (configure a rematerialization policy with offloading to achieve the best MFU). Each TPU core leverages internal 8×128 vector registers for highly optimized matrix multiplications. Therefore, for peak performance and to minimize padding, your batch size should ideally be a multiple of 128. If a multiple of 128 is not feasible, try a multiple of 8. For more detailed explanations, see this performance guide.

Ironwood

Ironwood is engineered for cutting-edge, large-scale AI model training and inference. To unlock its full potential, the primary goal is to continuously supply data to its powerful TensorCores, preventing bottlenecks from memory or the Inter-Chip Interconnect (ICI).

We have published optimized recipes for models like DeepSeek v3, GPT-OSS, Qwen3, and Llama3 on Ironwood, covering both BF16 and FP8 precision, available in this guide.

Key strategies to maximize performance on Ironwood include:

  • Adopt FP8 Precision: Ironwood delivers 2x throughput with FP8 compared to BF16. Design models to use mixed-precision training, employing FP8 for weights and activations where possible to maximize computational speed.
  • Offload to SparseCores: Ironwood's enhanced SparseCores are crucial for efficiency. Offloading collective communication and data management to keep TensorCores focused on compute.
  • Leverage the dual-chiplet architecture: Each Ironwood chip contains two TensorCores with an ultra-fast interconnect (die-to-die, 6x faster than 1D ICI link).

Given Ironwood's high compute power, communication bandwidth can easily become the limiting factor. To address this:

  • Leverage SparseCore offloading: By default, collective operations (like All-Reduce, All-Gather, etc.) are offloaded to SparseCore, allowing them to run in parallel with TensorCore computations. This effectively hides communication latency and improving Model Flop Utilization (MFU). If the default collective operations do not meet your performance requirements or fail to offload to SparseCore as intended, you can maximize throughput tuning those XLA flags.
  • Optimize sharding strategies: Align your model distribution with the hardware topology. Choose sharding strategies (e.g., data, tensor, pipeline parallelism) that minimize data transfer over the ICI and maximize the overlap between computation and communication.

Performance configs

Use these general runtime configurations to improve your model's performance.

  • Multi-Head Attention (MHA). If you are using MHA, we recommend to set fused_qkv=True to fuse the query, key, and value computations into a single, more efficient operation.

  • Flash Attention. Use the largest possible block size to maximize throughput.

  • Memory usage. To free up memory with large models, use custom remat policy to offload layer activations (including inputs, attention, and MLP blocks) to the host CPU.

  • Compiler flags. XLA is the backend compiler for TPUs. Many critical performance settings can be controlled directly through XLA flags. We suggest beginning with the proven flags we have tested and provided here.

  • Benchmark. For consistent speed tests, set reuse_example_batch=1 to repeatedly use the same data batch, isolating computation speed from data loading. Or use on-the-fly generated data by setting dataset_type=synthetic.

Step 3. Choose efficient sharding strategies using Roofline Analysis

To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for v5p, Trillium, and Ironwood that demonstrate which sharding approaches work well for specific models. We recommend reading and Jax’s scaling book.

TPU Type ICI Arithmetic Intensity
v5p 2550 for 1D-ICI
Trillium 5100 for 1D-ICI (1D with wrapound or 2D without wraparound)
2550 for 2D-ICI (2D with wraparound on both dimensions), particularly for v6e-256
Ironwood 12800 for 1D-ICI

Fully Sharded Data Parallelism (FSDP)

Pure FSDP

For pure FSDP to be effective, it must have enough memory to hold both a large data batch and a full, single layer of weights at the same time.

FSPD AI: global batch / sparsity (sparsity = num_experts / num_experts_per_tok).

Example with a sparsity of 16:

  • global batch / sparsity > hardware AI

v5p:

  • global batch / 16 > 2550
  • global batch > 40k (in tokens)

Trillium:

  • global batch / 16 > 2550 (16x16 with wraparound)
  • global batch > 40k (in tokens)

We also need a single layer of weights to fit into memory which can be an issue for medium/large MoE models, e.g. DeepSeek has roughly 10B params per layer, which corresponds to 40GiB of bf16 weights and gradients, which will not fit into Trillium’s 32GiB of HBM. So the use of pure FSDP on Trillium is feasible for models with layers not exceeding roughly 5B parameters. For these larger models need Expert or Tensor Parallelism.

Ironwood:

  • global batch / 16 > 12800
  • global batch > 205k (in tokens)

Mix FSDP

For sparse models, large models, or when scaling to a large number of chips FSDP can be used in conjunction with other sharding strategies, such as Expert Parallelism (EP), Tensor Parallelism (TP), and Pipeline Parallelism (PP).

The same AI as derived in the Pure FSDP section above still hold, we need global batch / sparsity * FSDP > hardware AI which is equivalently to per device batch (pdb) / sparsity * TP * EP * PP > hardware AI.

Example with EP=16, FSDP=16, and sparsity=32:

  • pdb * EP / sparsity > hardware AI

v5p:

  • pdb * 16 / 32 > 2550
  • pdb > 2550 * 32 / 16 = 5k (in tokens)

Trillium:

  • pdb * 16 / 32 > 5100
  • pdb > 5100 * 32 / 16 = 10k (in tokens)

Ironwood:

  • pdb * 16 / 32 > 12800
  • pdb > 12800 * 32 / 16 = 26k (in tokens)

We need a per device batch of at least 5k for v5p, 10k for Trillium, and 26k for Ironwood in this case.

Expert Parallelism (EP)

If pure FSDP doesn’t work either due to AI or to fit in layer weights, EP is generally the way to go for sparse models (large dense models should use TP).

AI of 1D EP on ICI rings = 4 * mlp_dim / EP. Communication cost of all-to-all is roughly 1/4 of all-gather and reduce-scatter.

Example with EP=4

v5p:

  • 4 * M > 2550 * 4
  • M > 2.5k

Trillium:

  • 4 * M > 5100 * 4
  • M > 5k

Ironwood:

  • 4 * M > 12800 * 4
  • M > 13k

These examples show that to use EP, we need a large enough MLP dimension.

It's important to note that this is only a roofline analysis. A nocap strategy with a high degree of EP introduces additional overhead - load balancing across expert groups becomes more challenging.

Tensor Parallelism (TP)

Tensor parallelism can be used for large dense models or super large sparse models, particularly helpful when a small per device batch is needed and to be used with PP.

AI of TP: M / TP

Example with TP=4

  • M / TP > hardware AI

v5p:

  • M / 4 > 2550
  • M > 10k

Trillium:

  • M / 4 > 5100
  • M > 20k

We have seen in practice M should be even larger - ideally 40k+. This is what we use for Llama-405B (M=53k), and was used for a custom sparse 10T model (M=40k, 64 experts). TP=4 corresponds to a custom Trillium mesh, an 8x8 ring of 2x2 subrings (the TP communication operates on the 2x2 ring). This 2x2 ring performs well (near roofline), but the 8x8 rings perform poorly (0.5 x 1 axis). E.g. if we use FSDP=64, TP=4, the FSDP=64 communications will be slower than the hardware ICI roofline, so we prefer to use the full 16 axis when M is large enough.

Ironwood:

  • M / 4 > 12800
  • M > 51k

Example with TP=16

  • M / TP > hardware AI

v5p:

  • M / 16 > 2550
  • M > 41k

Trillium:

  • M / 16 > 5100
  • M > 82k

To use TP=16, we need M > 80k (ideally larger, 100k+). We have used this in a custom dense model (900B, M=131k), which performs very well even at 1k per device tokens (scaling to 25k+ with a reasonable global batch).

Pipeline Parallelism (PP)

Pipeline Parallelism is advantageous when global batch size limits per device batch size, making Data Parallelism (DP) inefficient. PP is associated with small communication costs since it only needs to permute the small layer inputs.

AI of PP: 3/2 * layers_per_pipeline_stage * M * num_experts_per_tok

Example with PP=16, layers_per_pipeline_stage=1, num_experts_per_tok=8

  • layers_per_pipeline_stage * M * num_experts_per_tok > hardware AI

v5p - PP over ICI:

  • 3 * M * 8 / 2 > 2550
  • M > 210

v5p - PP over DCN:

  • 3 * M * 8 / 2 > 73000
  • M > 6k

Trillium over ICI:

  • 3 * M * 8 / 2 > 5100
  • M > 420

Trillium over DCN:

  • 3 * M * 8 / 2 > 73000
  • M > 6k

Ironwood over ICI:

  • 3 * M * 8 / 2 > 12800
  • M > 1100

It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the link for specific challenges regarding PP + FSDP/DP.

Step 4. Analyze experiments

With your configs, begin experimenting to evaluate the model's performance. We strongly recommend capturing a profile by following these instructions. If you are using MaxText, this can be done by simply setting profiler=xplane in your configuration.

After generating the profile, use a tool, like xprof, xprofiler, or tensorboard to analyze the results. This example (Profile TPU Programs) can serve as your guide. A key principle for maximizing training throughput is to ensure you are fully utilizing the available HBM. Once you achieve satisfactory performance, you can proceed with full training runs. Continue to analyze your model and refine your configurations as needed.

Example of dense model

900B dense model on Trillium

To use Trillium's 16x16 mesh efficiently for a large dense model, we would like to use TP=16. This requires a huge MLP dimension, of at least 5k * 16 = 80k. With a per-device batch size of 4k tokens, this model achieved 39.8% MFU. The model demonstrated excellent scalability, maintaining 37% MFU even when the batch size was reduced to just 1k tokens per device.

Final Configs
emb_dim 16384
mlp_dim 131072
head_dim 256
num_query_head 64
num_kv_head 16
num_decoder_layers 128
Total Params 9.15E+11
MFU (1 pod Trillium) 39.8%

Example of MoE model

700B Mixtral-like MoE on Trillium

Our objective was to develop a custom Mixtral-like MoE model capable of high MFU on Trillium TPUs, targeting a 1.5 capacity factor (The capacity factor is a multiplier used to determine the processing capacity of each expert. it is used as Expert Capacity = (Tokens in Batch / Number of Experts) * Capacity Factor). We established an initial baseline of 43.1% MFU with a 1.0 capacity factor. Profiling revealed this configuration utilized approximately 20GiB HBM. To better leverage Trillium's 32GiB HBM and avoid potential convergence issues with large global batch sizes during scaling (maintaining a per device batch size of 8k), we made the following architectural adjustments:

  • Increased the MLP dimension from 3x to 4x of the model dimension (32,768 : 8,192).
  • Increased query heads from 32 to 128 for each layer, while reducing the number of layers from 72 to 56 to preserve overall model size around 700B.

These changes, without updating sharding strategies, initially yielded nearly 50% MFU. Upon increasing the capacity factor to 1.5 (adding a buffer to allow experts to handle imbalance in token routing), MFU slightly decreased to 38.1% and scaling to 4 pods to get 35.3% MFU, which still exceeded our target of 35%. More detailed configs can be found here in the repo.

Initial Configs Experimental Config Final Configs
emb_dim 8192 8192 8192
mlp_dim 24576 32768 32768
num_experts 16 16 16
num_experts_per_tok 2 2 2
sparsity 8 8 8
head_dim 256 256 256
num_query_head 32 128 128
num_kv_head 8 8 8
num_decoder_layers 72 56 56
capacity_factor 1.0 1.0 1.5
Total Params 7.08E+11 7.54E+11 7.54E+11
Active Params 9.96E+10 1.23E+11 1.23E+11
MFU (1 pod Trillium) 43.1% 49.8% 38.1%
MFU (4 pod Trillium) n/a n/a 35.3%

10T Mixtral-like MoE on Trillium

Objective was to demonstrate achieving reasonable MFU on a low batch setting (2k tokens per device) for a highly sparse (sparsity=32) model on Trillium. This requires using pipeline parallelism over DCN, which in turn calls for EP+TP over ICI (EP=64, TP=4). This model achieved 26% MFU on 16 pods (PP=16), and degrades only by a few percent when adding in more DP replicas (24% MFU with PP=8 and DP=2), even at a small per device batch size of only 2k (scaling to 25k+ chips and maintaining a reasonable global batch size).

Final Configs
emb_dim 10240
mlp_dim 40960
num_experts 64
num_experts_per_tok 2
sparsity 32
head_dim 256
num_query_head 64
num_kv_head 16
num_decoder_layers 128
capacity_factor 1.0
Total Params 1.04E+13
Active Params 3.76E+11
MFU (1 pod Trillium) 34.5%
MFU (16 pods Trillium) 26.2%