Skip to content

Commit e0f17a0

Browse files
committed
deprecate sequence axis
1 parent 586e692 commit e0f17a0

10 files changed

Lines changed: 65 additions & 1298 deletions

File tree

docs/guides/optimization/sharding.md

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ Explanation: Both the activations ($BM$) and weights ($ME$) are sharded on the M
4747
| $M$ | mlp_dim (aka intermediate dim) |
4848
| $X$ | expert |
4949

50-
Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context and sequence parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like
50+
Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like
5151

5252
$$BSE \times EM = BSM$$
5353

54-
But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context/sequence parallelism), thus we only write
54+
But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context parallelism), thus we only write
5555

5656
$$BE \times EM = BM$$
5757

@@ -113,7 +113,7 @@ $$B_yM_x \times M_xE = B_yE \rightarrow \text{RS over x } \rightarrow B_yE_x $$
113113

114114
**Ratio (Arithmetic Intensity)** = $|M_x|$ Flops/byte
115115

116-
This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, context and sequence parallelism are all roughly the same for the purpose of
116+
This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, and context parallelism are all roughly the same for the purpose of
117117
arithmetic intensity analysis since they shard the batch, as we will illustrate in the individual sections below. In addition both data and pipeline parallelism (microbatches) shard the batch which decreases the HBM arithmetic intensity.
118118

119119
## Code implementation of sharding in MaxText
@@ -270,28 +270,6 @@ The extra cost of all gathering of keys and values is small, especially for long
270270

271271
**Ratio**: `seq_len * query_heads / (kv_heads * |CP|)`
272272

273-
## Sequence Parallelism (SP)
274-
275-
Sequence parallelism is very similar to context parallelism - we shard the layer inputs and feed forward activations along the sequence dimension. The difference is for attention - we shard the queries, keys, and values along the head dimension instead of sequence dimension (this is fairly MaxText specific, you might not see this in other codebases). This is because the head dimension is easy to shard on for attention (it is not a contracting dimension), and thus can be more efficient than context parallelism as long as there are enough heads. Both sequence parallelism and tensor parallelism shard the heads, so we are constrained by `tensor_parallelism * sequence_parallelism < kv_heads`. E.g. if there are only 8 `kv_heads` as for llama3 and we use `tensor_parallelism=8`, then we cannot use any `sequence_parallelism` (e.g. `sequence_parallelism=1`)
276-
277-
Sequence parallelism is currently only supported with TPUs attention kernel, for GPUs we recommend context parallelism above.
278-
279-
### SP Arithmetic Intensity
280-
281-
The main communications are the same as `FSDP` (all gather weights and synchronize gradients), with an arithmetic intensity of `local_batch` / `sparsity`
282-
283-
#### SP Extra A2A cost
284-
285-
Sequence parallelism has an additional cost of transferring the sharding from sequence to heads (and back again) for attention. This is executed via and all-to-all which are generally cheap operations, analyzed below:
286-
287-
**Compute**: Attention ($4 * batch * seq_len^2 * heads * head_dim / |SP|$)
288-
289-
**Communicate:** A2A QKV activations and output activations (roughly $4 * batch * seq_len * heads * head_dim$)
290-
291-
**Ratio (Arithmetic Intensity)**: Proportional to $seq_len / |SP|$
292-
293-
The exact ratio depends on MHA vs GQA, how many kv heads there are and the efficiency of an all-to-all on the given hardware.
294-
295273
## Tensor Parallelism (TP)
296274

297275
Shard the activations along the feature dimensions (e.g. model or `embed` dimension and intermediate or `mlp` dimension) instead of the batch dimension. Tensor parallelism communicates the activations as opposed to the weights as in DP/FSDP. Tensor parallelism can be used to replace some amount of DP/FSDP when the batch size is small and/or when the model is large (when the `mlp` dim is large). Tensor parallelism is needed to run with small batches, such as fraction `per_device_batch_size` < 1. For instance if we use `TP=4` then we can use the rest with FSDP and set `per_device_batch_size=0.25` since the `global_batch = per_device_batch_size * TP * FSDP = 0.25 * 4 * FSDP = FSDP`, and this is shardable among `FSDP` devices (each device will get a shard of `FSDP/FSDP = 1` of the batch axis in this case). For the attention activations (query, key, value), we shard the heads on `TP` since that is the easiest dimension to shard on and use an attention kernel like flash attention (the heads are not a contracting dimension during the attention computation).

