Skip to content

Commit 20d2b7e

Browse files
TimDettmersclaude
andcommitted
feat: Add pipeline training example and fix first-stage integer input bug
- examples/train_pipeline.py: pipeline parallelism training with KbitLoraModel split across 2+ GPUs using DistributedPipelineEngine with NCCL - Fix: skip requires_grad_(True) for first stage (integer input_ids) - Uses KbitFirstStage (embedding + layers) and KbitLastStage (layers + norm) - Loss computed via chunked CE on last stage Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7b5b36d commit 20d2b7e

File tree

2 files changed

+308
-1
lines changed

2 files changed

+308
-1
lines changed

bitsandbytes/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,10 @@ def _recv(shape, src, device, dtype):
427427
inp = _recv(self.hidden_shape, src=s - 1,
428428
device=self.device, dtype=self.dtype)
429429

430-
inp = inp.requires_grad_(True)
430+
# Only set requires_grad for non-first stages (first stage may
431+
# receive integer input_ids that can't track gradients)
432+
if s > 0:
433+
inp = inp.requires_grad_(True)
431434
fwd_inputs[m] = inp
432435

433436
# Forward

examples/train_pipeline.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""Pipeline parallelism training example using bitsandbytes kbit quantization.
2+
3+
Demonstrates distributed pipeline training across 2+ GPUs:
4+
- Loads a HuggingFace model and applies KbitLoraModel
5+
- Splits decoder layers across GPUs (first stage = embedding + first layers,
6+
last stage = remaining layers + norm + LM head)
7+
- Trains using DistributedPipelineEngine with NCCL
8+
- Reports per-GPU memory and throughput
9+
10+
Usage:
11+
# 2-GPU pipeline training on Qwen3-0.6B
12+
torchrun --nproc_per_node=2 examples/train_pipeline.py
13+
14+
# Larger model
15+
torchrun --nproc_per_node=2 examples/train_pipeline.py --model Qwen/Qwen3-4B
16+
17+
# More micro-batches for better pipeline utilization
18+
torchrun --nproc_per_node=2 examples/train_pipeline.py --micro-batches 8
19+
"""
20+
21+
import argparse
22+
import os
23+
import time
24+
25+
import torch
26+
import torch.distributed as dist
27+
import torch.nn as nn
28+
29+
if "BNB_CUDA_VERSION" not in os.environ:
30+
pass
31+
32+
import bitsandbytes # noqa: F401
33+
from bitsandbytes.kbit_lora import KbitLoraModel
34+
from bitsandbytes.pipeline import DistributedPipelineEngine
35+
36+
37+
def parse_args():
38+
parser = argparse.ArgumentParser(description="Pipeline QLoRA training")
39+
parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="HuggingFace model name")
40+
parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank")
41+
parser.add_argument("--k", type=int, default=4, help="Quantization bit width")
42+
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
43+
parser.add_argument("--steps", type=int, default=20, help="Training steps")
44+
parser.add_argument("--seq-len", type=int, default=256, help="Sequence length")
45+
parser.add_argument("--micro-batches", type=int, default=4, help="Number of micro-batches")
46+
return parser.parse_args()
47+
48+
49+
class KbitFirstStage(nn.Module):
50+
"""First pipeline stage: embedding + first layers.
51+
52+
Takes input_ids [B, S], returns hidden states [B, S, H].
53+
"""
54+
55+
def __init__(self, kbit_model, layer_start, layer_end):
56+
super().__init__()
57+
self.km = kbit_model
58+
self.layer_start = layer_start
59+
self.layer_end = layer_end
60+
61+
def forward(self, input_ids):
62+
B, S = input_ids.shape
63+
device = input_ids.device
64+
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
65+
self.km._extend_rope_cache(S, device)
66+
hidden = self.km.embed_tokens(input_ids).to(self.km.compute_dtype)
67+
for i in range(self.layer_start, self.layer_end):
68+
hidden = self.km._layer_forward(i, hidden, position_ids)
69+
return hidden
70+
71+
72+
class KbitLastStage(nn.Module):
73+
"""Last pipeline stage: remaining layers + final norm.
74+
75+
Takes hidden states [B, S, H], returns hidden states after norm [B*S, H].
76+
Loss is computed externally by the engine's loss_fn.
77+
"""
78+
79+
def __init__(self, kbit_model, layer_start, layer_end):
80+
super().__init__()
81+
self.km = kbit_model
82+
self.layer_start = layer_start
83+
self.layer_end = layer_end
84+
85+
def forward(self, hidden):
86+
from bitsandbytes.autograd.training_kernels import rmsnorm
87+
88+
B, S, H = hidden.shape
89+
device = hidden.device
90+
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
91+
self.km._extend_rope_cache(S, device)
92+
93+
for i in range(self.layer_start, self.layer_end):
94+
hidden = self.km._layer_forward(i, hidden, position_ids)
95+
96+
# Final norm
97+
hidden_2d = hidden.reshape(-1, self.km.hidden_size)
98+
hidden_2d = rmsnorm(
99+
hidden_2d, self.km._norm_weights["final_norm_weight"],
100+
eps=self.km.rms_norm_eps,
101+
)
102+
return hidden_2d
103+
104+
105+
def make_loss_fn(kbit_model):
106+
"""Create a loss function closure that uses chunked cross-entropy."""
107+
from bitsandbytes.autograd.chunked_ce import chunked_cross_entropy
108+
109+
km = kbit_model
110+
lm = km._lm_head_info
111+
112+
def loss_fn(hidden_2d, labels):
113+
"""Compute chunked cross-entropy loss.
114+
115+
Args:
116+
hidden_2d: [B*S, H] hidden states from last stage.
117+
labels: [B, S] target token IDs.
118+
"""
119+
shift_hidden = hidden_2d[:-1]
120+
shift_labels = labels.reshape(-1)[1:]
121+
loss = chunked_cross_entropy(
122+
shift_hidden, lm["packed"], lm["absmax"], lm["codebook"],
123+
shift_labels,
124+
lm["k"], lm["K"], lm["N_padded"], lm["N"],
125+
km.compute_dtype, km.ce_chunk_size,
126+
)
127+
return loss
128+
129+
return loss_fn
130+
131+
132+
def main():
133+
args = parse_args()
134+
135+
dist.init_process_group(backend="nccl")
136+
rank = dist.get_rank()
137+
world_size = dist.get_world_size()
138+
device = torch.device(f"cuda:{rank}")
139+
torch.cuda.set_device(device)
140+
141+
if rank == 0:
142+
print(f"{'=' * 60}")
143+
print(f"Pipeline QLoRA Training ({world_size} GPUs)")
144+
print(f"{'=' * 60}")
145+
print(f"Model: {args.model}")
146+
print(f"LoRA rank: {args.lora_r}, k={args.k}")
147+
print(f"Seq len: {args.seq_len}, Micro-batches: {args.micro_batches}")
148+
print(f"Steps: {args.steps}")
149+
print()
150+
151+
# Load model
152+
from transformers import AutoModelForCausalLM
153+
154+
if rank == 0:
155+
print("Loading base model...")
156+
model = AutoModelForCausalLM.from_pretrained(
157+
args.model,
158+
dtype=torch.float16,
159+
device_map={"": device},
160+
trust_remote_code=True,
161+
)
162+
163+
# Quantize
164+
if rank == 0:
165+
print("Quantizing and creating LoRA adapters...")
166+
kbit_model = KbitLoraModel(
167+
model,
168+
lora_r=args.lora_r,
169+
lora_alpha=16.0,
170+
k=args.k,
171+
compute_dtype=torch.bfloat16,
172+
)
173+
del model
174+
torch.cuda.empty_cache()
175+
176+
num_layers = kbit_model.num_layers
177+
layers_per_stage = num_layers // world_size
178+
layer_start = rank * layers_per_stage
179+
layer_end = (rank + 1) * layers_per_stage if rank < world_size - 1 else num_layers
180+
181+
is_first = (rank == 0)
182+
is_last = (rank == world_size - 1)
183+
184+
if is_first:
185+
stage = KbitFirstStage(kbit_model, layer_start, layer_end)
186+
else:
187+
stage = KbitLastStage(kbit_model, layer_start, layer_end)
188+
189+
if rank == 0:
190+
print(f" Total layers: {num_layers}")
191+
print(f" Trainable params: {kbit_model.num_trainable_parameters():,}")
192+
193+
for r in range(world_size):
194+
if r == rank:
195+
ls = r * layers_per_stage
196+
le = (r + 1) * layers_per_stage if r < world_size - 1 else num_layers
197+
role = "first" if r == 0 else ("last" if r == world_size - 1 else "mid")
198+
print(f" GPU {r}: layers {ls}-{le-1} ({role} stage)")
199+
dist.barrier()
200+
201+
# Loss function for the last stage
202+
loss_fn = make_loss_fn(kbit_model) if is_last else None
203+
204+
# Hidden shape for inter-stage communication: [B, S, H]
205+
hidden_shape = (1, args.seq_len, kbit_model.hidden_size)
206+
207+
# Pipeline engine
208+
engine = DistributedPipelineEngine(
209+
stage_module=stage,
210+
rank=rank,
211+
world_size=world_size,
212+
loss_fn=loss_fn,
213+
num_micro_batches=args.micro_batches,
214+
hidden_shape=hidden_shape,
215+
dtype=torch.bfloat16,
216+
)
217+
218+
# Optimizer — each rank has its own view of the parameters
219+
trainable_params = kbit_model.get_trainable_parameters()
220+
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.01)
221+
222+
# Training loop
223+
if rank == 0:
224+
print(f"\n{'=' * 60}")
225+
print("Training")
226+
print(f"{'=' * 60}")
227+
228+
vocab_size = kbit_model.vocab_size
229+
losses = []
230+
torch.cuda.reset_peak_memory_stats()
231+
232+
for step in range(args.steps):
233+
t_step = time.time()
234+
optimizer.zero_grad()
235+
236+
# Generate micro-batches (all ranks generate same data for labels)
237+
# Use deterministic seed per step so last rank has correct labels
238+
torch.manual_seed(step * 1000 + 42)
239+
micro_batch_inputs = []
240+
micro_batch_labels = []
241+
for mb in range(args.micro_batches):
242+
input_ids = torch.randint(0, vocab_size, (1, args.seq_len), device=device)
243+
labels = input_ids.clone()
244+
labels[:, :1] = -100
245+
micro_batch_inputs.append(input_ids)
246+
micro_batch_labels.append(labels)
247+
248+
# Run pipeline step
249+
result = engine.step(
250+
micro_batch_inputs=micro_batch_inputs if is_first else None,
251+
micro_batch_labels=micro_batch_labels if is_last else None,
252+
)
253+
254+
# Get loss from last rank
255+
loss_val = result["loss"] if is_last else 0.0
256+
loss_tensor = torch.tensor([loss_val], device=device)
257+
dist.broadcast(loss_tensor, src=world_size - 1)
258+
loss_val = loss_tensor.item()
259+
260+
optimizer.step()
261+
losses.append(loss_val)
262+
263+
dt = time.time() - t_step
264+
tokens = args.micro_batches * args.seq_len
265+
266+
if rank == 0 and (step % 5 == 0 or step == args.steps - 1):
267+
peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
268+
print(
269+
f" Step {step:3d}/{args.steps} | "
270+
f"Loss: {loss_val:.4f} | "
271+
f"Time: {dt:.2f}s | "
272+
f"Tok/s: {tokens/dt:.0f} | "
273+
f"Peak mem: {peak_mb:.0f} MB"
274+
)
275+
276+
# Results
277+
if rank == 0:
278+
print(f"\n{'=' * 60}")
279+
print("Results")
280+
print(f"{'=' * 60}")
281+
print(f" Initial loss: {losses[0]:.4f}")
282+
print(f" Final loss: {losses[-1]:.4f}")
283+
284+
if len(losses) >= 10:
285+
early = sum(losses[:5]) / 5
286+
late = sum(losses[-5:]) / 5
287+
if late < early:
288+
print(f" Loss DECREASED from {early:.4f} to {late:.4f} (PASS)")
289+
else:
290+
print(f" WARNING: Loss did not decrease ({early:.4f} -> {late:.4f})")
291+
292+
# Report per-GPU peak memory
293+
peak = torch.tensor([torch.cuda.max_memory_allocated() / 1024 / 1024], device=device)
294+
peaks = [torch.zeros(1, device=device) for _ in range(world_size)]
295+
dist.all_gather(peaks, peak)
296+
if rank == 0:
297+
for r, p in enumerate(peaks):
298+
print(f" GPU {r} peak memory: {p.item():.0f} MB")
299+
300+
dist.destroy_process_group()
301+
302+
303+
if __name__ == "__main__":
304+
main()

0 commit comments

Comments
 (0)