Skip to content

Commit a60011f

Browse files
TimDettmersclaude
andcommitted
feat: Add distributed pipeline engine with NCCL/gloo support
DistributedPipelineEngine runs one pipeline stage per process, communicating activations and gradients via torch.distributed send/recv. Supports both NCCL (for multi-GPU) and gloo (for single-GPU multi-process testing) backends. Verified with torchrun --nproc_per_node=2: all 4 layer gradients match single-device reference within 1e-5 tolerance. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e408054 commit a60011f

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed

bitsandbytes/pipeline.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,143 @@ def wrap_stages(stage_modules, cpu_offload=True):
334334
return [CheckpointedStage(s, cpu_offload=cpu_offload) for s in stage_modules]
335335

336336

337+
class DistributedPipelineEngine:
338+
"""Distributed 1F1B pipeline engine using NCCL.
339+
340+
Each process runs one pipeline stage. Activations are transferred
341+
between stages via torch.distributed.send/recv. Designed for use
342+
with torchrun or torch.distributed.launch.
343+
344+
Each rank runs one stage. rank 0 = first stage, rank (world_size-1) = last.
345+
346+
Args:
347+
stage_module: The nn.Module for this process's stage.
348+
rank: This process's rank (stage index).
349+
world_size: Total number of stages/processes.
350+
loss_fn: Loss function (only used by the last stage).
351+
num_micro_batches: Number of micro-batches per step.
352+
hidden_shape: Shape of the hidden state tensor (without batch dim).
353+
Used to pre-allocate receive buffers.
354+
dtype: Data type for tensors (default: float32).
355+
"""
356+
357+
def __init__(
358+
self,
359+
stage_module: nn.Module,
360+
rank: int,
361+
world_size: int,
362+
loss_fn=None,
363+
num_micro_batches: int = 4,
364+
hidden_shape: tuple = None,
365+
dtype: torch.dtype = torch.float32,
366+
):
367+
self.stage_module = stage_module
368+
self.rank = rank
369+
self.world_size = world_size
370+
self.loss_fn = loss_fn
371+
self.num_micro_batches = num_micro_batches
372+
self.hidden_shape = hidden_shape
373+
self.dtype = dtype
374+
self.device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
375+
376+
schedule = generate_1f1b_schedule(world_size, num_micro_batches)
377+
self.my_schedule = schedule[rank]
378+
379+
def step(self, micro_batch_inputs=None, micro_batch_labels=None):
380+
"""Run one distributed training step.
381+
382+
Args:
383+
micro_batch_inputs: List of M input tensors (only used by rank 0).
384+
micro_batch_labels: List of M label tensors (only used by last rank).
385+
386+
Returns:
387+
dict with loss info (only meaningful on last rank).
388+
"""
389+
import torch.distributed as dist
390+
391+
M = self.num_micro_batches
392+
s = self.rank
393+
S = self.world_size
394+
395+
fwd_inputs = [None] * M
396+
fwd_outputs = [None] * M
397+
losses = [None] * M
398+
grad_from_next = [None] * M
399+
400+
# Determine if we need CPU transfers (gloo doesn't support CUDA tensors)
401+
backend = dist.get_backend()
402+
use_cpu_comm = backend != "nccl"
403+
404+
def _send(tensor, dst):
405+
if use_cpu_comm:
406+
dist.send(tensor.cpu(), dst=dst)
407+
else:
408+
dist.send(tensor, dst=dst)
409+
410+
def _recv(shape, src, device, dtype):
411+
if use_cpu_comm:
412+
buf = torch.empty(*shape, dtype=dtype)
413+
dist.recv(buf, src=src)
414+
return buf.to(device)
415+
else:
416+
buf = torch.empty(*shape, device=device, dtype=dtype)
417+
dist.recv(buf, src=src)
418+
return buf
419+
420+
for op, m in self.my_schedule:
421+
if op == "F":
422+
# Get input
423+
if s == 0:
424+
inp = micro_batch_inputs[m].to(self.device)
425+
else:
426+
# Receive activation from previous stage
427+
inp = _recv(self.hidden_shape, src=s - 1,
428+
device=self.device, dtype=self.dtype)
429+
430+
inp = inp.requires_grad_(True)
431+
fwd_inputs[m] = inp
432+
433+
# Forward
434+
output = self.stage_module(inp)
435+
fwd_outputs[m] = output
436+
437+
if s < S - 1:
438+
# Send activation to next stage
439+
_send(output.detach(), dst=s + 1)
440+
441+
# Last stage: compute loss
442+
if s == S - 1 and self.loss_fn is not None and micro_batch_labels is not None:
443+
losses[m] = self.loss_fn(output, micro_batch_labels[m].to(self.device))
444+
445+
elif op == "B":
446+
output = fwd_outputs[m]
447+
inp = fwd_inputs[m]
448+
449+
if s == S - 1:
450+
# Last stage: backward from loss
451+
if losses[m] is not None:
452+
scaled_loss = losses[m] / M
453+
scaled_loss.backward(retain_graph=False)
454+
else:
455+
# Receive gradient from next stage
456+
grad = _recv(output.shape, src=s + 1,
457+
device=self.device, dtype=output.dtype)
458+
output.backward(grad, retain_graph=False)
459+
460+
if s > 0 and inp.grad is not None:
461+
# Send gradient to previous stage
462+
_send(inp.grad.detach(), dst=s - 1)
463+
464+
# Collect losses on last rank
465+
valid_losses = [l.item() for l in losses if l is not None]
466+
avg_loss = sum(valid_losses) / len(valid_losses) if valid_losses else 0.0
467+
468+
return {
469+
"loss": avg_loss,
470+
"losses": valid_losses,
471+
}
472+
473+
337474
class SequentialStage(nn.Module):
338475
"""A pipeline stage that sequentially runs a list of layers.
339476

tests/test_distributed_pipeline.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Distributed pipeline parallelism test.
2+
3+
Run with: torchrun --nproc_per_node=2 tests/test_distributed_pipeline.py
4+
5+
Verifies that the distributed pipeline engine produces the same
6+
gradients as single-process training with gradient accumulation.
7+
"""
8+
9+
import sys
10+
11+
import torch
12+
import torch.distributed as dist
13+
import torch.nn as nn
14+
15+
from bitsandbytes.pipeline import DistributedPipelineEngine, SequentialStage
16+
17+
18+
class SimpleLayer(nn.Module):
19+
def __init__(self, dim):
20+
super().__init__()
21+
self.linear = nn.Linear(dim, dim, bias=False)
22+
23+
def forward(self, x):
24+
return self.linear(x)
25+
26+
27+
def run_test():
28+
# Use gloo for point-to-point ops; NCCL send/recv can fail on single-GPU
29+
dist.init_process_group(backend="gloo")
30+
rank = dist.get_rank()
31+
world_size = dist.get_world_size()
32+
assert world_size == 2, f"Requires 2 processes, got {world_size}"
33+
34+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
35+
torch.cuda.set_device(device)
36+
37+
dim = 32
38+
M = 4
39+
batch = 4
40+
41+
# Create layers with shared seeds so all ranks have the same initial weights
42+
torch.manual_seed(42)
43+
all_layers = [SimpleLayer(dim) for _ in range(4)]
44+
45+
if rank == 0:
46+
my_layers = all_layers[:2]
47+
else:
48+
my_layers = all_layers[2:]
49+
50+
my_stage = SequentialStage(my_layers).to(device)
51+
my_stage.zero_grad()
52+
53+
# Create identical inputs/labels on all ranks
54+
torch.manual_seed(123)
55+
micro_inputs = [torch.randn(batch, dim) for _ in range(M)]
56+
micro_labels = [torch.randn(batch, dim) for _ in range(M)]
57+
58+
loss_fn = lambda out, labels: (out - labels).pow(2).mean()
59+
60+
# Run distributed pipeline
61+
engine = DistributedPipelineEngine(
62+
stage_module=my_stage,
63+
rank=rank,
64+
world_size=world_size,
65+
loss_fn=loss_fn,
66+
num_micro_batches=M,
67+
hidden_shape=(batch, dim),
68+
dtype=torch.float32,
69+
)
70+
71+
result = engine.step(
72+
micro_batch_inputs=micro_inputs if rank == 0 else None,
73+
micro_batch_labels=micro_labels if rank == world_size - 1 else None,
74+
)
75+
76+
# Collect per-layer gradients
77+
pipe_grads = {}
78+
for i, layer in enumerate(my_layers):
79+
layer_idx = i + (2 if rank == 1 else 0)
80+
if layer.linear.weight.grad is not None:
81+
pipe_grads[layer_idx] = layer.linear.weight.grad.clone()
82+
83+
# Exchange gradients and loss: rank 1 sends to rank 0 (CPU for gloo)
84+
if rank == 0:
85+
for layer_idx in [2, 3]:
86+
buf = torch.empty(dim, dim) # CPU tensor for gloo
87+
dist.recv(buf, src=1, tag=layer_idx)
88+
pipe_grads[layer_idx] = buf.to(device)
89+
# Receive loss from last rank
90+
loss_buf = torch.empty(1)
91+
dist.recv(loss_buf, src=world_size - 1, tag=100)
92+
pipeline_loss = loss_buf.item()
93+
else:
94+
for layer_idx in [2, 3]:
95+
dist.send(pipe_grads[layer_idx].cpu(), dst=0, tag=layer_idx)
96+
# Send loss to rank 0
97+
dist.send(torch.tensor([result["loss"]]), dst=0, tag=100)
98+
pipeline_loss = result["loss"]
99+
100+
# Rank 0 computes reference and checks
101+
if rank == 0:
102+
torch.manual_seed(42)
103+
ref_layers = [SimpleLayer(dim).to(device) for _ in range(4)]
104+
for ref in ref_layers:
105+
ref.zero_grad()
106+
107+
for m in range(M):
108+
x = micro_inputs[m].to(device)
109+
for ref in ref_layers:
110+
x = ref(x)
111+
loss = loss_fn(x, micro_labels[m].to(device)) / M
112+
loss.backward()
113+
114+
ref_grads = [ref.linear.weight.grad.clone() for ref in ref_layers]
115+
116+
all_pass = True
117+
for i in range(4):
118+
ref_g = ref_grads[i]
119+
pipe_g = pipe_grads.get(i)
120+
if pipe_g is None:
121+
print(f"FAIL: Layer {i} — no gradient")
122+
all_pass = False
123+
elif not torch.allclose(ref_g, pipe_g, atol=1e-5, rtol=1e-5):
124+
max_diff = (ref_g - pipe_g).abs().max().item()
125+
print(f"FAIL: Layer {i} — max diff: {max_diff:.2e}")
126+
all_pass = False
127+
else:
128+
print(f"PASS: Layer {i} — gradients match")
129+
130+
print(f"\nPipeline loss: {pipeline_loss:.6f}")
131+
print(f"Result: {'ALL PASSED' if all_pass else 'SOME FAILED'}")
132+
133+
if not all_pass:
134+
sys.exit(1)
135+
136+
dist.barrier()
137+
dist.destroy_process_group()
138+
139+
140+
if __name__ == "__main__":
141+
run_test()

0 commit comments

Comments
 (0)