Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 3 additions & 25 deletions docs/guides/optimization/sharding.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ Explanation: Both the activations ($BM$) and weights ($ME$) are sharded on the M
| $M$ | mlp_dim (aka intermediate dim) |
| $X$ | expert |

Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context and sequence parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like
Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like

$$BSE \times EM = BSM$$

But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context/sequence parallelism), thus we only write
But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context parallelism), thus we only write

$$BE \times EM = BM$$

Expand Down Expand Up @@ -113,7 +113,7 @@ $$B_yM_x \times M_xE = B_yE \rightarrow \text{RS over x } \rightarrow B_yE_x $$

**Ratio (Arithmetic Intensity)** = $|M_x|$ Flops/byte

This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, context and sequence parallelism are all roughly the same for the purpose of
This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, and context parallelism are all roughly the same for the purpose of
arithmetic intensity analysis since they shard the batch, as we will illustrate in the individual sections below. In addition both data and pipeline parallelism (microbatches) shard the batch which decreases the HBM arithmetic intensity.

## Code implementation of sharding in MaxText
Expand Down Expand Up @@ -270,28 +270,6 @@ The extra cost of all gathering of keys and values is small, especially for long

**Ratio**: `seq_len * query_heads / (kv_heads * |CP|)`

## Sequence Parallelism (SP)

Sequence parallelism is very similar to context parallelism - we shard the layer inputs and feed forward activations along the sequence dimension. The difference is for attention - we shard the queries, keys, and values along the head dimension instead of sequence dimension (this is fairly MaxText specific, you might not see this in other codebases). This is because the head dimension is easy to shard on for attention (it is not a contracting dimension), and thus can be more efficient than context parallelism as long as there are enough heads. Both sequence parallelism and tensor parallelism shard the heads, so we are constrained by `tensor_parallelism * sequence_parallelism < kv_heads`. E.g. if there are only 8 `kv_heads` as for llama3 and we use `tensor_parallelism=8`, then we cannot use any `sequence_parallelism` (e.g. `sequence_parallelism=1`)

Sequence parallelism is currently only supported with TPUs attention kernel, for GPUs we recommend context parallelism above.

### SP Arithmetic Intensity

The main communications are the same as `FSDP` (all gather weights and synchronize gradients), with an arithmetic intensity of `local_batch` / `sparsity`

#### SP Extra A2A cost

Sequence parallelism has an additional cost of transferring the sharding from sequence to heads (and back again) for attention. This is executed via and all-to-all which are generally cheap operations, analyzed below:

**Compute**: Attention ($4 * batch * seq_len^2 * heads * head_dim / |SP|$)

**Communicate:** A2A QKV activations and output activations (roughly $4 * batch * seq_len * heads * head_dim$)

**Ratio (Arithmetic Intensity)**: Proportional to $seq_len / |SP|$

The exact ratio depends on MHA vs GQA, how many kv heads there are and the efficiency of an all-to-all on the given hardware.

## Tensor Parallelism (TP)

Shard the activations along the feature dimensions (e.g. model or `embed` dimension and intermediate or `mlp` dimension) instead of the batch dimension. Tensor parallelism communicates the activations as opposed to the weights as in DP/FSDP. Tensor parallelism can be used to replace some amount of DP/FSDP when the batch size is small and/or when the model is large (when the `mlp` dim is large). Tensor parallelism is needed to run with small batches, such as fraction `per_device_batch_size` < 1. For instance if we use `TP=4` then we can use the rest with FSDP and set `per_device_batch_size=0.25` since the `global_batch = per_device_batch_size * TP * FSDP = 0.25 * 4 * FSDP = FSDP`, and this is shardable among `FSDP` devices (each device will get a shard of `FSDP/FSDP = 1` of the batch axis in this case). For the attention activations (query, key, value), we shard the heads on `TP` since that is the easiest dimension to shard on and use an attention kernel like flash attention (the heads are not a contracting dimension during the attention computation).
Expand Down
58 changes: 27 additions & 31 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -445,29 +445,27 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
# Parallelism
shard_mode: "auto" # can be either auto or explicit
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
# ==========================================
# Vocabulary Embedding
# ==========================================
# Vocab Activations
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose']],
['activation_vocab', 'tensor_sequence'],
['activation_vocab', ['sequence', 'context']],
# Vocab Weights
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
# ==========================================
# Attention
# ==========================================
# Attention Activations
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
['activation_length_attn', ['sequence', 'context']],
['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_length_attn', ['context']],
['activation_q_length', ['context']],
['activation_kv_length', []],
Expand All @@ -482,52 +480,50 @@ logical_axis_rules: [
['qkv', []],
['kv', []],
['kv_head_dim', []],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
['q_lora', ['fsdp', 'context', 'expert']],
["q_lora_up_proj", []],
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
['kv_lora', ['fsdp', 'context', 'expert']],
["kv_lora_up_proj", []],
# ==========================================
# Mixture of Experts (MoE)
# ==========================================
# MoE Activations
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_length_moe', ['sequence', 'context']],
['activation_length_moe', ['context']],
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
['activation_norm_length_moe', ['tensor_sequence', 'context']],
['activation_embed_moe', ['tensor', 'tensor_transpose']],
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_exp', ['expert']],
# MoE Weights
['exp', 'expert'],
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_moe', ['fsdp', 'sequence', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'context']],
['embed_moe', ['fsdp', 'context']],
# ==========================================
# Standard MLP / Dense Layers / Model Structure
# ==========================================
# Dense Activations
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
# Note activation batch and length also get used in vocab
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_norm_length', ['tensor_sequence', 'context']],
['activation_embed', ['tensor', 'tensor_transpose']],
['activation_stage', 'stage'],
# General Weights
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
['embed', ['fsdp', 'context', 'expert']],
['norm', ['tensor', 'tensor_transpose']],
['layers', 'stage'],
['diloco', 'diloco'],
Expand All @@ -538,11 +534,11 @@ logical_axis_rules: [
# ==========================================
# Inference(Prefill, Decode, Cache)
# ==========================================
['prefill_activation_length', ['sequence', 'context']],
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['prefill_activation_length', ['context']],
['prefill_activation_norm_length', ['tensor_sequence', 'context']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_length', ['sequence']],
['decode_length', []],
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
['paged_kv_heads', ['tensor']],
Expand All @@ -562,7 +558,7 @@ logical_axis_rules: [
['exp_with_fsdp', 'fsdp'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp)
Expand Down
Loading
Loading