Skip to content

Commit 814e665

Browse files
TimDettmersclaude
andcommitted
feat: Per-rank model loading for pipeline parallelism
KbitLoraModel now supports partial layer loading via layer_range, include_embed, and include_lm_head parameters. Each pipeline rank only quantizes and stores the layers it needs, reducing per-GPU memory by roughly 1/num_stages compared to loading the full model. - layer_range=(start, end): only load decoder layers [start, end) - include_embed=False: skip embedding (non-first stages) - include_lm_head=False: skip LM head + final norm (non-last stages) - _layer_forward uses local 0-based indexing within loaded range - Updated train_pipeline.py to use per-rank loading Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 20d2b7e commit 814e665

File tree

2 files changed

+118
-76
lines changed

2 files changed

+118
-76
lines changed

bitsandbytes/kbit_lora.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ class KbitLoraModel(nn.Module):
4949
cpu_offload: If True, offload inter-layer activations to CPU during
5050
forward and reload during backward. Saves GPU memory at cost
5151
of CPU<->GPU bandwidth. Default False.
52+
layer_range: Optional tuple (start, end) to only load decoder layers
53+
[start, end). Used for pipeline parallelism so each rank only
54+
loads its assigned layers. Default None (all layers).
55+
include_embed: Whether to keep the embedding layer. Default True.
56+
Set False for non-first pipeline stages.
57+
include_lm_head: Whether to quantize and keep the LM head. Default True.
58+
Set False for non-last pipeline stages.
5259
"""
5360

5461
def __init__(
@@ -63,6 +70,9 @@ def __init__(
6370
ce_chunk_size: int = 8192,
6471
compute_dtype: torch.dtype = torch.bfloat16,
6572
cpu_offload: bool = False,
73+
layer_range: Optional[tuple[int, int]] = None,
74+
include_embed: bool = True,
75+
include_lm_head: bool = True,
6676
):
6777
super().__init__()
6878

@@ -87,6 +97,8 @@ def __init__(
8797
self.ce_chunk_size = ce_chunk_size
8898
self.compute_dtype = compute_dtype
8999
self.cpu_offload = cpu_offload
100+
self.include_embed = include_embed
101+
self.include_lm_head = include_lm_head
90102

91103
# Extract model dimensions from config
92104
self.hidden_size = config.hidden_size
@@ -102,9 +114,21 @@ def __init__(
102114
self.rope_theta = getattr(config, "rope_theta", 10000.0)
103115
self.has_qk_norm = self.model_type == "qwen3"
104116

117+
# Determine layer range
118+
total_layers = config.num_hidden_layers
119+
if layer_range is not None:
120+
self._layer_start, self._layer_end = layer_range
121+
assert 0 <= self._layer_start < self._layer_end <= total_layers
122+
else:
123+
self._layer_start, self._layer_end = 0, total_layers
124+
self._num_loaded_layers = self._layer_end - self._layer_start
125+
105126
# Keep reference to original model for embeddings
106127
self.model = model
107-
self.embed_tokens = model.model.embed_tokens
128+
if include_embed:
129+
self.embed_tokens = model.model.embed_tokens
130+
else:
131+
self.embed_tokens = None
108132
self.lm_head_tied = hasattr(model, "lm_head") and (
109133
model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr()
110134
)
@@ -162,14 +186,19 @@ def _create_lora(self, name: str, N: int, K: int, device: torch.device):
162186
return A, B
163187

164188
def _quantize_and_create_lora(self, model: nn.Module):
165-
"""Walk model, quantize weights, create LoRA adapters."""
189+
"""Walk model, quantize weights, create LoRA adapters.
190+
191+
Only processes layers in [_layer_start, _layer_end) and optionally
192+
skips embedding and LM head for pipeline parallelism.
193+
"""
166194
device = next(model.parameters()).device
167195

168-
# Process each decoder layer
196+
# Process only the decoder layers in our range
169197
layers = model.model.layers
170198
self._layer_data = []
171199

172-
for i, layer in enumerate(layers):
200+
for i in range(self._layer_start, self._layer_end):
201+
layer = layers[i]
173202
attn = layer.self_attn
174203
mlp = layer.mlp
175204
prefix = f"layers_{i}"
@@ -225,23 +254,26 @@ def _quantize_and_create_lora(self, model: nn.Module):
225254

226255
self._layer_data.append(layer_info)
227256

228-
# Final norm
229-
final_norm = model.model.norm
230-
self._norm_weights["final_norm_weight"] = nn.Parameter(
231-
final_norm.weight.data.to(self.compute_dtype).clone()
232-
)
257+
# Final norm (only needed by last stage or full model)
258+
if self.include_lm_head:
259+
final_norm = model.model.norm
260+
self._norm_weights["final_norm_weight"] = nn.Parameter(
261+
final_norm.weight.data.to(self.compute_dtype).clone()
262+
)
233263

234-
# LM head (use k_lm_head)
235-
lm_weight = model.lm_head.weight.data.to(device)
236-
name = "lm_head"
237-
packed, absmax, codebook, N_padded, N, K = self._quantize_weight(
238-
lm_weight, name, k=self.k_lm_head,
239-
)
240-
self._lm_head_info = {
241-
"packed": packed, "absmax": absmax, "codebook": codebook,
242-
"N_padded": N_padded, "N": N, "K": K,
243-
"k": self.k_lm_head,
244-
}
264+
# LM head (only needed by last stage or full model)
265+
self._lm_head_info = None
266+
if self.include_lm_head:
267+
lm_weight = model.lm_head.weight.data.to(device)
268+
name = "lm_head"
269+
packed, absmax, codebook, N_padded, N, K = self._quantize_weight(
270+
lm_weight, name, k=self.k_lm_head,
271+
)
272+
self._lm_head_info = {
273+
"packed": packed, "absmax": absmax, "codebook": codebook,
274+
"N_padded": N_padded, "N": N, "K": K,
275+
"k": self.k_lm_head,
276+
}
245277

246278
# Precompute RoPE cos/sin cache
247279
self._build_rope_cache(device)
@@ -271,7 +303,7 @@ def _layer_forward(self, layer_idx: int, hidden: torch.Tensor, position_ids: tor
271303
"""Forward pass for one decoder layer.
272304
273305
Args:
274-
layer_idx: Index of the decoder layer.
306+
layer_idx: Local index (0-based within this model's loaded layers).
275307
hidden: Input hidden states [B, S, H].
276308
position_ids: Position IDs [B, S].
277309
@@ -417,11 +449,15 @@ def forward(
417449
# Extend RoPE cache if needed
418450
self._extend_rope_cache(S, device)
419451

420-
# Embedding
421-
hidden = self.embed_tokens(input_ids).to(self.compute_dtype)
452+
# Embedding (only if this model has the embedding layer)
453+
if self.embed_tokens is not None:
454+
hidden = self.embed_tokens(input_ids).to(self.compute_dtype)
455+
else:
456+
# input_ids is actually hidden states from previous pipeline stage
457+
hidden = input_ids
422458

423-
# Decoder layers
424-
for i in range(self.num_layers):
459+
# Decoder layers (local indices, 0-based)
460+
for i in range(self._num_loaded_layers):
425461
if self.cpu_offload and self.training:
426462
# Wrap each layer with CPU offload: saves inter-layer
427463
# activations to CPU during forward, reloads during backward
@@ -433,7 +469,10 @@ def _fn(h):
433469
else:
434470
hidden = self._layer_forward(i, hidden, position_ids)
435471

436-
# Final norm
472+
# Final norm + LM head (only if this model has the LM head)
473+
if not self.include_lm_head:
474+
return {"hidden": hidden}
475+
437476
hidden_2d = hidden.reshape(-1, self.hidden_size)
438477
hidden_2d = rmsnorm(
439478
hidden_2d, self._norm_weights["final_norm_weight"], eps=self.rms_norm_eps,

examples/train_pipeline.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Pipeline parallelism training example using bitsandbytes kbit quantization.
22
3-
Demonstrates distributed pipeline training across 2+ GPUs:
4-
- Loads a HuggingFace model and applies KbitLoraModel
5-
- Splits decoder layers across GPUs (first stage = embedding + first layers,
6-
last stage = remaining layers + norm + LM head)
7-
- Trains using DistributedPipelineEngine with NCCL
8-
- Reports per-GPU memory and throughput
3+
Demonstrates distributed pipeline training across 2+ GPUs with per-rank
4+
model loading — each GPU only loads the decoder layers it needs:
5+
- First stage: embedding + first half of layers
6+
- Last stage: remaining layers + final norm + LM head (loss)
7+
8+
This reduces per-GPU memory compared to loading the full model everywhere.
99
1010
Usage:
1111
# 2-GPU pipeline training on Qwen3-0.6B
@@ -50,21 +50,20 @@ class KbitFirstStage(nn.Module):
5050
"""First pipeline stage: embedding + first layers.
5151
5252
Takes input_ids [B, S], returns hidden states [B, S, H].
53+
The KbitLoraModel has already been created with only this stage's layers.
5354
"""
5455

55-
def __init__(self, kbit_model, layer_start, layer_end):
56+
def __init__(self, kbit_model):
5657
super().__init__()
5758
self.km = kbit_model
58-
self.layer_start = layer_start
59-
self.layer_end = layer_end
6059

6160
def forward(self, input_ids):
6261
B, S = input_ids.shape
6362
device = input_ids.device
6463
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
6564
self.km._extend_rope_cache(S, device)
6665
hidden = self.km.embed_tokens(input_ids).to(self.km.compute_dtype)
67-
for i in range(self.layer_start, self.layer_end):
66+
for i in range(self.km._num_loaded_layers):
6867
hidden = self.km._layer_forward(i, hidden, position_ids)
6968
return hidden
7069

@@ -74,13 +73,13 @@ class KbitLastStage(nn.Module):
7473
7574
Takes hidden states [B, S, H], returns hidden states after norm [B*S, H].
7675
Loss is computed externally by the engine's loss_fn.
76+
The KbitLoraModel has already been created with only this stage's layers
77+
plus the final norm and LM head.
7778
"""
7879

79-
def __init__(self, kbit_model, layer_start, layer_end):
80+
def __init__(self, kbit_model):
8081
super().__init__()
8182
self.km = kbit_model
82-
self.layer_start = layer_start
83-
self.layer_end = layer_end
8483

8584
def forward(self, hidden):
8685
from bitsandbytes.autograd.training_kernels import rmsnorm
@@ -90,7 +89,7 @@ def forward(self, hidden):
9089
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
9190
self.km._extend_rope_cache(S, device)
9291

93-
for i in range(self.layer_start, self.layer_end):
92+
for i in range(self.km._num_loaded_layers):
9493
hidden = self.km._layer_forward(i, hidden, position_ids)
9594

9695
# Final norm
@@ -110,12 +109,6 @@ def make_loss_fn(kbit_model):
110109
lm = km._lm_head_info
111110

112111
def loss_fn(hidden_2d, labels):
113-
"""Compute chunked cross-entropy loss.
114-
115-
Args:
116-
hidden_2d: [B*S, H] hidden states from last stage.
117-
labels: [B, S] target token IDs.
118-
"""
119112
shift_hidden = hidden_2d[:-1]
120113
shift_labels = labels.reshape(-1)[1:]
121114
loss = chunked_cross_entropy(
@@ -138,65 +131,76 @@ def main():
138131
device = torch.device(f"cuda:{rank}")
139132
torch.cuda.set_device(device)
140133

134+
is_first = (rank == 0)
135+
is_last = (rank == world_size - 1)
136+
141137
if rank == 0:
142138
print(f"{'=' * 60}")
143-
print(f"Pipeline QLoRA Training ({world_size} GPUs)")
139+
print(f"Pipeline QLoRA Training ({world_size} GPUs, per-rank loading)")
144140
print(f"{'=' * 60}")
145141
print(f"Model: {args.model}")
146142
print(f"LoRA rank: {args.lora_r}, k={args.k}")
147143
print(f"Seq len: {args.seq_len}, Micro-batches: {args.micro_batches}")
148144
print(f"Steps: {args.steps}")
149145
print()
150146

151-
# Load model
152-
from transformers import AutoModelForCausalLM
147+
# Load model — each rank loads the full HF model temporarily to extract
148+
# its layer weights. We immediately delete the original after quantization.
149+
from transformers import AutoModelForCausalLM, AutoConfig
150+
151+
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
152+
num_layers = config.num_hidden_layers
153+
layers_per_stage = num_layers // world_size
154+
layer_start = rank * layers_per_stage
155+
layer_end = (rank + 1) * layers_per_stage if rank < world_size - 1 else num_layers
156+
157+
role = "first" if is_first else ("last" if is_last else "mid")
158+
print(f" GPU {rank}: layers {layer_start}-{layer_end-1} ({role} stage)")
153159

154160
if rank == 0:
155-
print("Loading base model...")
161+
print(f"\nLoading and quantizing (per-rank)...")
162+
torch.cuda.reset_peak_memory_stats()
163+
156164
model = AutoModelForCausalLM.from_pretrained(
157165
args.model,
158166
dtype=torch.float16,
159167
device_map={"": device},
160168
trust_remote_code=True,
161169
)
162170

163-
# Quantize
164-
if rank == 0:
165-
print("Quantizing and creating LoRA adapters...")
171+
mem_after_load = torch.cuda.memory_allocated() / 1024 / 1024
172+
print(f" GPU {rank}: {mem_after_load:.0f} MB after HF model load")
173+
174+
# Create KbitLoraModel with ONLY this rank's layers
166175
kbit_model = KbitLoraModel(
167176
model,
168177
lora_r=args.lora_r,
169178
lora_alpha=16.0,
170179
k=args.k,
171180
compute_dtype=torch.bfloat16,
181+
layer_range=(layer_start, layer_end),
182+
include_embed=is_first,
183+
include_lm_head=is_last,
172184
)
185+
186+
# Delete the original HF model to free memory
173187
del model
174188
torch.cuda.empty_cache()
175189

176-
num_layers = kbit_model.num_layers
177-
layers_per_stage = num_layers // world_size
178-
layer_start = rank * layers_per_stage
179-
layer_end = (rank + 1) * layers_per_stage if rank < world_size - 1 else num_layers
190+
mem_after_quant = torch.cuda.memory_allocated() / 1024 / 1024
191+
print(f" GPU {rank}: {mem_after_quant:.0f} MB after quantize + cleanup "
192+
f"({kbit_model._num_loaded_layers} layers, "
193+
f"embed={'yes' if is_first else 'no'}, "
194+
f"lm_head={'yes' if is_last else 'no'})")
180195

181-
is_first = (rank == 0)
182-
is_last = (rank == world_size - 1)
196+
if rank == 0:
197+
print(f" Trainable params (rank 0): {kbit_model.num_trainable_parameters():,}")
183198

199+
# Create pipeline stage wrappers
184200
if is_first:
185-
stage = KbitFirstStage(kbit_model, layer_start, layer_end)
201+
stage = KbitFirstStage(kbit_model)
186202
else:
187-
stage = KbitLastStage(kbit_model, layer_start, layer_end)
188-
189-
if rank == 0:
190-
print(f" Total layers: {num_layers}")
191-
print(f" Trainable params: {kbit_model.num_trainable_parameters():,}")
192-
193-
for r in range(world_size):
194-
if r == rank:
195-
ls = r * layers_per_stage
196-
le = (r + 1) * layers_per_stage if r < world_size - 1 else num_layers
197-
role = "first" if r == 0 else ("last" if r == world_size - 1 else "mid")
198-
print(f" GPU {r}: layers {ls}-{le-1} ({role} stage)")
199-
dist.barrier()
203+
stage = KbitLastStage(kbit_model)
200204

201205
# Loss function for the last stage
202206
loss_fn = make_loss_fn(kbit_model) if is_last else None
@@ -215,7 +219,7 @@ def main():
215219
dtype=torch.bfloat16,
216220
)
217221

218-
# Optimizer — each rank has its own view of the parameters
222+
# Optimizer — each rank optimizes only its own trainable parameters
219223
trainable_params = kbit_model.get_trainable_parameters()
220224
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.01)
221225

@@ -233,8 +237,7 @@ def main():
233237
t_step = time.time()
234238
optimizer.zero_grad()
235239

236-
# Generate micro-batches (all ranks generate same data for labels)
237-
# Use deterministic seed per step so last rank has correct labels
240+
# All ranks generate same data with same seed (for label consistency)
238241
torch.manual_seed(step * 1000 + 42)
239242
micro_batch_inputs = []
240243
micro_batch_labels = []

0 commit comments

Comments
 (0)