Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/guides/optimization/sharding.md
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,18 @@ We are actively investing in Multiple Program Multiple Data (`MPMD`) style jax t

### PP + FSDP/DP

Pipelining and FSDP/DP interactions have to be considered together to achieve optimal performance. Generally we want to reduce the gradients across DP replicas only once outside of the pipeline loop as opposed to every microbatch (we want the gradient reduction performed locally across microbatches first and only once across DP replicas). We rely on the XLA compiler for this optimization. Similarly for FSDP we want to all-gather the weights across FSDP only once before the pipeline loop as opposed to every microbatch - we have implemented this in maxtext with `pipeline_fsdp_ag_once` and generally recommend this with small batch sizes. However this comes with a huge memory cost - the weights and gradients are not sharded by FSDP, and thus a significant amount of other sharding (PP, EP, TP) must be used. This is roughly equivalent 0-1 sharding, FSDP only shards the optimizer state, not the weights and gradients.
To achieve optimal performance, PP and FSDP/DP must be co-optimized. As a general rule, gradients should be reduced across DP replicas only *once* outside the pipeline microbatch loop, rather than for each microbatch. The gradients are first reduced locally across microbatches, followed by a single global reduction across DP replicas. MaxText currently relies on the XLA compiler to handle this optimization automatically, and expects to use [JAX explicit sharding feature](https://docs.jax.dev/en/latest/parallel.html#explicit-sharding-mode-makes-sharding-queryable-at-trace-time) for more granular control.

When using FSDP/DP combined with PP, MaxText supports three operational modes:

1. **Default**
2. **Per-Repeat All-Gather (`pipeline_fsdp_ag_per_repeat=True`):** All-gathers weights ahead of each pipeline repeat. *(Note: This is only supported when `num_layers_per_pipeline_stage=1`).*
3. **Single All-Gather (`pipeline_fsdp_ag_once=True`):** All-gathers weights only once ahead of all pipeline repeats.

Moving from Mode 1 to Mode 3 significantly improves computational performance, but at the cost of heavily increased memory usage.

- **For small models:** We recommend starting with `pipeline_fsdp_ag_once=True` (Mode 3). This provides the most efficient pipeline parallelism and should yield the best performance.
- **For large models:** Mode 3 is usually too memory-intensive. To strike the best balance between high performance and avoiding OOM, use `pipeline_fsdp_ag_per_repeat=True` (Mode 2).

### PP Arithmetic Intensity

Expand Down
Loading