Choosing efficient sharding strategies is key to achieving good performance, especially at scale. In general there are other related knobs to optimize performance - you should make use of all your HBM (by tuning batch size and rematerialization policies), but here we discuss the various sharding strategies we support in maxtext.
When considering different sharding strategies, the main concern is the amount of communication executed between chips. Different sharding strategies will require different patterns of communication - how often communication is needed and the amount of data needed to communicate. A very helpful tool to understand the performance implications of these communications is arithmetic intensity - which roughly gives the ratio of useful computation to the communication cost. We highly recommend understanding arithmetic intensity if you are serious about optimizing performance - we recommend both the “Jax Train your LLM” document and a MaxText HighPerformanceLLM class (specifically classes 1-4). We briefly describe how to compute arithmetic intensities, and then explain the various sharding strategies we support in maxtext below, starting with some notation.
- Sharding notation
- Arithmetic Intensity: whirlwind introduction example
- Arithmetic Intensity: Mixed sharding strategies
- Code implementation of sharding in MaxText
- Hierarchical Mesh
- Data Parallelism (DP)
- Fully Sharded Data Parallelism (FSDP)
- Fully Sharded Data Parallelism Transpose (FSDP Transpose)
- Context Parallelism (CP)
- Sequence Parallelism (SP)
- Tensor Parallelism (TP)
- Tensor Sequence Parallelism
- Tensor Parallelism Transpose (TP Transpose)
- Expert Parallelism (EP)
- Pipeline Parallelism (PP)
- Context Autoregressive
- Autoregressive
We illustrate our sharding notation with an example matmul:
Where B, E and M are names of dimensions and a subscript denotes sharding. For example,
We illustrate this notation on model parallelism as well:
Explanation: Both the activations (
| Symbol | Description |
|---|---|
| batch (either in tokens or sequences) | |
| sequence | |
| emb_dim (aka model dim) | |
| mlp_dim (aka intermediate dim) | |
| expert |
Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one
But for arithmetic intensity roofline analysis purposes the
We recognize this overloads the definition of
Arithmetic Intensity has a simple definition
Arithmetic Intensity:= Flops / Comms
We will see why this is a useful definition by walking through an example.
We want to be compute bound (because there is a fixed amount of compute to perform), which means we want the compute to take longer than the communication. Consider the above example (model parallelism aka tensor parallelism)
The compute is
Compute time = Flops / compute speed =
The required communication is the RS from bf16 would take
Comm time = comms bytes / comm speed =
We want to be compute bound, so we want:
Compute time > Communication time
Compute Flops / compute speed > Comm bytes / comm speed
Arithmetic Intensity simplifies and generalizes this analysis by re-arranging this inequality to put everything about the model on one side, and everything about the hardware on the other:
Compute Flops / Comm bytes > Compute Speed / comm speed
Operation Arithmetic Intensity > Hardware Arithmetic Intensity
The LHS (Compute Flops / Comm bytes) is the “Operation” or “Model” arithmetic intensity, whereas the RHS (Compute Speed / comm speed) is the hardware arithmetic intensity. This re-arrangement has a huge benefit in that it separates model from hardware - the operational intensity is independent of the hardware. Note however that arithmetic has this funky unity of flops/byte - intuitively you can think of this as the amount of flops unlocked by communicating a certain amount of bytes.
Operation Arithmetic Intensity for this example:
Hardware Arithmetic Intensity: Compute speed / comm speed
Example hardware for trillium (See https://cloud.google.com/tpu/docs/v6e), compute speed =
When we use multiple sharding strategies together it seems intractable to keep track of all of the compute vs communication ratios. However it turns out (not obvious at first), that the arithmetic intensity analysis of a “pure” sharding strategy generalizes to when it's used in a mix. For instance, if we added data parallelism to the above tensor parallelism example then the batch dimension
Compute: =
TP comms (RS) =
Ratio (Arithmetic Intensity) =
This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, context and sequence parallelism are all roughly the same for the purpose of arithmetic intensity analysis since they shard the batch, as we will illustrate in the individual sections below. In addition both data and pipeline parallelism (microbatches) shard the batch which decreases the HBM arithmetic intensity.
Sharding in maxtext is split into 3 layers
-
Physical mesh axes (e.g.
data,fsdp,tensor) defined here-
Mesh is created via create_device_mesh
-
Mesh given names in train.py via Mesh
-
-
Logical axes which map a meaningful axes name to physical axes defined here
- E.g. logical axes
activation_batchis sharded by the physical axes ofdataandfsdp(among others) since those sharding strategies shard the batch.Activation_batchis a common axis among most activation tensors. Note that if we usedata_parallelism=4andfsdp_parallelism=2, then theactivation_batchdimension will get sharded over both, e.g.$4*2=8$ ways.
- E.g. logical axes
-
Individual tensors have sharding constraints - generally specified by logical rules
-
Example for weights using
kernel_axesinMlpBlockhere which in turns relies on flax’s param argumentnn.with_logical_partitioninghere -
For activations we use
nn.with_logical_constraintto give sharding hints for the compiler - here is an example. Sharding hints for the activations is not strictly necessary but the compiler may do funky/inefficient things without these hints.
-
Constructing a hierarchical mesh and specifying shardings is very similar to a “flat” mesh except we use the nice API create_hybrid_device_mesh and specify both the degree of lower level faster network (e.g. TPU ICI) and higher level slower network (e.g. DCN) separately. For example if we want to use 4x fsdp parallelism over ICI and 2x data parallelism over DCN then we specify
mesh = mesh_utils.create_hybrid_device_mesh(
(1,4), # (1 data, 4 fsdp) over ICI
(2,1), # (2 data, 1 fsdp) over DCN
devices,
)
For TPUs this two level hierarchy is (within-slice, across slices) using (ICI, DCN). For v5e and trillium there are at most 256 chips within a slice, whereas for v4, v5p, and the upcoming ironwood can span up to 8k/9k chips within a slice.
For GPUs this two level hierarchy is (within NVL domain, across NVL Domains) using (NVLink, DCN). Starting with Grace Blackwell chips these NVL domains can span multiple hosts (e.g. 72 hosts or 576 chips).
XLA will perform efficient hierarchical collectives (all-gather, all-reduces, reduce-scatters) that communicate the minimal amount of information over the slower upper layer of the network. See the Data Parallel Hierarchal Section for an analysis of these communications.
The simplest parallelization is data parallelization. Each chip works on a different batch of data, and the forward pass is embarrassingly parallel. No communication is needed in the forward pass. The gradients are synchronized in the backward pass (averaged or summed) - which is typically achieved with an all reduce.
Roughly approximate the entire backward pass:
Compute:
We saw above that each matmul performs
Communicate: All reduce size of params (bf16) : 2* since bf16, another 2* since an optimal all reduce algorithm turns out to require two passes of communicating data (generally a reduce scatter followed by an all-gather))
Ratio (arithmetic intensity): local_batch
For an MoE architecture, we can imagine the batch axis is reshaped into [batch_per_expert, expert], where the
batch_per_expert * expert = batch * expert_per_token
e.g. the original activations have grown by a factor of expert_per_token and after reshaping the new batch axis is:
batch_per_expert = batch * (expert_per_token/expert) = batch / sparsity
We denote the local batch_per_expert with
Compute:
Comms: All Reduce Gradient of size
Ratio (arithmetic intensity):
For a hierarchal mesh (TPU: within slice ICI, across slice DCN, GPU: within NVL domain, across NVL Domains), only one set of gradients need to be communicated across the slower network per slice/NVL Domain (as opposed to one set per chip). This is generally achieved for us automatically by the XLA compiler:
Reduce Scatter grads on fast network
We can compute the arithmetic intensity of these cross slice/NVL Domain comms by imagining the chips forming a slice or NVL Domain as one "super chip". This "super chip" processes all of the tokens within its domain, but it only has to share one copy of the gradients to its super chip neighbors.
If the local per device batch size is local batch, then we can imagine each "super chip" has a batch of
super batch = # devices per slice * local batch
We can then perform the same arithmetic intensity analysis as before, and indeed get the same result:
Compute (per super chip):
Comms (per super chip): All reduce params
Ratio (arithmetic intensity):
This illustrates there are more than one way to calculate arithmetic intensity - we could also derive the same expression from the chip level as long as we are consistent for the compute and comms - either both the compute and comms should be at the super chip level, or both should be at the regular chip level.
Similar to data parallelism, except the model weights are also sharded to save memory. Generally the weights must get all-gathered before computation.
In addition to the weights all-gathering, the gradient communications are synchronized in the backward pass similar to DP (optimally will be synchronized with a reduce scatter which is 2x faster than an all-reduce, but only certain sizes of weight matrices allow for this efficient reduce scatter operation). The arithmetic intensity of this grad comm is thus either the same or 2x better than in the DP case, which has an arithmetic intensity of local_batch.
Fully sharded data parallelism (aka zero3) is used when the full model weights do not fit into HBM memory and thus they should be sharded as well. Generally we recommend using FSDP on TPU ICI or GPU NVLINK and DP across slices for TPUs or across hosts for NVLINK, although for large models even more FSDP may be required.
Approximate a typical weight @ activation = activation matmul:
Start with activations sharded like
Compute:
This takes
Note that local_batch.
Communicate: All gather params bf16):
Ratio (arithmetic intensity) local_batch flops/byte (local_batch / sparsity for sparse)
The sparsity factor for sparse models shows up for the same reason as derived in the DP Sparse Section
Note: You may notice that in the DP arithmetic intensity we analyzed the entire backward pass whereas here we analyzed a single matmul. Both approaches should give the same answer, it is useful to understand both ways. Certain shardings are easier to analyze with a global view, whereas others are better analyzed with a local view, it is useful to practice switching between them.
This is nearly identical to FSDP above except we choose to shard the main feedforward weights on the larger mlp dim instead of embed dim. This can be useful when the embed dim cannot be sharded further or does not have enough powers of 2 for efficient reduce scatter algorithms on TPUs. You may try swapping between FSDP and FSDP_transpose, their performance should be very similar, but one may offer a ~1% MFU improvement.
Context parallelism is similar to FSDP except we shard the sequence dimension of activations instead of batch to allow for smaller batch dimensions (correspondingly smaller per device batch, including fractional per device batch sizes). A smaller per device batch dimension is often needed for large sequence lengths so that the activations fit into memory. Also a smaller per device batch size is needed so that the global token count (global batch size) stays under some desired global batch size limit for optimal training - generally smaller global batch sizes can achieve better losses given a fixed number of total tokens (e.g. Llama3 used 16M global batch in tokens, DeepSeek uses 61M).
Care needs to be taken to shard the sequence dimension for attention - only the queries are sharded by sequence, the keys and values need to be all-gathered to perform the full computation. Additionally if we naively shard the sequence dimension then the attention computation is not evenly distributed due to the lower triangular causal mask - shards corresponding to later queries have more non-zero mask and thus become the bottleneck. Instead we “stripe” the inputs, so that the first shard has the first and last chunk of the sequence, the second shard has the second and second to last, etc. This striping is done on the initial data inputs (instead of every layer), so it is a small cost.
Note in general there are many flavors of CP such as ring attention, which in theory can hide all of the comms (as opposed to this implementation where the KV all gathers are probably exposed). This all gather is relatively cheap so we have implementd this flavor for now, a good trade-off of complexity and performance.
Currently Context Parallelism is only supported for GPUs (Sequence parallelism below is supported on TPUs). We plan to land context parallelism on TPUs shortly.
The main communications are the same as FSDP (all gather weights and synchronize gradients), with an arithmetic intensity of local_batch / sparsity.
The extra cost of all gathering of keys and values is small, especially for long sequences, analyzed below assuming group query attention:
Compute: Attention - 4 * batch * seq_len^2 * query_heads * head_dim/|CP|
Communicate (KV all gather): All-gather keys and values - 4 * batch * seq_len * kv_heads * head_dim
Ratio: seq_len * query_heads / (kv_heads * |CP|)
Sequence parallelism is very similar to context parallelism - we shard the layer inputs and feed forward activations along the sequence dimension. The difference is for attention - we shard the queries, keys, and values along the head dimension instead of sequence dimension (this is fairly MaxText specific, you might not see this in other codebases). This is because the head dimension is easy to shard on for attention (it is not a contracting dimension), and thus can be more efficient than context parallelism as long as there are enough heads. Both sequence parallelism and tensor parallelism shard the heads, so we are constrained by tensor_parallelism * sequence_parallelism < kv_heads. E.g. if there are only 8 kv_heads as for llama3 and we use tensor_parallelism=8, then we cannot use any sequence_parallelism (e.g. sequence_parallelism=1)
Sequence parallelism is currently only supported with TPUs attention kernel, for GPUs we recommend context parallelism above.
The main communications are the same as FSDP (all gather weights and synchronize gradients), with an arithmetic intensity of local_batch / sparsity
Sequence parallelism has an additional cost of transferring the sharding from sequence to heads (and back again) for attention. This is executed via and all-to-all which are generally cheap operations, analyzed below:
Compute: Attention (4 * batch * seq_len^2 * heads * head_dim \ |SP|)
Communicate: A2A QKV activations and output activations (roughly 4 * batch * seq_len * heads * head_dim)
Ratio (Arithmetic Intensity): Proportional to seq_len / |SP|
The exact ratio depends on MHA vs GQA, how many kv heads there are and the efficiency of an all-to-all on the given hardware.
Shard the activations along the feature dimensions (e.g. model or embed dimension and intermediate or mlp dimension) instead of the batch dimension. Tensor parallelism communicates the activations as opposed to the weights as in DP/FSDP. Tensor parallelism can be used to replace some amount of DP/FSDP when the batch size is small and/or when the model is large (when the mlp dim is large). Tensor parallelism is needed to run with small batches, such as fraction per_device_batch_size < 1. For instance if we use TP=4 then we can use the rest with FSDP and set per_device_batch_size=0.25 since the global_batch = per_device_batch_size * TP * FSDP = 0.25 * 4 * FSDP = FSDP, and this is shardable among FSDP devices (each device will get a shard of FSDP/FSDP = 1 of the batch axis in this case). For the attention activations (query, key, value), we shard the heads on TP since that is the easiest dimension to shard on and use an attention kernel like flash attention (the heads are not a contracting dimension during the attention computation).
Analyze one pattern of TP as given above
Compute:
Communicate: Reduce scatter bf16):
Ratio (arithmetic intensity)
Note this is one pattern of TP where the contracting dimension is sharded. By contrast for the initial feed forward matmul the non-contracting weight dimension is sharded:
This is the same amount of compute, and also the same amount of communication - again activations of
This sharding strategy is very similar to tensor parallelism, except we shard the initial feed forward (FF) activations on the sequence dimension as opposed to the model dimension. The activations have to get all-gathered at the start of the FF and reduce-scattered at the end, but it's the same amount of total comms, just a different axis (see above analysis for TP). The intermediate activations of shape [batch, sequence, mlp] are still sharded by mlp (since the weights are sharded on mlp). The benefits are explained in more detail in this paper, TL;DR is that all-reduces for small normalizations are not needed since the feature dimension is not sharded with TP sequence as opposed to when its sharded with regular TP. This is generally recommended for GPUs over tensor parallelism. See PR #1136 which introduces this parallelism.
Near identical to tensor parallelism above except a different axis gets all-gathered and reduce-scattered on: thus MLP/TP
Similar to tensor parallelism, but instead of sharding the feed forward weights along the mlp_dim, shard them along the embed_dim. This will require communicating activations of the mlp_dim as opposed to the embed_dim, and thus is useful when the mlp_dim < embed_dim which is unusual but is true for some models such as DeepSeek V3.
TP and TP transpose can be used together called "2D TP" which can be more efficient than using purely one of them for inference decoding, although this is still a work in progress/largely untested.
This is really just swapping
Compute:
Communicate: Reduce scatter bf16):
Ratio (arithmetic intensity):
Shard expert feed forward computation (both weights and activations) by expert!
The feedforward layer is the only one that has experts - for this layer we shard the weights and the activations on the experts dimensions by EP. For attention operations (including projections) the EP dimension acts like FSDP. This is a choice by MaxText, we may implement more options in the future where instead EP could act like DP or CP/SP as well.
When using dropless strategies you may want to ensure that the shards are balanced. The balance can be improved by using less EP so that each shard is averaged over more experts. For instance imagine a scenario where expert 1 gets 10x more tokens routed to it than the rest. If EP = # experts = 64 than we will get terrible performance waiting for this one expert to finish its computation which is 3x slower. However if we set EP = 1/4 * # experts than the EP rank with expert 1 will have 4 experts, so we will have 3 + 1 + 1 + 1 = 6 compute to do compared to the average of 1 + 1 + 1 + 1 = 4, a ratio of 6/4 = 1.5x slower, which is a huge improvement over the 3x slower.
An all-to-all (A2A) is needed to move between data sharding (fsdp) prior to the feed forward and the expert sharding during the feed forward. We denote
Compute
Analyze only 1 feed forward matmul
Communicate
Ideally this A2A only requires moving around GPUs and TPU DCN but not for TPU ICI)
With a true all-to-all network this takes 1/4 of all gathering the entire activation as nicely drawn here in jax's sharding doc.
Ratio (arithmetic intensity):
Note: The batch batch or batch_per_exp)
Shard the weights and computation by layers. There are many flavors of pipelining, MaxText current supports gPipe and circular pipelines, which are discussed below
Pipeline parallelism is generally needed when the per_device_batch size is too small for data parallelism to be efficient. Recall above the arithmetic intensity of data parallelism is given by the local_batch/sparsity, so when this becomes too small then the communications associated with data parallelism will be very costly. This occurs either for very sparse models (e.g. DeepSeek), or when scaling to a large number of chips and maintaining a fixed global batch size (and thus the per device batch size is small).
gPipe style pipelining (reference) shards layers across stages, where each stage can have multiple layers. E.g. if there are four stages and twelve layers, stage 0 will perform layers 0, 1, and 2, then pass the results to stage 1 which will perform layers 3, 4, and 5, etc. Naively implemented this isn’t parallel since stage 1 has to wait for stage 0 to finish, however we can break the batch into microbatches to enable parallelism. E.g. as stage 1 works on microbatch 0, stage 0 can start working on a new microbatch 1. There is still a “bubble” - an amount of time each stage is idle while either waiting for the first microbatch or once it has finished all of its microbatches. This “bubble” time goes down with the amount of microbatches:
Bubble = (PP - 1) / (Microbatches + PP - 1)
Circular pipelining also shards layers across stages, but the layers “wrap” back around. E.g. if we have 24 layers, 4 stages, and 2 repeats, then stage 0 will perform layers 0, 1, 2 and also layers 12, 13, 14. Stage 1 will perform layers 3, 4, 5 and also 15, 16, 17 etc. This pattern helps to reduce the bubble: stage 1 is able to start its set of layers earlier (only need to wait for a microbatch to finish 3 layers instead of 6 since there are two repeats).
Bubble = (PP - 1) / (repeats * Microbatches + PP - 1)
There is a tradeoff of using many repeats - more repeats creates a schedule with a smaller bubble, however it also requires more PP comms between stages. The limiting case repeats=1 is a gPipe schedule with minimal communication overhead, but maximal bubble. Ideally the PP comms are overlapped as long as there is enough compute, however achieving overlap is a challenging problem for the compiler. To break the data dependency of the circular transfer (last stage to first), the number of microbatches must exceed the number of stages, and thus we generally recommend num_pipeline_microbatches = 2 * PP.
We are actively investing in Multiple Program Multiple Data (MPMD) style jax to support fancier pipeline schedules such as 1F1B and dualpipe which can achieve smaller bubbles while using less PP comms. Currently we only support gPipe and circular pipelines.
Pipelining and FSDP/DP interactions have to be considered together to achieve optimal performance. Generally we want to reduce the gradients across DP replicas only once outside of the pipeline loop as opposed to every microbatch (we want the gradient reduction performed locally across microbatches first and only once across DP replicas). We rely on the XLA compiler for this optimization. Similarly for FSDP we want to all-gather the weights across FSDP only once before the pipeline loop as opposed to every microbatch - we have implemented this in maxtext with pipeline_fsdp_ag_once and generally recommend this with small batch sizes. However this comes with a huge memory cost - the weights and gradients are not sharded by FSDP, and thus a significant amount of other sharding (PP, EP, TP) must be used. This is roughly equivalent 0-1 sharding, FSDP only shards the optimizer state, not the weights and gradients.
The arithmetic intensity is a bit harder to define for PP, and depends on the pipeline flavor. We analyze the circular pipeline below.
Compute
One stage worth. A stage can consist of multiple layers, if layers_per_pipeline_stage > 1. Each layer generally is a combination of a fully connected feed forward block and an attention block. Let's ignore attention since it's generally significantly smaller than the FF (for sequence length of 8k). A typical FF has 3 matmuls (2 in for silu, 1 out), for a total of layers_per_pipeline_stage * 6 * B * M * E flops
Communicate
The layer outputs between stages of size
Ratio (arithmetic intensity)
3/2 * layers_per_pipeline_stage * M * experts_per_token
Note that for MoE models, this arithmetic intensity grows by a factor of experts_per_token since the compute grows by this factor, but the communication is independent of this factor.
Context Autoregressive shards the KV cache on the sequence dimension. It shards feed forward layer by experts for both activations and weights. This is used for inference only, see inference.yml for the modified logical axis rules for inference.
Autoregressive shards weights, but not activations. This is used for inference only. See inference.yml for the modified logical axis rules for inference.