Skip to content

Commit 7b5b36d

Browse files
TimDettmersclaude
andcommitted
feat: Add CPU offload, Alpaca dataset, and benchmarking to training script
- KbitLoraModel: add cpu_offload option that wraps per-layer forward with checkpoint_cpu_offload for inter-layer activation offloading - train_qlora.py: support Alpaca dataset (tatsu-lab/alpaca) with tokenizer - train_qlora.py: report tokens/sec, avg step time - train_qlora.py: add --compare-memory mode for chunked vs unchunked - train_qlora.py: add --cpu-offload and --grad-accum options Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a60011f commit 7b5b36d

File tree

2 files changed

+296
-72
lines changed

2 files changed

+296
-72
lines changed

bitsandbytes/kbit_lora.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from bitsandbytes.autograd.lora_kbit import LoRA_MLP_Kbit, LoRA_W_Kbit
2222
from bitsandbytes.autograd.training_kernels import rmsnorm, rope
2323
from bitsandbytes.chunked import chunked_mlp_forward
24+
from bitsandbytes.training import checkpoint_cpu_offload
2425

2526
SUPPORTED_MODEL_TYPES = {"llama", "mistral", "qwen2", "qwen3"}
2627

@@ -45,6 +46,9 @@ class KbitLoraModel(nn.Module):
4546
mlp_chunk_size: Sequence chunk size for MLP. Default 4096.
4647
ce_chunk_size: Vocab chunk size for cross-entropy. Default 8192.
4748
compute_dtype: Computation dtype. Default bf16.
49+
cpu_offload: If True, offload inter-layer activations to CPU during
50+
forward and reload during backward. Saves GPU memory at cost
51+
of CPU<->GPU bandwidth. Default False.
4852
"""
4953

5054
def __init__(
@@ -58,6 +62,7 @@ def __init__(
5862
mlp_chunk_size: int = 4096,
5963
ce_chunk_size: int = 8192,
6064
compute_dtype: torch.dtype = torch.bfloat16,
65+
cpu_offload: bool = False,
6166
):
6267
super().__init__()
6368

@@ -81,6 +86,7 @@ def __init__(
8186
self.mlp_chunk_size = mlp_chunk_size
8287
self.ce_chunk_size = ce_chunk_size
8388
self.compute_dtype = compute_dtype
89+
self.cpu_offload = cpu_offload
8490

8591
# Extract model dimensions from config
8692
self.hidden_size = config.hidden_size
@@ -416,7 +422,16 @@ def forward(
416422

417423
# Decoder layers
418424
for i in range(self.num_layers):
419-
hidden = self._layer_forward(i, hidden, position_ids)
425+
if self.cpu_offload and self.training:
426+
# Wrap each layer with CPU offload: saves inter-layer
427+
# activations to CPU during forward, reloads during backward
428+
def _make_layer_fn(layer_idx, pos_ids):
429+
def _fn(h):
430+
return self._layer_forward(layer_idx, h, pos_ids)
431+
return _fn
432+
hidden = checkpoint_cpu_offload(_make_layer_fn(i, position_ids), hidden)
433+
else:
434+
hidden = self._layer_forward(i, hidden, position_ids)
420435

421436
# Final norm
422437
hidden_2d = hidden.reshape(-1, self.hidden_size)

0 commit comments

Comments
 (0)