Pipeline parallel: memory-proportional splitting and inference sync#1137
Open
qubitcontracting wants to merge 4 commits intoml-explore:mainfrom
Open
Pipeline parallel: memory-proportional splitting and inference sync#1137qubitcontracting wants to merge 4 commits intoml-explore:mainfrom
qubitcontracting wants to merge 4 commits intoml-explore:mainfrom
Conversation
The existing PipelineMixin splits layers equally across ranks. On
heterogeneous clusters (e.g. 256GB + 64GB + 48GB Mac nodes over
Thunderbolt 5 RDMA), equal splitting either wastes large nodes or
OOMs small ones.
This adds:
1. Memory-proportional layer splitting in PipelineMixin.pipeline()
- Queries each node's Metal working set via mx.metal.device_info()
- Accounts for per-layer compute overhead (KV cache, attention
scores, MoE activations) to avoid OOM
- Falls back to equal splitting when psutil is unavailable
2. Pipeline sync in generate_step() to prevent deadlocks
- In pipeline parallel, each rank runs a subset of layers. The
model's forward pass contains distributed send/recv ops as lazy
graph nodes. Without a collective sync at each mx.eval, rank 0
may not trigger the ops needed by other ranks, causing a deadlock.
- Adds _pipeline_sync() helper that returns an all_sum collective
- Patches prefill eval to evaluate logits (which contain the
distributed ops in their graph) rather than cache states alone
Tested on a 3-node JACCL cluster (Mac Studio M3 Ultra 256GB + 2x Mac
mini M4 Pro 64GB/48GB) running Qwen3-235B-A22B (94 layers split
65/16/13) and MiniMax-M2.5.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When sharded_load calls pipeline() on a lazy model, each rank may have different weight files present locally. This causes tree_flatten(self.parameters()) to report different total bytes per rank, leading to inconsistent layer splits. Fix: share local_model_bytes via all_gather and use rank 0's value as the reference, since rank 0 always has access to the full model. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
On asymmetric pipeline splits (e.g. 80/14 layers across 2 nodes), the first forward pass builds a compute graph spanning all local layers plus distributed ops. Metal's command buffer timeout fires before shader compilation finishes. pipeline_warmup() fixes this by: 1. Running layers locally in chunks of 5 (no distributed ops) to compile each layer's Metal shaders independently 2. Running a full pipeline forward pass (recv/send/all_gather) which succeeds because all shaders are already compiled Called automatically by sharded_load() after weight materialization. Adds ~10s one-time startup cost. Tested on 2-node (80/14, 17.9 tok/s) and 3-node (70/14/10, 11.4 tok/s) with Qwen3-235B-A22B 8-bit. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
BatchGenerator does not include the pipeline sync collectives needed for distributed inference. When pipeline_group is set, force is_batchable = False so the server uses the single-request stream_generate path which has proper pipeline sync. Without this, the server's first inference request causes a GPU timeout — the fast rank's collective op waits indefinitely for the slow rank which is in a different eval path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds heterogeneous pipeline parallel support for distributed inference across nodes with different memory capacities.
Changes
mlx_lm/models/pipeline.pypipeline_warmup()method: compiles Metal shaders incrementally to avoid GPU command buffer timeout on cold start. Phase 1 compiles locally (no distributed ops), Phase 2 runs full pipeline pass with pre-compiled shaders.mlx_lm/generate.py_pipeline_sync()function returns anall_sumcollective for distributed syncmx.evalwith sync instead ofmx.async_evalfor pipelineWhy warmup is needed
On asymmetric splits (e.g. 80/14 layers), the first cold
model()call builds a massive compute graph. Metal's command buffer timeout fires before shader compilation finishes.pipeline_warmup()runs layers in chunks of 5 locally first, then a full pipeline pass — all fast because shaders are pre-compiled. Adds ~10s one-time startup cost.Testing