Skip to content

Commit 9a13bae

Browse files
TimDettmersclaude
andcommitted
feat: CPU-to-GPU streaming quantization for minimal peak GPU memory
Load HF model on CPU, then stream weights to GPU one layer at a time during quantization. Each layer's fp16 weights are moved to GPU, quantized into packed kbit format, then the CPU copy is freed. This keeps peak GPU memory at ~1 layer of fp16 + growing quantized data, instead of the entire model. - KbitLoraModel: add target_device parameter for CPU->GPU streaming - _quantize_weight: move weight to target_device, free after quantize - _quantize_and_create_lora: in streaming mode, free each layer after processing (replaces with empty nn.Module) - Only free layers when target_device is explicitly set (streaming mode); non-streaming mode preserves the original model for reuse - train_qlora.py: load model on CPU, pass target_device=cuda - train_pipeline.py: same CPU loading approach per rank Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 814e665 commit 9a13bae

File tree

3 files changed

+78
-35
lines changed

3 files changed

+78
-35
lines changed

bitsandbytes/kbit_lora.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class KbitLoraModel(nn.Module):
5656
Set False for non-first pipeline stages.
5757
include_lm_head: Whether to quantize and keep the LM head. Default True.
5858
Set False for non-last pipeline stages.
59+
target_device: Device for quantized weights and LoRA params. If None,
60+
uses the model's device. Set this when loading the HF model on
61+
CPU to stream weights to GPU one layer at a time (minimizes
62+
peak GPU memory). Example: torch.device("cuda:0").
5963
"""
6064

6165
def __init__(
@@ -73,6 +77,7 @@ def __init__(
7377
layer_range: Optional[tuple[int, int]] = None,
7478
include_embed: bool = True,
7579
include_lm_head: bool = True,
80+
target_device: Optional[torch.device] = None,
7681
):
7782
super().__init__()
7883

@@ -123,10 +128,20 @@ def __init__(
123128
self._layer_start, self._layer_end = 0, total_layers
124129
self._num_loaded_layers = self._layer_end - self._layer_start
125130

131+
# Determine target device for quantized weights.
132+
# When target_device is explicitly set (streaming mode), we free each
133+
# layer from the source model after quantization to save memory.
134+
self._streaming = target_device is not None
135+
if target_device is not None:
136+
self._target_device = target_device
137+
else:
138+
self._target_device = next(model.parameters()).device
139+
126140
# Keep reference to original model for embeddings
127141
self.model = model
128142
if include_embed:
129-
self.embed_tokens = model.model.embed_tokens
143+
# Move embedding to target device (may be CPU->GPU transfer)
144+
self.embed_tokens = model.model.embed_tokens.to(self._target_device)
130145
else:
131146
self.embed_tokens = None
132147
self.lm_head_tied = hasattr(model, "lm_head") and (
@@ -140,7 +155,7 @@ def __init__(
140155

141156
self._quantize_and_create_lora(model)
142157

143-
# Freeze all base model parameters
158+
# Freeze all base model parameters (any that remain)
144159
for p in model.parameters():
145160
p.requires_grad_(False)
146161

@@ -151,19 +166,27 @@ def __init__(
151166
p.requires_grad_(True)
152167

153168
def _quantize_weight(self, weight: torch.Tensor, name: str, k: int | None = None):
154-
"""Quantize a weight matrix and store packed data."""
169+
"""Quantize a weight matrix and store packed data.
170+
171+
The weight is moved to _target_device for quantization (CUDA kernel),
172+
then the original weight reference is no longer needed.
173+
"""
155174
if k is None:
156175
k = self.k
176+
# Move to target device for quantization (CPU -> GPU transfer if needed)
177+
weight = weight.to(self._target_device)
157178
N, K = weight.shape
158179
N_padded = ((N + 127) // 128) * 128
159180
if N_padded != N:
160181
w_padded = torch.nn.functional.pad(weight.float(), (0, 0, 0, N_padded - N))
161182
else:
162183
w_padded = weight.float()
184+
del weight # Free the fp16 copy on GPU
163185

164186
packed, absmax, codebook = F.quantize_kbit(
165187
w_padded.reshape(-1), k=k, absmax_format="fp32",
166188
)
189+
del w_padded # Free the fp32 padded copy
167190

168191
# Store as non-trainable buffers
169192
safe_name = name.replace(".", "_")
@@ -173,9 +196,10 @@ def _quantize_weight(self, weight: torch.Tensor, name: str, k: int | None = None
173196

174197
return packed, absmax, codebook, N_padded, N, K
175198

176-
def _create_lora(self, name: str, N: int, K: int, device: torch.device):
177-
"""Create LoRA A and B parameters for a weight matrix."""
199+
def _create_lora(self, name: str, N: int, K: int):
200+
"""Create LoRA A and B parameters for a weight matrix on _target_device."""
178201
safe_name = name.replace(".", "_")
202+
device = self._target_device
179203
# A: [r, K] initialized with Kaiming uniform
180204
A = nn.Parameter(torch.empty(self.lora_r, K, dtype=self.compute_dtype, device=device))
181205
nn.init.kaiming_uniform_(A, a=math.sqrt(5))
@@ -190,8 +214,13 @@ def _quantize_and_create_lora(self, model: nn.Module):
190214
191215
Only processes layers in [_layer_start, _layer_end) and optionally
192216
skips embedding and LM head for pipeline parallelism.
217+
218+
Streams weights one layer at a time: each layer's weights are moved
219+
from the model's device (often CPU) to _target_device (GPU), quantized,
220+
then the original layer is deleted. This keeps peak GPU memory at
221+
~1 layer of fp16 weights plus the growing quantized data.
193222
"""
194-
device = next(model.parameters()).device
223+
device = self._target_device
195224