src/maxtext/configs/base.yml

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -445,29 +445,27 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
445445
# Parallelism
446446
shard_mode: "auto" # can be either auto or explicit
447447
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
448-
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
448+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
449449
logical_axis_rules: [
450450
# ==========================================
451451
# Vocabulary Embedding
452452
# ==========================================
453453
# Vocab Activations
454454
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
455-
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
455+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
456456
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
457457
['activation_vocab', ['tensor', 'tensor_transpose']],
458458
['activation_vocab', 'tensor_sequence'],
459-
['activation_vocab', ['sequence', 'context']],
460459
# Vocab Weights
461460
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
462-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
463462
# ==========================================
464463
# Attention
465464
# ==========================================
466465
# Attention Activations
467466
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
468-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
469-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
470-
['activation_length_attn', ['sequence', 'context']],
467+
['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
468+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
471469
['activation_length_attn', ['context']],
472470
['activation_q_length', ['context']],
473471
['activation_kv_length', []],
@@ -482,52 +480,50 @@ logical_axis_rules: [
482480
['qkv', []],
483481
['kv', []],
484482
['kv_head_dim', []],
485-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
486-
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
487-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
488-
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
483+
['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
484+
['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
485+
['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
486+
['q_lora', ['fsdp', 'context', 'expert']],
489487
["q_lora_up_proj", []],
490-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
491-
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
492-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
493-
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
488+
['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
489+
['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
490+
['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
491+
['kv_lora', ['fsdp', 'context', 'expert']],
494492
["kv_lora_up_proj", []],
495493
# ==========================================
496494
# Mixture of Experts (MoE)
497495
# ==========================================
498496
# MoE Activations
499497
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
500-
['activation_length_moe', ['sequence', 'context']],
501498
['activation_length_moe', ['context']],
502-
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
499+
['activation_norm_length_moe', ['tensor_sequence', 'context']],
503500
['activation_embed_moe', ['tensor', 'tensor_transpose']],
504501
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
505502
['activation_exp', ['expert']],
506503
# MoE Weights
507504
['exp', 'expert'],
508505
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
509-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
510-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
511-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
512-
['embed_moe', ['fsdp', 'sequence', 'context']],
506+
['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']],
507+
['embed_moe', ['fsdp', 'tensor_transpose', 'context']],
508+
['embed_moe', ['fsdp', 'fsdp_transpose', 'context']],
509+
['embed_moe', ['fsdp', 'context']],
513510
# ==========================================
514511
# Standard MLP / Dense Layers / Model Structure
515512
# ==========================================
516513
# Dense Activations
517514
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
518515
# Note activation batch and length also get used in vocab
519516
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
520-
['activation_length', ['sequence', 'context']],
521517
['activation_length', ['context']],
522-
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
518+
['activation_norm_length', ['tensor_sequence', 'context']],
523519
['activation_embed', ['tensor', 'tensor_transpose']],
524520
['activation_stage', 'stage'],
525521
# General Weights
526522
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
527-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
528-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
529-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
530-
['embed', ['fsdp', 'sequence', 'context', 'expert']],
523+
['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],
524+
['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']],
525+
['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
526+
['embed', ['fsdp', 'context', 'expert']],
531527
['norm', ['tensor', 'tensor_transpose']],
532528
['layers', 'stage'],
533529
['diloco', 'diloco'],
@@ -538,11 +534,11 @@ logical_axis_rules: [
538534
# ==========================================
539535
# Inference(Prefill, Decode, Cache)
540536
# ==========================================
541-
['prefill_activation_length', ['sequence', 'context']],
542-
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
537+
['prefill_activation_length', ['context']],
538+
['prefill_activation_norm_length', ['tensor_sequence', 'context']],
543539
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
544540
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
545-
['decode_length', ['sequence']],
541+
['decode_length', []],
546542
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
547543
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
548544
['paged_kv_heads', ['tensor']],
@@ -562,7 +558,7 @@ logical_axis_rules: [
562558
['exp_with_fsdp', 'fsdp'],
563559
]
564560
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
565-
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
561+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
566562
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
567563
# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
568564
# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp)

0 commit comments

Comments
 (0)