You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/optimization/sharding.md
+3-25Lines changed: 3 additions & 25 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -47,11 +47,11 @@ Explanation: Both the activations ($BM$) and weights ($ME$) are sharded on the M
47
47
| $M$ | mlp_dim (aka intermediate dim) |
48
48
| $X$ | expert |
49
49
50
-
Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context and sequence parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like
50
+
Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like
51
51
52
52
$$BSE \times EM = BSM$$
53
53
54
-
But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context/sequence parallelism), thus we only write
54
+
But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context parallelism), thus we only write
55
55
56
56
$$BE \times EM = BM$$
57
57
@@ -113,7 +113,7 @@ $$B_yM_x \times M_xE = B_yE \rightarrow \text{RS over x } \rightarrow B_yE_x $$
This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, context and sequence parallelism are all roughly the same for the purpose of
116
+
This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, and context parallelism are all roughly the same for the purpose of
117
117
arithmetic intensity analysis since they shard the batch, as we will illustrate in the individual sections below. In addition both data and pipeline parallelism (microbatches) shard the batch which decreases the HBM arithmetic intensity.
118
118
119
119
## Code implementation of sharding in MaxText
@@ -270,28 +270,6 @@ The extra cost of all gathering of keys and values is small, especially for long
Sequence parallelism is very similar to context parallelism - we shard the layer inputs and feed forward activations along the sequence dimension. The difference is for attention - we shard the queries, keys, and values along the head dimension instead of sequence dimension (this is fairly MaxText specific, you might not see this in other codebases). This is because the head dimension is easy to shard on for attention (it is not a contracting dimension), and thus can be more efficient than context parallelism as long as there are enough heads. Both sequence parallelism and tensor parallelism shard the heads, so we are constrained by `tensor_parallelism * sequence_parallelism < kv_heads`. E.g. if there are only 8 `kv_heads` as for llama3 and we use `tensor_parallelism=8`, then we cannot use any `sequence_parallelism` (e.g. `sequence_parallelism=1`)
276
-
277
-
Sequence parallelism is currently only supported with TPUs attention kernel, for GPUs we recommend context parallelism above.
278
-
279
-
### SP Arithmetic Intensity
280
-
281
-
The main communications are the same as `FSDP` (all gather weights and synchronize gradients), with an arithmetic intensity of `local_batch` / `sparsity`
282
-
283
-
#### SP Extra A2A cost
284
-
285
-
Sequence parallelism has an additional cost of transferring the sharding from sequence to heads (and back again) for attention. This is executed via and all-to-all which are generally cheap operations, analyzed below:
**Ratio (Arithmetic Intensity)**: Proportional to $seq_len / |SP|$
292
-
293
-
The exact ratio depends on MHA vs GQA, how many kv heads there are and the efficiency of an all-to-all on the given hardware.
294
-
295
273
## Tensor Parallelism (TP)
296
274
297
275
Shard the activations along the feature dimensions (e.g. model or `embed` dimension and intermediate or `mlp` dimension) instead of the batch dimension. Tensor parallelism communicates the activations as opposed to the weights as in DP/FSDP. Tensor parallelism can be used to replace some amount of DP/FSDP when the batch size is small and/or when the model is large (when the `mlp` dim is large). Tensor parallelism is needed to run with small batches, such as fraction `per_device_batch_size` < 1. For instance if we use `TP=4` then we can use the rest with FSDP and set `per_device_batch_size=0.25` since the `global_batch = per_device_batch_size * TP * FSDP = 0.25 * 4 * FSDP = FSDP`, and this is shardable among `FSDP` devices (each device will get a shard of `FSDP/FSDP = 1` of the batch axis in this case). For the attention activations (query, key, value), we shard the heads on `TP` since that is the easiest dimension to shard on and use an attention kernel like flash attention (the heads are not a contracting dimension during the attention computation).
0 commit comments