Skip to content

Pipeline parallel: memory-proportional splitting and inference sync#1137

Open
qubitcontracting wants to merge 4 commits intoml-explore:mainfrom
qubitcontracting:pipeline-heterogeneous
Open

Pipeline parallel: memory-proportional splitting and inference sync#1137
qubitcontracting wants to merge 4 commits intoml-explore:mainfrom
qubitcontracting:pipeline-heterogeneous

Conversation

@qubitcontracting
Copy link
Copy Markdown

@qubitcontracting qubitcontracting commented Apr 9, 2026

Summary

Adds heterogeneous pipeline parallel support for distributed inference across nodes with different memory capacities.

Changes

mlx_lm/models/pipeline.py

  • Memory-proportional layer splitting based on each node's Metal working set size
  • pipeline_warmup() method: compiles Metal shaders incrementally to avoid GPU command buffer timeout on cold start. Phase 1 compiles locally (no distributed ops), Phase 2 runs full pipeline pass with pre-compiled shaders.
  • Falls back to equal splitting when memory info unavailable

mlx_lm/generate.py

  • _pipeline_sync() function returns an all_sum collective for distributed sync
  • Prefill loop: evaluates logits + sync (not just cache states) so all ranks participate
  • Generation loop: uses mx.eval with sync instead of mx.async_eval for pipeline

Why warmup is needed

On asymmetric splits (e.g. 80/14 layers), the first cold model() call builds a massive compute graph. Metal's command buffer timeout fires before shader compilation finishes. pipeline_warmup() runs layers in chunks of 5 locally first, then a full pipeline pass — all fast because shaders are pre-compiled. Adds ~10s one-time startup cost.

Testing

  • 2-node (256GB + 64GB): 80/14 split, Qwen3-235B-A22B 8-bit, 17.8 tok/s
  • 3-node (256GB + 64GB + 48GB): 70/14/10 split, 11.3 tok/s
  • Both tested from fresh reboot via automated startup script
  • mlx-lm 0.31.1 and 0.31.2 verified

The existing PipelineMixin splits layers equally across ranks. On
heterogeneous clusters (e.g. 256GB + 64GB + 48GB Mac nodes over
Thunderbolt 5 RDMA), equal splitting either wastes large nodes or
OOMs small ones.

This adds:

1. Memory-proportional layer splitting in PipelineMixin.pipeline()
   - Queries each node's Metal working set via mx.metal.device_info()
   - Accounts for per-layer compute overhead (KV cache, attention
     scores, MoE activations) to avoid OOM
   - Falls back to equal splitting when psutil is unavailable

2. Pipeline sync in generate_step() to prevent deadlocks
   - In pipeline parallel, each rank runs a subset of layers. The
     model's forward pass contains distributed send/recv ops as lazy
     graph nodes. Without a collective sync at each mx.eval, rank 0
     may not trigger the ops needed by other ranks, causing a deadlock.
   - Adds _pipeline_sync() helper that returns an all_sum collective
   - Patches prefill eval to evaluate logits (which contain the
     distributed ops in their graph) rather than cache states alone

Tested on a 3-node JACCL cluster (Mac Studio M3 Ultra 256GB + 2x Mac
mini M4 Pro 64GB/48GB) running Qwen3-235B-A22B (94 layers split
65/16/13) and MiniMax-M2.5.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When sharded_load calls pipeline() on a lazy model, each rank may
have different weight files present locally. This causes
tree_flatten(self.parameters()) to report different total bytes
per rank, leading to inconsistent layer splits.

Fix: share local_model_bytes via all_gather and use rank 0's value
as the reference, since rank 0 always has access to the full model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@qubitcontracting qubitcontracting marked this pull request as ready for review April 9, 2026 13:53
Thomas and others added 2 commits April 11, 2026 00:34
On asymmetric pipeline splits (e.g. 80/14 layers across 2 nodes),
the first forward pass builds a compute graph spanning all local
layers plus distributed ops. Metal's command buffer timeout fires
before shader compilation finishes.

pipeline_warmup() fixes this by:
1. Running layers locally in chunks of 5 (no distributed ops) to
   compile each layer's Metal shaders independently
2. Running a full pipeline forward pass (recv/send/all_gather) which
   succeeds because all shaders are already compiled

Called automatically by sharded_load() after weight materialization.
Adds ~10s one-time startup cost. Tested on 2-node (80/14, 17.9 tok/s)
and 3-node (70/14/10, 11.4 tok/s) with Qwen3-235B-A22B 8-bit.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
BatchGenerator does not include the pipeline sync collectives
needed for distributed inference. When pipeline_group is set,
force is_batchable = False so the server uses the single-request
stream_generate path which has proper pipeline sync.

Without this, the server's first inference request causes a GPU
timeout — the fast rank's collective op waits indefinitely for
the slow rank which is in a different eval path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant