Skip to content

Commit 70fef4f

Browse files
TimDettmersclaude
andcommitted
feat: Add 1F1B pipeline parallelism engine
Implements a custom one-forward-one-backward pipeline schedule for training across multiple stages. Key design: - generate_1f1b_schedule: produces per-stage operation sequences with warmup (S-1-s forwards), steady state (interleaved F/B), and cooldown - PipelineEngine: single-process execution with correct dependency ordering — forwards left-to-right, backwards right-to-left - SequentialStage: generic wrapper for composing model layers into stages - split_model_layers: even distribution of layers across stages Non-last stages use B-before-F ordering in steady state to bound in-flight micro-batches. Last stage uses F-before-B since it must receive activations before computing backward. 14 tests: schedule generation (coverage, ordering, warmup counts, bounded in-flight), gradient correctness (2/3/4 stages), loss matching, nonlinear models, multiple training steps. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 47e34dd commit 70fef4f

File tree

2 files changed

+775
-0
lines changed

2 files changed

+775
-0
lines changed

bitsandbytes/pipeline.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
"""1F1B Pipeline Parallelism Engine.
2+
3+
Custom implementation of one-forward-one-backward pipeline schedule for
4+
training large models across multiple stages. Each stage processes a
5+
subset of model layers.
6+
7+
Supports:
8+
- Single-process mode: all stages on one GPU (for testing)
9+
- Multi-process NCCL mode: one stage per GPU, activation transfer via isend/irecv
10+
11+
The 1F1B schedule minimizes peak activation memory by interleaving forward
12+
and backward passes, keeping at most (num_stages) micro-batches in flight.
13+
"""
14+
15+
import torch
16+
import torch.nn as nn
17+
18+
19+
def generate_1f1b_schedule(num_stages, num_micro_batches):
20+
"""Generate the 1F1B (one-forward-one-backward) pipeline schedule.
21+
22+
The schedule for each stage consists of:
23+
1. Warmup phase: (num_stages - 1 - stage_id) forward passes
24+
2. Steady state: alternating backward+forward (non-last) or forward+backward (last)
25+
3. Cooldown phase: remaining backward passes
26+
27+
Args:
28+
num_stages: Number of pipeline stages.
29+
num_micro_batches: Number of micro-batches per training step.
30+
Must be >= num_stages.
31+
32+
Returns:
33+
List of lists: schedule[stage_id] = [(op, micro_batch_id), ...]
34+
where op is 'F' (forward) or 'B' (backward).
35+
"""
36+
assert num_micro_batches >= num_stages, (
37+
f"Need at least {num_stages} micro-batches for {num_stages} stages, "
38+
f"got {num_micro_batches}"
39+
)
40+
41+
S = num_stages
42+
M = num_micro_batches
43+
schedules = [[] for _ in range(S)]
44+
45+
for s in range(S):
46+
warmup_forwards = S - 1 - s
47+
is_last_stage = s == S - 1
48+
49+
# Warmup: forward-only passes to fill the pipeline
50+
for m in range(warmup_forwards):
51+
schedules[s].append(("F", m))
52+
53+
# Steady state: interleave F and B
54+
f_idx = warmup_forwards
55+
b_idx = 0
56+
num_steady = M - warmup_forwards
57+
58+
for _ in range(num_steady):
59+
if is_last_stage:
60+
# Last stage: F then B (must receive activation before backward)
61+
schedules[s].append(("F", f_idx))
62+
f_idx += 1
63+
schedules[s].append(("B", b_idx))
64+
b_idx += 1
65+
else:
66+
# Non-last stages: B then F (drain before filling)
67+
schedules[s].append(("B", b_idx))
68+
b_idx += 1
69+
schedules[s].append(("F", f_idx))
70+
f_idx += 1
71+
72+
# Cooldown: remaining backward passes
73+
while b_idx < M:
74+
schedules[s].append(("B", b_idx))
75+
b_idx += 1
76+
77+
return schedules
78+
79+
80+
class PipelineEngine:
81+
"""1F1B pipeline parallelism engine.
82+
83+
Splits a model into pipeline stages and executes them using the 1F1B
84+
schedule. Supports single-process mode for testing and multi-process
85+
NCCL mode for multi-GPU training.
86+
87+
The model must be provided as a list of stage callables. Each stage
88+
takes a hidden state tensor and returns the next hidden state.
89+
The last stage should include the loss computation.
90+
91+
Args:
92+
stage_modules: List of nn.Module instances, one per stage.
93+
Each module's forward takes (hidden_states,) and returns hidden_states.
94+
loss_fn: Loss function taking (last_stage_output, labels) -> scalar loss.
95+
Used only at the last stage. If None, the last stage must return the loss.
96+
num_micro_batches: Number of micro-batches per training step.
97+
device: Device for all stages (single-process mode).
98+
"""
99+
100+
def __init__(
101+
self,
102+
stage_modules: list[nn.Module],
103+
loss_fn=None,
104+
num_micro_batches: int = 4,
105+
device: torch.device = None,
106+
):
107+
self.stage_modules = stage_modules
108+
self.loss_fn = loss_fn
109+
self.num_stages = len(stage_modules)
110+
self.num_micro_batches = num_micro_batches
111+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
112+
113+
self.schedule = generate_1f1b_schedule(self.num_stages, num_micro_batches)
114+
115+
def step(self, micro_batch_inputs, micro_batch_labels=None):
116+
"""Run one training step with 1F1B schedule (single-process mode).
117+
118+
Executes all stages sequentially in a single process, following the
119+
1F1B schedule for correct ordering. At each schedule index:
120+
- Forward operations are processed left-to-right (stage 0 first)
121+
- Backward operations are processed right-to-left (last stage first)
122+
123+
This respects data dependencies: forward outputs flow left-to-right,
124+
backward gradients flow right-to-left.
125+
126+
Args:
127+
micro_batch_inputs: List of M input tensors, one per micro-batch.
128+
micro_batch_labels: List of M label tensors. Required if loss_fn is set.
129+
130+
Returns:
131+
dict with:
132+
loss: Average loss across micro-batches (float).
133+
losses: List of per-micro-batch losses.
134+
"""
135+
S = self.num_stages
136+
M = self.num_micro_batches
137+
138+
assert len(micro_batch_inputs) == M, (
139+
f"Expected {M} micro-batch inputs, got {len(micro_batch_inputs)}"
140+
)
141+
142+
# Storage for intermediate activations
143+
# fwd_inputs[s][m] = input tensor to stage s for micro-batch m (requires_grad)
144+
# fwd_outputs[s][m] = output tensor from stage s for micro-batch m
145+
fwd_inputs = [[None] * M for _ in range(S)]
146+
fwd_outputs = [[None] * M for _ in range(S)]
147+
losses = [None] * M
148+
grad_inputs = [[None] * M for _ in range(S)] # gradients from backward
149+
150+
# Execute the 1F1B schedule with proper dependency ordering
151+
max_ops = max(len(sched) for sched in self.schedule)
152+
153+
for op_idx in range(max_ops):
154+
# Collect operations at this schedule index
155+
forward_ops = []
156+
backward_ops = []
157+
for s in range(S):
158+
if op_idx >= len(self.schedule[s]):
159+
continue
160+
op, m = self.schedule[s][op_idx]
161+
if op == "F":
162+
forward_ops.append((s, m))
163+
else:
164+
backward_ops.append((s, m))
165+
166+
# Process forward operations left-to-right (stage 0 first)
167+
for s, m in sorted(forward_ops, key=lambda x: x[0]):
168+
self._forward_step(s, m, micro_batch_inputs, micro_batch_labels,
169+
fwd_inputs, fwd_outputs, losses)
170+
171+
# Process backward operations right-to-left (last stage first)
172+
for s, m in sorted(backward_ops, key=lambda x: -x[0]):
173+
self._backward_step(s, m, fwd_inputs, fwd_outputs, losses,
174+
grad_inputs)
175+
176+
# Compute average loss
177+
valid_losses = [l.item() for l in losses if l is not None]
178+
avg_loss = sum(valid_losses) / len(valid_losses) if valid_losses else 0.0
179+
180+
return {
181+
"loss": avg_loss,
182+
"losses": valid_losses,
183+
}
184+
185+
def _forward_step(self, stage, micro_batch, inputs, labels,
186+
fwd_inputs, fwd_outputs, losses):
187+
"""Execute one forward step for a stage and micro-batch."""
188+
S = self.num_stages
189+
190+
# Get input
191+
if stage == 0:
192+
# First stage: use the micro-batch input directly
193+
inp = inputs[micro_batch]
194+
else:
195+
# Get output from previous stage (detached for pipeline boundary)
196+
inp = fwd_outputs[stage - 1][micro_batch].detach()
197+
198+
# Enable gradient tracking at stage boundaries
199+
inp = inp.requires_grad_(True)
200+
fwd_inputs[stage][micro_batch] = inp
201+
202+
# Run forward through this stage's layers
203+
output = self.stage_modules[stage](inp)
204+
fwd_outputs[stage][micro_batch] = output
205+
206+
# Last stage: compute loss
207+
if stage == S - 1 and self.loss_fn is not None and labels is not None:
208+
loss = self.loss_fn(output, labels[micro_batch])
209+
losses[micro_batch] = loss
210+
211+
def _backward_step(self, stage, micro_batch, fwd_inputs, fwd_outputs,
212+
losses, grad_inputs):
213+
"""Execute one backward step for a stage and micro-batch."""
214+
S = self.num_stages
215+
216+
output = fwd_outputs[stage][micro_batch]
217+
inp = fwd_inputs[stage][micro_batch]
218+
219+
if stage == S - 1:
220+
# Last stage: backward from loss
221+
if losses[micro_batch] is not None:
222+
# Scale loss by 1/M for gradient accumulation
223+
scaled_loss = losses[micro_batch] / self.num_micro_batches
224+
scaled_loss.backward(retain_graph=False)
225+
else:
226+
# If no loss_fn, backward on output directly
227+
output.backward(
228+
torch.ones_like(output) / self.num_micro_batches,
229+
retain_graph=False,
230+
)
231+
else:
232+
# Non-last stage: backward using gradient from next stage
233+
grad_from_next = grad_inputs[stage + 1][micro_batch]
234+
if grad_from_next is not None:
235+
output.backward(grad_from_next, retain_graph=False)
236+
237+
# Save input gradient for the previous stage
238+
if inp.grad is not None:
239+
grad_inputs[stage][micro_batch] = inp.grad.detach()
240+
241+
def parameters(self):
242+
"""Return all trainable parameters across all stages."""
243+
for stage_module in self.stage_modules:
244+
yield from stage_module.parameters()
245+
246+
@staticmethod
247+
def split_model_layers(layers, num_stages):
248+
"""Split a list of layers evenly across stages.
249+
250+
Args:
251+
layers: List of nn.Module layers.
252+
num_stages: Number of pipeline stages.
253+
254+
Returns:
255+
List of lists: stage_layers[stage_id] = [layer1, layer2, ...]
256+
"""
257+
n = len(layers)
258+
assert n >= num_stages, (
259+
f"Cannot split {n} layers into {num_stages} stages"
260+
)
261+
262+
# Even split with remainder going to earlier stages
263+
base = n // num_stages
264+
remainder = n % num_stages
265+
266+
stage_layers = []
267+
idx = 0
268+
for s in range(num_stages):
269+
count = base + (1 if s < remainder else 0)
270+
stage_layers.append(layers[idx:idx + count])
271+
idx += count
272+
273+
return stage_layers
274+
275+
276+
class SequentialStage(nn.Module):
277+
"""A pipeline stage that sequentially runs a list of layers.
278+
279+
Simple wrapper that takes a list of nn.Module layers and runs them
280+
in sequence. Used as the default stage module when splitting a model.
281+
"""
282+
283+
def __init__(self, layers):
284+
super().__init__()
285+
self.layers = nn.ModuleList(layers)
286+
287+
def forward(self, x):
288+
for layer in self.layers:
289+
x = layer(x)
290+
return x

0 commit comments

Comments
 (0)