Skip to content

Commit 463908b

Browse files
TimDettmersclaude
andcommitted
feat: Add prepare_model_for_kbit_training and CPU offload checkpointing
Infrastructure for QLoRA training: prepare_model_for_kbit_training(): - Freezes all base parameters - Casts normalization layers to float32 - Enables gradient checkpointing - Pre-allocates global weight buffer for largest layer checkpoint_cpu_offload(): - Gradient checkpoint that offloads activations to CPU - Async non-blocking transfers overlap with GPU compute - RNG state preservation for dropout reproducibility - Verified to reduce GPU peak memory vs standard forward 8 new tests (3 for prepare_model, 5 for CPU offload checkpoint). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3d2dba2 commit 463908b

File tree

5 files changed

+354
-1
lines changed

5 files changed

+354
-1
lines changed

bitsandbytes/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
StableEmbedding,
2121
SwitchBackLinearBnb,
2222
_GlobalWeightBuffer,
23+
prepare_model_for_kbit_training,
2324
)
2425
from .triton_based_modules import (
2526
StandardLinear,

bitsandbytes/nn/modules.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,68 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
921921
return out.to(inp_dtype)
922922

923923

924+
def prepare_model_for_kbit_training(
925+
model: torch.nn.Module,
926+
use_gradient_checkpointing: bool = True,
927+
gradient_checkpointing_kwargs: Optional[dict] = None,
928+
) -> torch.nn.Module:
929+
"""Prepare a model with LinearKbit layers for QLoRA-style training.
930+
931+
This function:
932+
1. Freezes all base model parameters (requires_grad=False)
933+
2. Casts LayerNorm and other normalization layers to float32
934+
3. Enables gradient checkpointing if requested
935+
4. Registers the global weight buffer size from the model's largest layer
936+
937+
After calling this, add LoRA adapters (or any trainable parameters) and
938+
those will be the only parameters that receive gradients.
939+
940+
Args:
941+
model: A model containing LinearKbit layers.
942+
use_gradient_checkpointing: Enable gradient checkpointing for memory savings.
943+
gradient_checkpointing_kwargs: Kwargs passed to model.gradient_checkpointing_enable().
944+
945+
Returns:
946+
The modified model (in-place).
947+
"""
948+
# Freeze all parameters
949+
for param in model.parameters():
950+
param.requires_grad = False
951+
952+
# Cast normalization layers to float32 for training stability
953+
for module in model.modules():
954+
if isinstance(module, (torch.nn.LayerNorm, torch.nn.RMSNorm)):
955+
module.float()
956+
957+
# Enable gradient checkpointing
958+
if use_gradient_checkpointing:
959+
if hasattr(model, "gradient_checkpointing_enable"):
960+
kwargs = gradient_checkpointing_kwargs or {}
961+
model.gradient_checkpointing_enable(**kwargs)
962+
elif hasattr(model, "enable_input_require_grads"):
963+
model.enable_input_require_grads()
964+
model.is_gradient_checkpointing = True
965+
966+
# Register global weight buffer for the largest LinearKbit layer
967+
max_elements = 0
968+
compute_dtype = torch.float16
969+
device = None
970+
for module in model.modules():
971+
if isinstance(module, LinearKbit) and module.weight.kbit_quantized:
972+
w = module.weight
973+
n = w.N_padded * w.K_dim
974+
if n > max_elements:
975+
max_elements = n
976+
device = w.packed.device
977+
if module.compute_dtype is not None:
978+
compute_dtype = module.compute_dtype
979+
980+
if max_elements > 0 and device is not None:
981+
_GlobalWeightBuffer.get_buffer(device, max_elements, compute_dtype)
982+
983+
return model
984+
985+
924986
class Int8Params(torch.nn.Parameter):
925987
def __new__(
926988
cls,

bitsandbytes/training.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Training utilities for kbit QLoRA.
2+
3+
Provides gradient checkpointing with CPU offload for reducing GPU memory
4+
during QLoRA fine-tuning.
5+
"""
6+
7+
from typing import Any
8+
9+
import torch
10+
11+
12+
class _CPUOffloadCheckpointFunction(torch.autograd.Function):
13+
"""Gradient checkpoint that offloads activations to CPU during forward.
14+
15+
Forward: copies activations to CPU asynchronously, frees GPU copy.
16+
Backward: copies activations back from CPU, recomputes the forward pass.
17+
18+
This saves GPU memory at the cost of CPU→GPU bandwidth during backward.
19+
Non-blocking transfers overlap with GPU compute when possible.
20+
"""
21+
22+
@staticmethod
23+
def forward(ctx, run_function, preserve_rng_state, *args):
24+
ctx.run_function = run_function
25+
ctx.preserve_rng_state = preserve_rng_state
26+
27+
# Save RNG state if requested
28+
if preserve_rng_state:
29+
ctx.fwd_cpu_state = torch.random.get_rng_state()
30+
ctx.had_cuda = torch.cuda._initialized
31+
if ctx.had_cuda:
32+
ctx.fwd_gpu_state = torch.cuda.get_rng_state()
33+
34+
# Save inputs to CPU (async)
35+
ctx.cpu_inputs = []
36+
ctx.input_requires_grad = []
37+
for arg in args:
38+
if isinstance(arg, torch.Tensor):
39+
ctx.input_requires_grad.append(arg.requires_grad)
40+
# Async copy to CPU, pin memory for faster D2H transfer
41+
cpu_tensor = torch.empty(
42+
arg.shape, dtype=arg.dtype, device="cpu", pin_memory=True,
43+
)
44+
cpu_tensor.copy_(arg, non_blocking=True)
45+
ctx.cpu_inputs.append(cpu_tensor)
46+
else:
47+
ctx.input_requires_grad.append(None)
48+
ctx.cpu_inputs.append(arg)
49+
50+
# Run the function
51+
with torch.no_grad():
52+
outputs = run_function(*args)
53+
54+
return outputs
55+
56+
@staticmethod
57+
def backward(ctx, *grad_outputs):
58+
# Restore inputs from CPU (async)
59+
inputs = []
60+
for cpu_input, req_grad in zip(ctx.cpu_inputs, ctx.input_requires_grad):
61+
if isinstance(cpu_input, torch.Tensor):
62+
# Async copy back to GPU
63+
gpu_tensor = cpu_input.to("cuda", non_blocking=True)
64+
if req_grad:
65+
gpu_tensor.requires_grad_(True)
66+
inputs.append(gpu_tensor)
67+
else:
68+
inputs.append(cpu_input)
69+
70+
# Synchronize to ensure transfers are complete
71+
torch.cuda.current_stream().synchronize()
72+
73+
# Restore RNG state and recompute forward
74+
if ctx.preserve_rng_state:
75+
rng_devices = []
76+
if ctx.had_cuda:
77+
rng_devices.append("cuda")
78+
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
79+
torch.random.set_rng_state(ctx.fwd_cpu_state)
80+
if ctx.had_cuda:
81+
torch.cuda.set_rng_state(ctx.fwd_gpu_state)
82+
with torch.enable_grad():
83+
outputs = ctx.run_function(*inputs)
84+
else:
85+
with torch.enable_grad():
86+
outputs = ctx.run_function(*inputs)
87+
88+
if isinstance(outputs, torch.Tensor):
89+
outputs = (outputs,)
90+
91+
# Compute gradients
92+
input_grads = torch.autograd.grad(
93+
outputs,
94+
[inp for inp in inputs if isinstance(inp, torch.Tensor) and inp.requires_grad],
95+
grad_outputs=grad_outputs,
96+
)
97+
98+
# Map gradients back to original input positions
99+
grad_iter = iter(input_grads)
100+
result = [None, None] # for run_function and preserve_rng_state
101+
for cpu_input, req_grad in zip(ctx.cpu_inputs, ctx.input_requires_grad):
102+
if isinstance(cpu_input, torch.Tensor) and req_grad:
103+
result.append(next(grad_iter))
104+
else:
105+
result.append(None)
106+
107+
# Free CPU copies
108+
ctx.cpu_inputs = None
109+
110+
return tuple(result)
111+
112+
113+
def checkpoint_cpu_offload(
114+
function: Any,
115+
*args: Any,
116+
preserve_rng_state: bool = True,
117+
) -> Any:
118+
"""Gradient checkpoint with CPU offload.
119+
120+
Like ``torch.utils.checkpoint.checkpoint`` but offloads saved activations
121+
to CPU during forward to reduce GPU memory. Activations are copied back
122+
from CPU asynchronously during backward.
123+
124+
Args:
125+
function: The function to checkpoint.
126+
*args: Arguments to the function. Tensors will be offloaded.
127+
preserve_rng_state: Preserve and restore RNG state during recompute.
128+
129+
Returns:
130+
Output of the function.
131+
"""
132+
return _CPUOffloadCheckpointFunction.apply(function, preserve_rng_state, *args)

tests/test_linear_kbit.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import bitsandbytes as bnb
1818
from bitsandbytes import _ops # noqa: F401 — ensure ops are registered
19-
from bitsandbytes.nn import LinearKbit, ParamsKbit, _GlobalWeightBuffer
19+
from bitsandbytes.nn import LinearKbit, ParamsKbit, _GlobalWeightBuffer, prepare_model_for_kbit_training
2020

2121
# Skip all tests if CUDA not available
2222
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@@ -309,3 +309,40 @@ def test_gradient_all_k(self, k):
309309
loss.backward()
310310
assert x.grad is not None
311311
assert x.grad.shape == x.shape
312+
313+
314+
class TestPrepareModelForKbitTraining:
315+
"""Tests for prepare_model_for_kbit_training."""
316+
317+
def _make_model(self):
318+
"""Create a simple model with LinearKbit layers."""
319+
model = torch.nn.Sequential(
320+
LinearKbit(256, 128, bias=True, k=4),
321+
torch.nn.LayerNorm(128),
322+
LinearKbit(128, 64, bias=True, k=4),
323+
).to("cuda")
324+
return model
325+
326+
def test_freezes_all_params(self):
327+
"""All parameters should be frozen after prepare."""
328+
model = self._make_model()
329+
prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
330+
for param in model.parameters():
331+
assert not param.requires_grad
332+
333+
def test_layernorm_float32(self):
334+
"""LayerNorm should be cast to float32."""
335+
model = self._make_model()
336+
prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
337+
ln = model[1]
338+
assert ln.weight.dtype == torch.float32
339+
340+
def test_buffer_pre_allocated(self):
341+
"""Global weight buffer should be sized for the largest layer."""
342+
_GlobalWeightBuffer.clear()
343+
model = self._make_model()
344+
prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
345+
# Largest layer is 256→128: K_dim=256, N_padded=128, so 256*128 = 32768
346+
buf = _GlobalWeightBuffer._buffers.get(torch.device("cuda", 0))
347+
assert buf is not None
348+
assert buf.numel() >= 32768

tests/test_training.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Tests for training utilities (gradient checkpointing with CPU offload).
3+
4+
Verifies:
5+
- Correctness: output matches standard forward/backward
6+
- Memory reduction: GPU memory is lower with CPU offload
7+
- Gradient flow: gradients propagate correctly through checkpoint
8+
"""
9+
10+
import pytest
11+
import torch
12+
import torch.nn as nn
13+
14+
from bitsandbytes.training import checkpoint_cpu_offload
15+
16+
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
17+
18+
19+
def _simple_block(x):
20+
"""A simple compute block for testing."""
21+
return torch.nn.functional.gelu(x @ x.t()) @ x
22+
23+
24+
class TestCPUOffloadCheckpoint:
25+
"""Tests for checkpoint_cpu_offload."""
26+
27+
def test_forward_correctness(self):
28+
"""Output should match standard (non-checkpointed) forward."""
29+
x = torch.randn(4, 64, dtype=torch.float32, device="cuda", requires_grad=True)
30+
ref = _simple_block(x.detach().clone().requires_grad_(True))
31+
out = checkpoint_cpu_offload(_simple_block, x)
32+
diff = (out - ref).abs().max().item()
33+
assert diff < 1e-5, f"Forward diff: {diff}"
34+
35+
def test_gradient_correctness(self):
36+
"""Gradients should match standard backward."""
37+
# Standard
38+
x_std = torch.randn(4, 64, dtype=torch.float32, device="cuda", requires_grad=True)
39+
out_std = _simple_block(x_std)
40+
out_std.sum().backward()
41+
grad_std = x_std.grad.clone()
42+
43+
# Checkpointed
44+
x_ckpt = x_std.detach().clone().requires_grad_(True)
45+
out_ckpt = checkpoint_cpu_offload(_simple_block, x_ckpt)
46+
out_ckpt.sum().backward()
47+
grad_ckpt = x_ckpt.grad.clone()
48+
49+
diff = (grad_std - grad_ckpt).abs().max().item()
50+
assert diff < 1e-5, f"Gradient diff: {diff}"
51+
52+
def test_with_nn_module(self):
53+
"""Should work with nn.Module as the function."""
54+
linear = nn.Linear(64, 64).cuda()
55+
56+
x = torch.randn(4, 64, dtype=torch.float32, device="cuda", requires_grad=True)
57+
out = checkpoint_cpu_offload(linear, x)
58+
out.sum().backward()
59+
assert x.grad is not None
60+
assert x.grad.shape == x.shape
61+
62+
def test_memory_reduction(self):
63+
"""CPU offload should use less GPU memory than standard checkpoint."""
64+
dim = 1024
65+
66+
# Standard forward (saves activations on GPU)
67+
torch.cuda.empty_cache()
68+
torch.cuda.reset_peak_memory_stats()
69+
70+
layers = nn.ModuleList([nn.Linear(dim, dim).cuda() for _ in range(4)])
71+
x = torch.randn(32, dim, device="cuda", requires_grad=True)
72+
73+
# Standard: all activations stay on GPU
74+
h = x
75+
for layer in layers:
76+
h = torch.nn.functional.gelu(layer(h))
77+
h.sum().backward()
78+
peak_standard = torch.cuda.max_memory_allocated()
79+
80+
# Reset
81+
del h, x
82+
for p in layers.parameters():
83+
p.grad = None
84+
torch.cuda.empty_cache()
85+
torch.cuda.reset_peak_memory_stats()
86+
87+
# CPU offload: activations go to CPU
88+
x = torch.randn(32, dim, device="cuda", requires_grad=True)
89+
h = x
90+
for layer in layers:
91+
h = checkpoint_cpu_offload(lambda inp, l=layer: torch.nn.functional.gelu(l(inp)), h)
92+
h.sum().backward()
93+
peak_offload = torch.cuda.max_memory_allocated()
94+
95+
# CPU offload should use less peak memory
96+
# Allow some margin since PyTorch internal allocations vary
97+
assert peak_offload < peak_standard, (
98+
f"CPU offload ({peak_offload / 1e6:.1f} MB) should use less peak memory "
99+
f"than standard ({peak_standard / 1e6:.1f} MB)"
100+
)
101+
102+
def test_preserves_rng_state(self):
103+
"""RNG state should be preserved for dropout reproducibility."""
104+
linear = nn.Linear(64, 64).cuda()
105+
dropout = nn.Dropout(0.5)
106+
107+
def block(x):
108+
return dropout(linear(x))
109+
110+
torch.manual_seed(42)
111+
x = torch.randn(4, 64, device="cuda", requires_grad=True)
112+
113+
# Run twice with same seed — should produce same output
114+
torch.manual_seed(123)
115+
out1 = checkpoint_cpu_offload(block, x)
116+
117+
torch.manual_seed(123)
118+
out2 = checkpoint_cpu_offload(block, x.detach().clone().requires_grad_(True))
119+
120+
diff = (out1 - out2).abs().max().item()
121+
assert diff < 1e-6, f"RNG state not preserved: diff={diff}"

0 commit comments

Comments
 (0)