Skip to content

Commit 5e96345

Browse files
Merge pull request #3655 from AI-Hypercomputer:chengnuojin-update-ppdoc
PiperOrigin-RevId: 899193435
2 parents 6bd68a5 + 36039d0 commit 5e96345

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

docs/guides/optimization/sharding.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,18 @@ We are actively investing in Multiple Program Multiple Data (`MPMD`) style jax t
402402

403403
### PP + FSDP/DP
404404

405-
Pipelining and FSDP/DP interactions have to be considered together to achieve optimal performance. Generally we want to reduce the gradients across DP replicas only once outside of the pipeline loop as opposed to every microbatch (we want the gradient reduction performed locally across microbatches first and only once across DP replicas). We rely on the XLA compiler for this optimization. Similarly for FSDP we want to all-gather the weights across FSDP only once before the pipeline loop as opposed to every microbatch - we have implemented this in maxtext with `pipeline_fsdp_ag_once` and generally recommend this with small batch sizes. However this comes with a huge memory cost - the weights and gradients are not sharded by FSDP, and thus a significant amount of other sharding (PP, EP, TP) must be used. This is roughly equivalent 0-1 sharding, FSDP only shards the optimizer state, not the weights and gradients.
405+
To achieve optimal performance, PP and FSDP/DP must be co-optimized. As a general rule, gradients should be reduced across DP replicas only *once* outside the pipeline microbatch loop, rather than for each microbatch. The gradients are first reduced locally across microbatches, followed by a single global reduction across DP replicas. MaxText currently relies on the XLA compiler to handle this optimization automatically, and expects to use [JAX explicit sharding feature](https://docs.jax.dev/en/latest/parallel.html#explicit-sharding-mode-makes-sharding-queryable-at-trace-time) for more granular control.
406+
407+
When using FSDP/DP combined with PP, MaxText supports three operational modes:
408+
409+
1. **Default**
410+
2. **Per-Repeat All-Gather (`pipeline_fsdp_ag_per_repeat=True`):** All-gathers weights ahead of each pipeline repeat. *(Note: This is only supported when `num_layers_per_pipeline_stage=1`).*
411+
3. **Single All-Gather (`pipeline_fsdp_ag_once=True`):** All-gathers weights only once ahead of all pipeline repeats.
412+
413+
Moving from Mode 1 to Mode 3 significantly improves computational performance, but at the cost of heavily increased memory usage.
414+
415+
- **For small models:** We recommend starting with `pipeline_fsdp_ag_once=True` (Mode 3). This provides the most efficient pipeline parallelism and should yield the best performance.
416+
- **For large models:** Mode 3 is usually too memory-intensive. To strike the best balance between high performance and avoiding OOM, use `pipeline_fsdp_ag_per_repeat=True` (Mode 2).
406417

407418
### PP Arithmetic Intensity
408419

0 commit comments

Comments
 (0)