Skip to content

Commit fb5bcba

Browse files
TimDettmersclaude
andcommitted
feat: Add forward_streaming/backward_streaming split, remove old monolithic API
Split forward_streaming_explicit into two methods connected by StreamingContext: - forward_streaming() → (loss, StreamingContext) - backward_streaming(ctx) → accumulates LoRA/norm gradients, frees context StreamingContext holds CPU-pinned checkpoints, position_ids, loss value, and the gradient from the loss computation for the backward pass. Removed the old forward_streaming_explicit() monolithic method after verifying gradient equivalence between the new split API and the standard non-streaming forward+backward path. Tests verify: - Gradient match between streaming and non-streaming (atol=1e-5, rtol=1e-4) - 20-step loss curve match between both paths (<5% relative error per step) - Context freed after backward - Gradient accumulation across micro-batches Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4ae1573 commit fb5bcba

File tree

2 files changed

+280
-14
lines changed

2 files changed

+280
-14
lines changed

bitsandbytes/kbit_lora.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
import math
13+
from dataclasses import dataclass, field
1314
from typing import Optional
1415

1516
import torch
@@ -26,6 +27,29 @@
2627
from bitsandbytes.training import checkpoint_cpu_offload
2728

2829

30+
@dataclass
31+
class StreamingContext:
32+
"""Holds state between forward_streaming and backward_streaming.
33+
34+
Created by forward_streaming(), consumed by backward_streaming().
35+
"""
36+
37+
checkpoints: list[torch.Tensor] = field(default_factory=list)
38+
position_ids: Optional[torch.Tensor] = None
39+
loss: Optional[torch.Tensor] = None
40+
hidden_final: Optional[torch.Tensor] = None
41+
grad_from_loss: Optional[torch.Tensor] = None
42+
43+
def free(self):
44+
"""Explicitly free CPU pinned checkpoint memory."""
45+
self.checkpoints.clear()
46+
self.hidden_final = None
47+
self.grad_from_loss = None
48+
49+
def __del__(self):
50+
self.free()
51+
52+
2953
class KbitLoraModel(nn.Module):
3054
"""Wraps a HuggingFace CausalLM model with kbit quantization + LoRA.
3155
@@ -1157,15 +1181,18 @@ def get_layer_lora_params(self, layer_idx: int) -> list[nn.Parameter]:
11571181
params.append(info[proj]["B"])
11581182
return params
11591183

1160-
def forward_streaming_explicit(
1184+
# ─── Separated streaming forward/backward ───
1185+
1186+
def forward_streaming(
11611187
self,
11621188
input_ids: torch.Tensor,
11631189
labels: torch.Tensor,
11641190
position_ids: Optional[torch.Tensor] = None,
1165-
):
1166-
"""Forward + backward with explicit per-layer autograd.grad() control.
1191+
) -> tuple[torch.Tensor, StreamingContext]:
1192+
"""Forward pass with weight streaming. Returns (loss, context).
11671193
1168-
Returns loss value. Gradients are accumulated on LoRA params.
1194+
The context must be passed to backward_streaming() to compute
1195+
gradients. This separation enables clean gradient accumulation.
11691196
"""
11701197
B, S = input_ids.shape
11711198
device = input_ids.device
@@ -1175,7 +1202,7 @@ def forward_streaming_explicit(
11751202

11761203
self._extend_rope_cache(S, device)
11771204

1178-
# ─── FORWARD: save checkpoints at block boundaries ───
1205+
# Embed
11791206
if self.embed_tokens is not None:
11801207
hidden = self.embed_tokens(input_ids).to(self.compute_dtype)
11811208
else:
@@ -1192,6 +1219,7 @@ def forward_streaming_explicit(
11921219
# Pre-load layer 0
11931220
self._stream_load_layer(0, slot=0, sync=True)
11941221

1222+
# Double-buffered forward (no grad — just checkpointing)
11951223
for i in range(n):
11961224
next_slot = 1 - (i % 2)
11971225
if i + 1 < n:
@@ -1200,15 +1228,14 @@ def forward_streaming_explicit(
12001228
with torch.no_grad():
12011229
hidden = self._layer_forward(i, hidden, position_ids)
12021230

1203-
# Save checkpoint
12041231
ckpt = torch.empty(hidden.shape, dtype=hidden.dtype, device="cpu", pin_memory=True)
12051232
ckpt.copy_(hidden, non_blocking=True)
12061233
checkpoints.append(ckpt)
12071234

12081235
if i + 1 < n:
12091236
torch.cuda.current_stream().wait_stream(self._copy_stream)
12101237

1211-
# ─── LOSS (with grad) ───
1238+
# Compute loss (with grad)
12121239
hidden_final = checkpoints[-1].to(device, non_blocking=True).requires_grad_(True)
12131240
torch.cuda.current_stream().synchronize()
12141241

@@ -1230,22 +1257,39 @@ def forward_streaming_explicit(
12301257
self.compute_dtype, self.ce_chunk_size,
12311258
)
12321259

1233-
# Also get grad for final norm weights
1260+
# Compute grad w.r.t. hidden_final and final norm
12341261
norm_params = [self._norm_weights["final_norm_weight"]]
12351262
all_grads = torch.autograd.grad(
12361263
loss, [hidden_final] + norm_params,
12371264
retain_graph=False,
12381265
)
1239-
grad = all_grads[0]
1266+
grad_from_loss = all_grads[0]
1267+
1268+
# Accumulate final norm gradients
12401269
for param, g in zip(norm_params, all_grads[1:]):
12411270
if param.grad is None:
12421271
param.grad = g.detach()
12431272
else:
12441273
param.grad.add_(g.detach())
12451274

1246-
loss_val = loss.detach()
1275+
ctx = StreamingContext(
1276+
checkpoints=checkpoints,
1277+
position_ids=position_ids,
1278+
loss=loss.detach(),
1279+
hidden_final=hidden_final,
1280+
grad_from_loss=grad_from_loss,
1281+
)
1282+
return loss.detach(), ctx
1283+
1284+
def backward_streaming(self, ctx: StreamingContext):
1285+
"""Backward pass with weight streaming. Accumulates LoRA gradients.
1286+
1287+
Consumes and frees the context.
1288+
"""
1289+
device = ctx.position_ids.device
1290+
n = self._num_loaded_layers
1291+
grad = ctx.grad_from_loss
12471292

1248-
# ─── BACKWARD: reverse layer order, double-buffered ───
12491293
# Pre-load last layer
12501294
last_slot = (n - 1) % 2
12511295
self._stream_load_layer(n - 1, slot=last_slot, sync=True)
@@ -1259,12 +1303,12 @@ def forward_streaming_explicit(
12591303
self._stream_load_layer(i - 1, slot=next_bwd_slot, sync=False)
12601304

12611305
# Restore checkpoint and recompute forward with grad
1262-
input_act = checkpoints[i].to(device, non_blocking=True)
1306+
input_act = ctx.checkpoints[i].to(device, non_blocking=True)
12631307
torch.cuda.current_stream().synchronize()
12641308
input_act = input_act.requires_grad_(True)
12651309

12661310
with torch.enable_grad():
1267-
output = self._layer_forward(i, input_act, position_ids)
1311+
output = self._layer_forward(i, input_act, ctx.position_ids)
12681312

12691313
# Get LoRA params + norm params for this layer
12701314
lora_params = self.get_layer_lora_params(i)
@@ -1301,7 +1345,7 @@ def forward_streaming_explicit(
13011345
if i > 0:
13021346
torch.cuda.current_stream().wait_stream(self._copy_stream)
13031347

1304-
return loss_val
1348+
ctx.free()
13051349

13061350
# ─── Standard forward ───
13071351

tests/test_streaming_fwd_bwd.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Tests for separated forward_streaming / backward_streaming API.
2+
3+
Verifies gradient correctness against a non-streaming reference model,
4+
and tests gradient accumulation and training convergence.
5+
"""
6+
7+
import os
8+
import tempfile
9+
10+
import pytest
11+
import torch
12+
13+
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
14+
15+
16+
def _make_model_pair():
17+
"""Create matching non-streaming and streaming models from same checkpoint."""
18+
from transformers import LlamaConfig, LlamaForCausalLM
19+
20+
from bitsandbytes.checkpoint import save_quantized, save_lora
21+
from bitsandbytes.kbit_lora import KbitLoraModel
22+
23+
config = LlamaConfig(
24+
hidden_size=256,
25+
num_hidden_layers=2,
26+
num_attention_heads=4,
27+
num_key_value_heads=2,
28+
intermediate_size=512,
29+
vocab_size=1000,
30+
max_position_embeddings=256,
31+
)
32+
model = LlamaForCausalLM(config).to(torch.float16).cuda()
33+
34+
kbit = KbitLoraModel(
35+
model, lora_r=4, lora_alpha=8.0, k=4,
36+
attn_chunk_size=64, mlp_chunk_size=64, ce_chunk_size=256,
37+
compute_dtype=torch.bfloat16,
38+
)
39+
40+
tmpdir = tempfile.mkdtemp()
41+
quant_path = os.path.join(tmpdir, "quant.safetensors")
42+
lora_path = os.path.join(tmpdir, "lora.safetensors")
43+
save_quantized(kbit, quant_path)
44+
save_lora(kbit, lora_path)
45+
46+
# Non-streaming reference (standard autograd works correctly)
47+
non_streaming = KbitLoraModel.from_quantized(
48+
quant_path, lora_r=4, lora_alpha=8.0,
49+
attn_chunk_size=64, mlp_chunk_size=64, ce_chunk_size=256,
50+
compute_dtype=torch.bfloat16,
51+
weight_streaming=False,
52+
lora_checkpoint=lora_path,
53+
)
54+
55+
# Streaming model
56+
streaming = KbitLoraModel.from_quantized(
57+
quant_path, lora_r=4, lora_alpha=8.0,
58+
attn_chunk_size=64, mlp_chunk_size=64, ce_chunk_size=256,
59+
compute_dtype=torch.bfloat16,
60+
weight_streaming=True,
61+
lora_checkpoint=lora_path,
62+
)
63+
64+
return non_streaming, streaming, tmpdir
65+
66+
67+
@pytest.fixture(scope="module")
68+
def model_pair():
69+
non_streaming, streaming, tmpdir = _make_model_pair()
70+
yield non_streaming, streaming
71+
import shutil
72+
shutil.rmtree(tmpdir, ignore_errors=True)
73+
74+
75+
class TestForwardBackwardSeparation:
76+
77+
def test_gradient_match(self, model_pair):
78+
"""Streaming gradients must match non-streaming standard forward+backward."""
79+
non_streaming, streaming = model_pair
80+
input_ids = torch.randint(0, 100, (1, 32), device="cuda")
81+
labels = input_ids.clone()
82+
83+
# ─── Reference: non-streaming forward() + loss.backward() ───
84+
non_streaming.train()
85+
for p in non_streaming.get_trainable_parameters():
86+
p.grad = None
87+
88+
result = non_streaming(input_ids, labels=labels)
89+
result["loss"].backward()
90+
91+
grads_ref = {}
92+
for name, p in non_streaming._lora_params.items():
93+
if p.grad is not None:
94+
grads_ref[name] = p.grad.clone()
95+
for name, p in non_streaming._norm_weights.items():
96+
if p.grad is not None:
97+
grads_ref[f"norm_{name}"] = p.grad.clone()
98+
99+
loss_ref = result["loss"].detach()
100+
101+
# ─── Streaming: forward_streaming + backward_streaming ───
102+
for p in streaming.get_trainable_parameters():
103+
p.grad = None
104+
105+
loss_stream, ctx = streaming.forward_streaming(input_ids, labels)
106+
streaming.backward_streaming(ctx)
107+
108+
grads_stream = {}
109+
for name, p in streaming._lora_params.items():
110+
if p.grad is not None:
111+
grads_stream[name] = p.grad.clone()
112+
for name, p in streaming._norm_weights.items():
113+
if p.grad is not None:
114+
grads_stream[f"norm_{name}"] = p.grad.clone()
115+
116+
# Compare losses
117+
assert torch.allclose(loss_ref, loss_stream, atol=1e-5), \
118+
f"Loss mismatch: {loss_ref.item()} vs {loss_stream.item()}"
119+
120+
# Compare gradients
121+
assert set(grads_ref.keys()) == set(grads_stream.keys()), \
122+
f"Gradient key mismatch: {set(grads_ref) - set(grads_stream)} vs {set(grads_stream) - set(grads_ref)}"
123+
124+
for name in grads_ref:
125+
assert torch.allclose(grads_ref[name], grads_stream[name], atol=1e-5, rtol=1e-4), \
126+
f"Gradient mismatch for {name}: max diff {(grads_ref[name] - grads_stream[name]).abs().max().item()}"
127+
128+
def test_loss_curve_match(self, model_pair):
129+
"""Loss curves must match between non-streaming and streaming over 20 steps."""
130+
non_streaming, streaming = model_pair
131+
lr = 1e-3
132+
133+
# Set both models to same initial state
134+
for (n1, p1), (n2, p2) in zip(
135+
non_streaming._lora_params.items(), streaming._lora_params.items()
136+
):
137+
torch.manual_seed(42)
138+
val = torch.randn_like(p1.data) * 0.01
139+
p1.data.copy_(val)
140+
p2.data.copy_(val)
141+
for (n1, p1), (n2, p2) in zip(
142+
non_streaming._norm_weights.items(), streaming._norm_weights.items()
143+
):
144+
p1.data.fill_(1.0)
145+
p2.data.fill_(1.0)
146+
147+
losses_ref = []
148+
losses_stream = []
149+
150+
for step in range(20):
151+
torch.manual_seed(step + 1000)
152+
input_ids = torch.randint(0, 100, (1, 32), device="cuda")
153+
labels = input_ids.clone()
154+
155+
# Non-streaming
156+
non_streaming.train()
157+
for p in non_streaming.get_trainable_parameters():
158+
p.grad = None
159+
result = non_streaming(input_ids, labels=labels)
160+
result["loss"].backward()
161+
losses_ref.append(result["loss"].item())
162+
for p in non_streaming.get_trainable_parameters():
163+
if p.grad is not None:
164+
p.data.add_(p.grad, alpha=-lr)
165+
166+
# Streaming
167+
for p in streaming.get_trainable_parameters():
168+
p.grad = None
169+
loss_s, ctx = streaming.forward_streaming(input_ids, labels)
170+
streaming.backward_streaming(ctx)
171+
losses_stream.append(loss_s.item())
172+
for p in streaming.get_trainable_parameters():
173+
if p.grad is not None:
174+
p.data.add_(p.grad, alpha=-lr)
175+
176+
# Losses should match at each step
177+
for i, (lr_val, ls_val) in enumerate(zip(losses_ref, losses_stream)):
178+
if lr_val == 0:
179+
continue
180+
rel_diff = abs(lr_val - ls_val) / abs(lr_val)
181+
assert rel_diff < 0.05, \
182+
f"Step {i}: ref loss {lr_val:.6f} vs stream loss {ls_val:.6f} (rel diff {rel_diff:.4f})"
183+
184+
def test_context_freed_after_backward(self, model_pair):
185+
"""backward_streaming should free the context's checkpoint memory."""
186+
_, streaming = model_pair
187+
input_ids = torch.randint(0, 100, (1, 32), device="cuda")
188+
labels = input_ids.clone()
189+
190+
for p in streaming.get_trainable_parameters():
191+
p.grad = None
192+
193+
_, ctx = streaming.forward_streaming(input_ids, labels)
194+
assert len(ctx.checkpoints) > 0
195+
196+
streaming.backward_streaming(ctx)
197+
assert len(ctx.checkpoints) == 0
198+
assert ctx.hidden_final is None
199+
assert ctx.grad_from_loss is None
200+
201+
def test_gradient_accumulation(self, model_pair):
202+
"""Multiple forward_streaming + backward_streaming calls should accumulate gradients."""
203+
_, streaming = model_pair
204+
205+
for p in streaming.get_trainable_parameters():
206+
p.grad = None
207+
208+
# Two micro-batches
209+
for _ in range(2):
210+
input_ids = torch.randint(0, 100, (1, 32), device="cuda")
211+
labels = input_ids.clone()
212+
213+
_, ctx = streaming.forward_streaming(input_ids, labels)
214+
streaming.backward_streaming(ctx)
215+
216+
# At least some parameters should have gradients
217+
has_grad = False
218+
for p in streaming.get_trainable_parameters():
219+
if p.grad is not None and p.grad.abs().sum() > 0:
220+
has_grad = True
221+
break
222+
assert has_grad, "No gradients after 2 micro-batches"

0 commit comments

Comments
 (0)