Skip to content

Commit e408054

Browse files
TimDettmersclaude
andcommitted
feat: Add pipeline-aware gradient checkpointing
Adds CheckpointedStage and PipelineCheckpointer that wrap pipeline stages with checkpoint_cpu_offload. Stage boundary activations stay on GPU for inter-stage communication; internal layer activations are offloaded to CPU during forward and reloaded during backward. Also fixes checkpoint_cpu_offload backward to use torch.autograd.backward instead of torch.autograd.grad, which properly accumulates gradients into nn.Module parameters (not just input tensors). Updates the memory test to use lightweight layers with large intermediate activations where savings are clearly measurable. 4 new tests: checkpointed gradient correctness (CPU offload and standard), memory reduction verification, eval mode bypass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 70fef4f commit e408054

File tree

4 files changed

+287
-23
lines changed

4 files changed

+287
-23
lines changed

bitsandbytes/pipeline.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,67 @@ def split_model_layers(layers, num_stages):
273273
return stage_layers
274274

275275

276+
class CheckpointedStage(nn.Module):
277+
"""Pipeline stage with gradient checkpointing and optional CPU offload.
278+
279+
Wraps a stage module's forward with checkpoint_cpu_offload, so that
280+
intermediate activations within the stage are offloaded to CPU during
281+
forward and reloaded+recomputed during backward. Stage boundary
282+
activations (input/output tensors) stay on GPU — they're managed by
283+
the PipelineEngine for inter-stage communication.
284+
285+
Args:
286+
stage_module: The stage module to wrap.
287+
cpu_offload: If True, use checkpoint_cpu_offload (offloads to CPU).
288+
If False, use torch.utils.checkpoint (GPU-only recomputation).
289+
"""
290+
291+
def __init__(self, stage_module, cpu_offload=True):
292+
super().__init__()
293+
self.stage_module = stage_module
294+
self.cpu_offload = cpu_offload
295+
296+
def forward(self, x):
297+
if self.training:
298+
if self.cpu_offload:
299+
from bitsandbytes.training import checkpoint_cpu_offload
300+
return checkpoint_cpu_offload(self.stage_module, x)
301+
else:
302+
return torch.utils.checkpoint.checkpoint(
303+
self.stage_module, x, use_reentrant=False,
304+
)
305+
return self.stage_module(x)
306+
307+
308+
class PipelineCheckpointer:
309+
"""Wraps pipeline stages with gradient checkpointing.
310+
311+
Provides a static method to wrap each stage module with
312+
CheckpointedStage. Stage boundary activations (passed between stages)
313+
remain on GPU for pipeline communication; only internal layer
314+
activations are checkpointed.
315+
316+
Usage:
317+
stages = [SequentialStage(layers[:2]), SequentialStage(layers[2:])]
318+
stages = PipelineCheckpointer.wrap_stages(stages, cpu_offload=True)
319+
engine = PipelineEngine(stages, loss_fn=loss_fn, ...)
320+
"""
321+
322+
@staticmethod
323+
def wrap_stages(stage_modules, cpu_offload=True):
324+
"""Wrap each stage with gradient checkpointing.
325+
326+
Args:
327+
stage_modules: List of nn.Module stage modules.
328+
cpu_offload: If True, offload activations to CPU. If False,
329+
use standard gradient checkpointing (GPU recomputation only).
330+
331+
Returns:
332+
List of CheckpointedStage modules.
333+
"""
334+
return [CheckpointedStage(s, cpu_offload=cpu_offload) for s in stage_modules]
335+
336+
276337
class SequentialStage(nn.Module):
277338
"""A pipeline stage that sequentially runs a list of layers.
278339

bitsandbytes/training.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,16 @@ def backward(ctx, *grad_outputs):
8888
if isinstance(outputs, torch.Tensor):
8989
outputs = (outputs,)
9090

91-
# Compute gradients
92-
input_grads = torch.autograd.grad(
93-
outputs,
94-
[inp for inp in inputs if isinstance(inp, torch.Tensor) and inp.requires_grad],
95-
grad_outputs=grad_outputs,
96-
)
97-
98-
# Map gradients back to original input positions
99-
grad_iter = iter(input_grads)
91+
# Use backward() to accumulate gradients into all leaf parameters
92+
# (not just inputs). This is needed when the checkpointed function
93+
# is an nn.Module with trainable parameters.
94+
torch.autograd.backward(outputs, grad_outputs)
95+
96+
# Collect input gradients
10097
result = [None, None] # for run_function and preserve_rng_state
101-
for cpu_input, req_grad in zip(ctx.cpu_inputs, ctx.input_requires_grad):
102-
if isinstance(cpu_input, torch.Tensor) and req_grad:
103-
result.append(next(grad_iter))
98+
for inp, req_grad in zip(inputs, ctx.input_requires_grad):
99+
if isinstance(inp, torch.Tensor) and req_grad:
100+
result.append(inp.grad if inp.grad is not None else torch.zeros_like(inp))
104101
else:
105102
result.append(None)
106103

tests/test_pipeline.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch.nn as nn
1313

1414
from bitsandbytes.pipeline import (
15+
CheckpointedStage,
16+
PipelineCheckpointer,
1517
PipelineEngine,
1618
SequentialStage,
1719
generate_1f1b_schedule,
@@ -481,5 +483,189 @@ def test_multiple_steps(self):
481483
assert (layer.linear.weight.grad != 0).any()
482484

483485

486+
# ─── Pipeline Checkpointing Tests ─────────────────────────────────────────
487+
488+
class WideLayer(nn.Module):
489+
"""Linear layer with large intermediate for memory testing."""
490+
491+
def __init__(self, dim, intermediate):
492+
super().__init__()
493+
self.up = nn.Linear(dim, intermediate, bias=False)
494+
self.down = nn.Linear(intermediate, dim, bias=False)
495+
496+
def forward(self, x):
497+
return self.down(torch.relu(self.up(x)))
498+
499+
500+
class TestPipelineCheckpointer:
501+
502+
def test_checkpointed_gradient_correctness(self):
503+
"""Checkpointed pipeline should produce identical gradients to reference."""
504+
dim = 32
505+
M = 4
506+
torch.manual_seed(42)
507+
508+
layers = [SimpleLayer(dim).cuda() for _ in range(4)]
509+
micro_inputs = [torch.randn(4, dim, device="cuda") for _ in range(M)]
510+
micro_labels = [torch.randn(4, dim, device="cuda") for _ in range(M)]
511+
loss_fn = lambda out, labels: (out - labels).pow(2).mean()
512+
513+
# Reference: single-device gradient accumulation
514+
ref_layers = [SimpleLayer(dim).cuda() for _ in range(4)]
515+
for ref, orig in zip(ref_layers, layers):
516+
ref.linear.weight.data.copy_(orig.linear.weight.data)
517+
for ref in ref_layers:
518+
ref.zero_grad()
519+
for m in range(M):
520+
x = micro_inputs[m]
521+
for ref in ref_layers:
522+
x = ref(x)
523+
loss = loss_fn(x, micro_labels[m]) / M
524+
loss.backward()
525+
ref_grads = [ref.linear.weight.grad.clone() for ref in ref_layers]
526+
527+
# Pipeline with checkpointing
528+
for layer in layers:
529+
layer.zero_grad()
530+
stages = [SequentialStage(layers[:2]).cuda(), SequentialStage(layers[2:]).cuda()]
531+
stages = PipelineCheckpointer.wrap_stages(stages, cpu_offload=True)
532+
engine = PipelineEngine(stages, loss_fn=loss_fn, num_micro_batches=M)
533+
534+
# Set to training mode
535+
for s in stages:
536+
s.train()
537+
538+
result = engine.step(micro_inputs, micro_labels)
539+
540+
for i, layer in enumerate(layers):
541+
assert layer.linear.weight.grad is not None, f"Layer {i}: no gradient"
542+
torch.testing.assert_close(
543+
ref_grads[i], layer.linear.weight.grad,
544+
atol=1e-5, rtol=1e-5,
545+
msg=f"Layer {i}: gradient mismatch with checkpointing",
546+
)
547+
548+
def test_checkpointed_no_cpu_offload(self):
549+
"""Checkpointing without CPU offload should also produce correct gradients."""
550+
dim = 32
551+
M = 4
552+
torch.manual_seed(42)
553+
554+
layers = [SimpleLayer(dim).cuda() for _ in range(4)]
555+
micro_inputs = [torch.randn(4, dim, device="cuda") for _ in range(M)]
556+
micro_labels = [torch.randn(4, dim, device="cuda") for _ in range(M)]
557+
loss_fn = lambda out, labels: (out - labels).pow(2).mean()
558+
559+
# Reference
560+
ref_layers = [SimpleLayer(dim).cuda() for _ in range(4)]
561+
for ref, orig in zip(ref_layers, layers):
562+
ref.linear.weight.data.copy_(orig.linear.weight.data)
563+
for ref in ref_layers:
564+
ref.zero_grad()
565+
for m in range(M):
566+
x = micro_inputs[m]
567+
for ref in ref_layers:
568+
x = ref(x)
569+
loss = loss_fn(x, micro_labels[m]) / M
570+
loss.backward()
571+
ref_grads = [ref.linear.weight.grad.clone() for ref in ref_layers]
572+
573+
# Pipeline with standard checkpointing (no CPU offload)
574+
for layer in layers:
575+
layer.zero_grad()
576+
stages = [SequentialStage(layers[:2]).cuda(), SequentialStage(layers[2:]).cuda()]
577+
stages = PipelineCheckpointer.wrap_stages(stages, cpu_offload=False)
578+
engine = PipelineEngine(stages, loss_fn=loss_fn, num_micro_batches=M)
579+
for s in stages:
580+
s.train()
581+
result = engine.step(micro_inputs, micro_labels)
582+
583+
for i, layer in enumerate(layers):
584+
assert layer.linear.weight.grad is not None, f"Layer {i}: no gradient"
585+
torch.testing.assert_close(
586+
ref_grads[i], layer.linear.weight.grad,
587+
atol=1e-5, rtol=1e-5,
588+
msg=f"Layer {i}: gradient mismatch without CPU offload",
589+
)
590+
591+
def test_checkpointed_memory_reduction(self):
592+
"""Checkpointing should reduce peak GPU memory for wide layers."""
593+
dim = 64
594+
intermediate = 4096 # Large intermediate to make memory difference visible
595+
M = 4
596+
batch = 32
597+
torch.manual_seed(42)
598+
599+
loss_fn = lambda out, labels: (out - labels).pow(2).mean()
600+
601+
def run_pipeline(use_checkpoint):
602+
torch.cuda.empty_cache()
603+
torch.cuda.reset_peak_memory_stats()
604+
605+
layers = [WideLayer(dim, intermediate).cuda() for _ in range(4)]
606+
micro_inputs = [torch.randn(batch, dim, device="cuda") for _ in range(M)]
607+
micro_labels = [torch.randn(batch, dim, device="cuda") for _ in range(M)]
608+
609+
for layer in layers:
610+
layer.zero_grad()
611+
612+
stages = [
613+
SequentialStage(layers[:2]).cuda(),
614+
SequentialStage(layers[2:]).cuda(),
615+
]
616+
617+
if use_checkpoint:
618+
stages = PipelineCheckpointer.wrap_stages(stages, cpu_offload=True)
619+
for s in stages:
620+
s.train()
621+
622+
engine = PipelineEngine(stages, loss_fn=loss_fn, num_micro_batches=M)
623+
624+
result = engine.step(micro_inputs, micro_labels)
625+
626+
peak_mem = torch.cuda.max_memory_allocated()
627+
628+
# Verify gradients exist
629+
for layer in layers:
630+
for p in layer.parameters():
631+
assert p.grad is not None
632+
633+
# Cleanup
634+
del layers, micro_inputs, micro_labels, stages, engine
635+
torch.cuda.empty_cache()
636+
637+
return peak_mem
638+
639+
peak_no_ckpt = run_pipeline(use_checkpoint=False)
640+
peak_with_ckpt = run_pipeline(use_checkpoint=True)
641+
642+
# Checkpointing should use less peak memory
643+
assert peak_with_ckpt < peak_no_ckpt, (
644+
f"Checkpointing should reduce memory: "
645+
f"without={peak_no_ckpt / 1e6:.1f}MB, with={peak_with_ckpt / 1e6:.1f}MB"
646+
)
647+
648+
def test_eval_mode_skips_checkpointing(self):
649+
"""In eval mode, checkpointed stages should skip checkpointing."""
650+
dim = 32
651+
torch.manual_seed(42)
652+
653+
layers = [SimpleLayer(dim).cuda() for _ in range(4)]
654+
stage = SequentialStage(layers[:2]).cuda()
655+
ckpt_stage = CheckpointedStage(stage, cpu_offload=True)
656+
657+
x = torch.randn(4, dim, device="cuda")
658+
659+
# Training mode: uses checkpointing
660+
ckpt_stage.train()
661+
out_train = ckpt_stage(x)
662+
663+
# Eval mode: skips checkpointing
664+
ckpt_stage.eval()
665+
out_eval = ckpt_stage(x)
666+
667+
torch.testing.assert_close(out_train, out_eval, atol=1e-6, rtol=1e-6)
668+
669+
484670
if __name__ == "__main__":
485671
pytest.main([__file__, "-v", "--tb=short"])

tests/test_training.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,41 @@ def test_with_nn_module(self):
6060
assert x.grad.shape == x.shape
6161

6262
def test_memory_reduction(self):
63-
"""CPU offload should use less GPU memory than standard checkpoint."""
64-
dim = 1024
65-
66-
# Standard forward (saves activations on GPU)
63+
"""CPU offload should reduce GPU memory by offloading activations.
64+
65+
Uses lightweight parameterized functions that produce large
66+
intermediate activations so the saved-activation memory
67+
dominates over parameter gradient memory.
68+
"""
69+
dim = 64
70+
expand = 2048
71+
n_layers = 8
72+
73+
class ExpandLayer(nn.Module):
74+
"""Lightweight params but large intermediate activations."""
75+
76+
def __init__(self):
77+
super().__init__()
78+
self.w = nn.Parameter(torch.randn(dim) * 0.01)
79+
80+
def forward(self, x):
81+
# x: [batch, dim]. Expand to [batch, dim, expand], sum back.
82+
# The expanded tensor is large and saved for backward.
83+
h = x * self.w # element-wise, saves x and w for backward
84+
h = h.unsqueeze(-1).expand(-1, -1, expand) # large activation
85+
h = h.mean(-1) # back to [batch, dim]
86+
return h
87+
88+
# Standard forward (saves all expanded activations on GPU)
6789
torch.cuda.empty_cache()
6890
torch.cuda.reset_peak_memory_stats()
6991

70-
layers = nn.ModuleList([nn.Linear(dim, dim).cuda() for _ in range(4)])
71-
x = torch.randn(32, dim, device="cuda", requires_grad=True)
92+
layers = nn.ModuleList([ExpandLayer().cuda() for _ in range(n_layers)])
93+
x = torch.randn(512, dim, device="cuda", requires_grad=True)
7294

73-
# Standard: all activations stay on GPU
7495
h = x
7596
for layer in layers:
76-
h = torch.nn.functional.gelu(layer(h))
97+
h = layer(h)
7798
h.sum().backward()
7899
peak_standard = torch.cuda.max_memory_allocated()
79100

@@ -85,15 +106,14 @@ def test_memory_reduction(self):
85106
torch.cuda.reset_peak_memory_stats()
86107

87108
# CPU offload: activations go to CPU
88-
x = torch.randn(32, dim, device="cuda", requires_grad=True)
109+
x = torch.randn(512, dim, device="cuda", requires_grad=True)
89110
h = x
90111
for layer in layers:
91-
h = checkpoint_cpu_offload(lambda inp, l=layer: torch.nn.functional.gelu(l(inp)), h)
112+
h = checkpoint_cpu_offload(layer, h)
92113
h.sum().backward()
93114
peak_offload = torch.cuda.max_memory_allocated()
94115

95116
# CPU offload should use less peak memory
96-
# Allow some margin since PyTorch internal allocations vary
97117
assert peak_offload < peak_standard, (
98118
f"CPU offload ({peak_offload / 1e6:.1f} MB) should use less peak memory "
99119
f"than standard ({peak_standard / 1e6:.1f} MB)"

0 commit comments

Comments
 (0)