196225
# Process only the decoder layers in our range
197226
layers = model.model.layers
@@ -207,12 +236,12 @@ def _quantize_and_create_lora(self, model: nn.Module):
207236

208237
# Attention projections (use k_attention)
209238
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
210-
weight = getattr(attn, proj_name).weight.data.to(device)
239+
weight = getattr(attn, proj_name).weight.data
211240
name = f"{prefix}_attn_{proj_name}"
212241
packed, absmax, codebook, N_padded, N, K = self._quantize_weight(
213242
weight, name, k=self.k_attention,
214243
)
215-
A, B = self._create_lora(name, N, K, device)
244+
A, B = self._create_lora(name, N, K)
216245
layer_info[proj_name] = {
217246
"packed": packed, "absmax": absmax, "codebook": codebook,
218247
"N_padded": N_padded, "N": N, "K": K, "A": A, "B": B,
@@ -221,24 +250,24 @@ def _quantize_and_create_lora(self, model: nn.Module):
221250

222251
# MLP projections (use k_mlp)
223252
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
224-
weight = getattr(mlp, proj_name).weight.data.to(device)
253+
weight = getattr(mlp, proj_name).weight.data
225254
name = f"{prefix}_mlp_{proj_name}"
226255
packed, absmax, codebook, N_padded, N, K = self._quantize_weight(
227256
weight, name, k=self.k_mlp,
228257
)
229-
A, B = self._create_lora(name, N, K, device)
258+
A, B = self._create_lora(name, N, K)
230259
layer_info[proj_name] = {
231260
"packed": packed, "absmax": absmax, "codebook": codebook,
232261
"N_padded": N_padded, "N": N, "K": K, "A": A, "B": B,
233262
"k": self.k_mlp,
234263
}
235264

236-
# Norm weights (trainable, not quantized)
265+
# Norm weights (trainable, not quantized) — move to target device
237266
for norm_name in ["input_layernorm", "post_attention_layernorm"]:
238267
norm = getattr(layer, norm_name)
239268
safe = f"{prefix}_{norm_name}_weight"
240269
self._norm_weights[safe] = nn.Parameter(
241-
norm.weight.data.to(self.compute_dtype).clone()
270+
norm.weight.data.to(device=device, dtype=self.compute_dtype).clone()
242271
)
243272
layer_info[norm_name] = self._norm_weights[safe]
244273

@@ -248,23 +277,31 @@ def _quantize_and_create_lora(self, model: nn.Module):
248277
norm = getattr(attn, norm_name)
249278
safe = f"{prefix}_attn_{norm_name}_weight"
250279
self._norm_weights[safe] = nn.Parameter(
251-
norm.weight.data.to(self.compute_dtype).clone()
280+
norm.weight.data.to(device=device, dtype=self.compute_dtype).clone()
252281
)
253282
layer_info[norm_name] = self._norm_weights[safe]
254283

255284
self._layer_data.append(layer_info)
256285

286+
# In streaming mode, free each layer from the source model
287+
# after quantization to release memory (typically CPU RAM).
288+
if self._streaming:
289+
layers[i] = nn.Module()
290+
del layer
291+
if device.type == "cuda":
292+
torch.cuda.empty_cache()
293+
257294
# Final norm (only needed by last stage or full model)
258295
if self.include_lm_head:
259296
final_norm = model.model.norm
260297
self._norm_weights["final_norm_weight"] = nn.Parameter(
261-
final_norm.weight.data.to(self.compute_dtype).clone()
298+
final_norm.weight.data.to(device=device, dtype=self.compute_dtype).clone()
262299
)
263300

264301
# LM head (only needed by last stage or full model)
265302
self._lm_head_info = None
266303
if self.include_lm_head:
267-
lm_weight = model.lm_head.weight.data.to(device)
304+
lm_weight = model.lm_head.weight.data
268305
name = "lm_head"
269306
packed, absmax, codebook, N_padded, N, K = self._quantize_weight(
270307
lm_weight, name, k=self.k_lm_head,

examples/train_pipeline.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,9 @@ def main():
144144
print(f"Steps: {args.steps}")
145145
print()
146146

147-
# Load model — each rank loads the full HF model temporarily to extract
148-
# its layer weights. We immediately delete the original after quantization.
147+
# Load model on CPU, then stream weights to GPU layer by layer.
148+
# This avoids the full model ever being on GPU — peak GPU memory is
149+
# just ~1 fp16 layer at a time plus the growing quantized data.
149150
from transformers import AutoModelForCausalLM, AutoConfig
150151

151152
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
@@ -158,20 +159,21 @@ def main():
158159
print(f" GPU {rank}: layers {layer_start}-{layer_end-1} ({role} stage)")
159160

160161
if rank == 0:
161-
print(f"\nLoading and quantizing (per-rank)...")
162+
print(f"\nLoading HF model on CPU, streaming to GPU...")
162163
torch.cuda.reset_peak_memory_stats()
163164

165+
# Load on CPU — no GPU memory used yet
164166
model = AutoModelForCausalLM.from_pretrained(
165167
args.model,
166168
dtype=torch.float16,
167-
device_map={"": device},
169+
device_map="cpu",
168170
trust_remote_code=True,
169171
)
170172

171-
mem_after_load = torch.cuda.memory_allocated() / 1024 / 1024
172-
print(f" GPU {rank}: {mem_after_load:.0f} MB after HF model load")
173+
mem_before = torch.cuda.memory_allocated() / 1024 / 1024
174+
print(f" GPU {rank}: {mem_before:.0f} MB after HF model load (model on CPU)")
173175

174-
# Create KbitLoraModel with ONLY this rank's layers
176+
# Create KbitLoraModel: streams weights CPU->GPU one layer at a time
175177
kbit_model = KbitLoraModel(
176178
model,
177179
lora_r=args.lora_r,
@@ -181,15 +183,17 @@ def main():
181183
layer_range=(layer_start, layer_end),
182184
include_embed=is_first,
183185
include_lm_head=is_last,
186+
target_device=device,
184187
)
185188

186-
# Delete the original HF model to free memory
189+
# Delete the original HF model to free CPU memory
187190
del model
188-
torch.cuda.empty_cache()
189191

190192
mem_after_quant = torch.cuda.memory_allocated() / 1024 / 1024
191-
print(f" GPU {rank}: {mem_after_quant:.0f} MB after quantize + cleanup "
192-
f"({kbit_model._num_loaded_layers} layers, "
193+
peak_during_quant = torch.cuda.max_memory_allocated() / 1024 / 1024
194+
print(f" GPU {rank}: {mem_after_quant:.0f} MB after quantize "
195+
f"(peak during load: {peak_during_quant:.0f} MB, "
196+
f"{kbit_model._num_loaded_layers} layers, "
193197
f"embed={'yes' if is_first else 'no'}, "
194198
f"lm_head={'yes' if is_last else 'no'})")
195199

examples/train_qlora.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,22 @@ def main():
293293
if tokenizer.pad_token_id is None:
294294
tokenizer.pad_token_id = tokenizer.eos_token_id
295295

296-
# Load base model
297-
print("Loading base model...")
296+
# Load base model on CPU, then stream weights to GPU during quantization.
297+
# This keeps peak GPU memory minimal — only 1 layer of fp16 weights at a
298+
# time on GPU, plus the growing quantized data.
299+
print("Loading base model on CPU...")
298300
t0 = time.time()
299301
model = AutoModelForCausalLM.from_pretrained(
300302
args.model,
301303
dtype=torch.float16,
302-
device_map="cuda",
304+
device_map="cpu",
303305
trust_remote_code=True,
304306
)
305307
print(f" Loaded in {time.time() - t0:.1f}s")
306-
print(f" GPU memory after load: {get_gpu_memory_mb():.0f} MB")
308+
print(f" GPU memory after load: {get_gpu_memory_mb():.0f} MB (model on CPU)")
307309

308-
# Apply KbitLoraModel
309-
print("\nQuantizing and creating LoRA adapters...")
310+
# Apply KbitLoraModel — streams weights CPU->GPU one layer at a time
311+
print("\nQuantizing and streaming to GPU...")
310312
t0 = time.time()
311313
kbit_model = KbitLoraModel(
312314
model,
@@ -318,15 +320,15 @@ def main():
318320
ce_chunk_size=args.ce_chunk,
319321
compute_dtype=torch.bfloat16,
320322
cpu_offload=args.cpu_offload,
323+
target_device=torch.device("cuda"),
321324
)
322325
print(f" Quantized in {time.time() - t0:.1f}s")
323326
print(f" Trainable parameters: {kbit_model.num_trainable_parameters():,}")
324327
print(f" GPU memory after quantization: {get_gpu_memory_mb():.0f} MB")
328+
print(f" Peak GPU memory during load: {get_gpu_peak_mb():.0f} MB")
325329

326-
# Free the original model weights (they're now quantized)
330+
# Free the original model (CPU memory)
327331
del model
328-
torch.cuda.empty_cache()
329-
print(f" GPU memory after cleanup: {get_gpu_memory_mb():.0f} MB")
330332

331333
# Prepare dataset
332334
if not args.synthetic:

0 commit comments

Comments
 (0)