|
| 1 | +"""Pipeline parallelism training example using bitsandbytes kbit quantization. |
| 2 | +
|
| 3 | +Demonstrates distributed pipeline training across 2+ GPUs: |
| 4 | +- Loads a HuggingFace model and applies KbitLoraModel |
| 5 | +- Splits decoder layers across GPUs (first stage = embedding + first layers, |
| 6 | + last stage = remaining layers + norm + LM head) |
| 7 | +- Trains using DistributedPipelineEngine with NCCL |
| 8 | +- Reports per-GPU memory and throughput |
| 9 | +
|
| 10 | +Usage: |
| 11 | + # 2-GPU pipeline training on Qwen3-0.6B |
| 12 | + torchrun --nproc_per_node=2 examples/train_pipeline.py |
| 13 | +
|
| 14 | + # Larger model |
| 15 | + torchrun --nproc_per_node=2 examples/train_pipeline.py --model Qwen/Qwen3-4B |
| 16 | +
|
| 17 | + # More micro-batches for better pipeline utilization |
| 18 | + torchrun --nproc_per_node=2 examples/train_pipeline.py --micro-batches 8 |
| 19 | +""" |
| 20 | + |
| 21 | +import argparse |
| 22 | +import os |
| 23 | +import time |
| 24 | + |
| 25 | +import torch |
| 26 | +import torch.distributed as dist |
| 27 | +import torch.nn as nn |
| 28 | + |
| 29 | +if "BNB_CUDA_VERSION" not in os.environ: |
| 30 | + pass |
| 31 | + |
| 32 | +import bitsandbytes # noqa: F401 |
| 33 | +from bitsandbytes.kbit_lora import KbitLoraModel |
| 34 | +from bitsandbytes.pipeline import DistributedPipelineEngine |
| 35 | + |
| 36 | + |
| 37 | +def parse_args(): |
| 38 | + parser = argparse.ArgumentParser(description="Pipeline QLoRA training") |
| 39 | + parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="HuggingFace model name") |
| 40 | + parser.add_argument("--lora-r", type=int, default=64, help="LoRA rank") |
| 41 | + parser.add_argument("--k", type=int, default=4, help="Quantization bit width") |
| 42 | + parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") |
| 43 | + parser.add_argument("--steps", type=int, default=20, help="Training steps") |
| 44 | + parser.add_argument("--seq-len", type=int, default=256, help="Sequence length") |
| 45 | + parser.add_argument("--micro-batches", type=int, default=4, help="Number of micro-batches") |
| 46 | + return parser.parse_args() |
| 47 | + |
| 48 | + |
| 49 | +class KbitFirstStage(nn.Module): |
| 50 | + """First pipeline stage: embedding + first layers. |
| 51 | +
|
| 52 | + Takes input_ids [B, S], returns hidden states [B, S, H]. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__(self, kbit_model, layer_start, layer_end): |
| 56 | + super().__init__() |
| 57 | + self.km = kbit_model |
| 58 | + self.layer_start = layer_start |
| 59 | + self.layer_end = layer_end |
| 60 | + |
| 61 | + def forward(self, input_ids): |
| 62 | + B, S = input_ids.shape |
| 63 | + device = input_ids.device |
| 64 | + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) |
| 65 | + self.km._extend_rope_cache(S, device) |
| 66 | + hidden = self.km.embed_tokens(input_ids).to(self.km.compute_dtype) |
| 67 | + for i in range(self.layer_start, self.layer_end): |
| 68 | + hidden = self.km._layer_forward(i, hidden, position_ids) |
| 69 | + return hidden |
| 70 | + |
| 71 | + |
| 72 | +class KbitLastStage(nn.Module): |
| 73 | + """Last pipeline stage: remaining layers + final norm. |
| 74 | +
|
| 75 | + Takes hidden states [B, S, H], returns hidden states after norm [B*S, H]. |
| 76 | + Loss is computed externally by the engine's loss_fn. |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__(self, kbit_model, layer_start, layer_end): |
| 80 | + super().__init__() |
| 81 | + self.km = kbit_model |
| 82 | + self.layer_start = layer_start |
| 83 | + self.layer_end = layer_end |
| 84 | + |
| 85 | + def forward(self, hidden): |
| 86 | + from bitsandbytes.autograd.training_kernels import rmsnorm |
| 87 | + |
| 88 | + B, S, H = hidden.shape |
| 89 | + device = hidden.device |
| 90 | + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) |
| 91 | + self.km._extend_rope_cache(S, device) |
| 92 | + |
| 93 | + for i in range(self.layer_start, self.layer_end): |
| 94 | + hidden = self.km._layer_forward(i, hidden, position_ids) |
| 95 | + |
| 96 | + # Final norm |
| 97 | + hidden_2d = hidden.reshape(-1, self.km.hidden_size) |
| 98 | + hidden_2d = rmsnorm( |
| 99 | + hidden_2d, self.km._norm_weights["final_norm_weight"], |
| 100 | + eps=self.km.rms_norm_eps, |
| 101 | + ) |
| 102 | + return hidden_2d |
| 103 | + |
| 104 | + |
| 105 | +def make_loss_fn(kbit_model): |
| 106 | + """Create a loss function closure that uses chunked cross-entropy.""" |
| 107 | + from bitsandbytes.autograd.chunked_ce import chunked_cross_entropy |
| 108 | + |
| 109 | + km = kbit_model |
| 110 | + lm = km._lm_head_info |
| 111 | + |
| 112 | + def loss_fn(hidden_2d, labels): |
| 113 | + """Compute chunked cross-entropy loss. |
| 114 | +
|
| 115 | + Args: |
| 116 | + hidden_2d: [B*S, H] hidden states from last stage. |
| 117 | + labels: [B, S] target token IDs. |
| 118 | + """ |
| 119 | + shift_hidden = hidden_2d[:-1] |
| 120 | + shift_labels = labels.reshape(-1)[1:] |
| 121 | + loss = chunked_cross_entropy( |
| 122 | + shift_hidden, lm["packed"], lm["absmax"], lm["codebook"], |
| 123 | + shift_labels, |
| 124 | + lm["k"], lm["K"], lm["N_padded"], lm["N"], |
| 125 | + km.compute_dtype, km.ce_chunk_size, |
| 126 | + ) |
| 127 | + return loss |
| 128 | + |
| 129 | + return loss_fn |
| 130 | + |
| 131 | + |
| 132 | +def main(): |
| 133 | + args = parse_args() |
| 134 | + |
| 135 | + dist.init_process_group(backend="nccl") |
| 136 | + rank = dist.get_rank() |
| 137 | + world_size = dist.get_world_size() |
| 138 | + device = torch.device(f"cuda:{rank}") |
| 139 | + torch.cuda.set_device(device) |
| 140 | + |
| 141 | + if rank == 0: |
| 142 | + print(f"{'=' * 60}") |
| 143 | + print(f"Pipeline QLoRA Training ({world_size} GPUs)") |
| 144 | + print(f"{'=' * 60}") |
| 145 | + print(f"Model: {args.model}") |
| 146 | + print(f"LoRA rank: {args.lora_r}, k={args.k}") |
| 147 | + print(f"Seq len: {args.seq_len}, Micro-batches: {args.micro_batches}") |
| 148 | + print(f"Steps: {args.steps}") |
| 149 | + print() |
| 150 | + |
| 151 | + # Load model |
| 152 | + from transformers import AutoModelForCausalLM |
| 153 | + |
| 154 | + if rank == 0: |
| 155 | + print("Loading base model...") |
| 156 | + model = AutoModelForCausalLM.from_pretrained( |
| 157 | + args.model, |
| 158 | + dtype=torch.float16, |
| 159 | + device_map={"": device}, |
| 160 | + trust_remote_code=True, |
| 161 | + ) |
| 162 | + |
| 163 | + # Quantize |
| 164 | + if rank == 0: |
| 165 | + print("Quantizing and creating LoRA adapters...") |
| 166 | + kbit_model = KbitLoraModel( |
| 167 | + model, |
| 168 | + lora_r=args.lora_r, |
| 169 | + lora_alpha=16.0, |
| 170 | + k=args.k, |
| 171 | + compute_dtype=torch.bfloat16, |
| 172 | + ) |
| 173 | + del model |
| 174 | + torch.cuda.empty_cache() |
| 175 | + |
| 176 | + num_layers = kbit_model.num_layers |
| 177 | + layers_per_stage = num_layers // world_size |
| 178 | + layer_start = rank * layers_per_stage |
| 179 | + layer_end = (rank + 1) * layers_per_stage if rank < world_size - 1 else num_layers |
| 180 | + |
| 181 | + is_first = (rank == 0) |
| 182 | + is_last = (rank == world_size - 1) |
| 183 | + |
| 184 | + if is_first: |
| 185 | + stage = KbitFirstStage(kbit_model, layer_start, layer_end) |
| 186 | + else: |
| 187 | + stage = KbitLastStage(kbit_model, layer_start, layer_end) |
| 188 | + |
| 189 | + if rank == 0: |
| 190 | + print(f" Total layers: {num_layers}") |
| 191 | + print(f" Trainable params: {kbit_model.num_trainable_parameters():,}") |
| 192 | + |
| 193 | + for r in range(world_size): |
| 194 | + if r == rank: |
| 195 | + ls = r * layers_per_stage |
| 196 | + le = (r + 1) * layers_per_stage if r < world_size - 1 else num_layers |
| 197 | + role = "first" if r == 0 else ("last" if r == world_size - 1 else "mid") |
| 198 | + print(f" GPU {r}: layers {ls}-{le-1} ({role} stage)") |
| 199 | + dist.barrier() |
| 200 | + |
| 201 | + # Loss function for the last stage |
| 202 | + loss_fn = make_loss_fn(kbit_model) if is_last else None |
| 203 | + |
| 204 | + # Hidden shape for inter-stage communication: [B, S, H] |
| 205 | + hidden_shape = (1, args.seq_len, kbit_model.hidden_size) |
| 206 | + |
| 207 | + # Pipeline engine |
| 208 | + engine = DistributedPipelineEngine( |
| 209 | + stage_module=stage, |
| 210 | + rank=rank, |
| 211 | + world_size=world_size, |
| 212 | + loss_fn=loss_fn, |
| 213 | + num_micro_batches=args.micro_batches, |
| 214 | + hidden_shape=hidden_shape, |
| 215 | + dtype=torch.bfloat16, |
| 216 | + ) |
| 217 | + |
| 218 | + # Optimizer — each rank has its own view of the parameters |
| 219 | + trainable_params = kbit_model.get_trainable_parameters() |
| 220 | + optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=0.01) |
| 221 | + |
| 222 | + # Training loop |
| 223 | + if rank == 0: |
| 224 | + print(f"\n{'=' * 60}") |
| 225 | + print("Training") |
| 226 | + print(f"{'=' * 60}") |
| 227 | + |
| 228 | + vocab_size = kbit_model.vocab_size |
| 229 | + losses = [] |
| 230 | + torch.cuda.reset_peak_memory_stats() |
| 231 | + |
| 232 | + for step in range(args.steps): |
| 233 | + t_step = time.time() |
| 234 | + optimizer.zero_grad() |
| 235 | + |
| 236 | + # Generate micro-batches (all ranks generate same data for labels) |
| 237 | + # Use deterministic seed per step so last rank has correct labels |
| 238 | + torch.manual_seed(step * 1000 + 42) |
| 239 | + micro_batch_inputs = [] |
| 240 | + micro_batch_labels = [] |
| 241 | + for mb in range(args.micro_batches): |
| 242 | + input_ids = torch.randint(0, vocab_size, (1, args.seq_len), device=device) |
| 243 | + labels = input_ids.clone() |
| 244 | + labels[:, :1] = -100 |
| 245 | + micro_batch_inputs.append(input_ids) |
| 246 | + micro_batch_labels.append(labels) |
| 247 | + |
| 248 | + # Run pipeline step |
| 249 | + result = engine.step( |
| 250 | + micro_batch_inputs=micro_batch_inputs if is_first else None, |
| 251 | + micro_batch_labels=micro_batch_labels if is_last else None, |
| 252 | + ) |
| 253 | + |
| 254 | + # Get loss from last rank |
| 255 | + loss_val = result["loss"] if is_last else 0.0 |
| 256 | + loss_tensor = torch.tensor([loss_val], device=device) |
| 257 | + dist.broadcast(loss_tensor, src=world_size - 1) |
| 258 | + loss_val = loss_tensor.item() |
| 259 | + |
| 260 | + optimizer.step() |
| 261 | + losses.append(loss_val) |
| 262 | + |
| 263 | + dt = time.time() - t_step |
| 264 | + tokens = args.micro_batches * args.seq_len |
| 265 | + |
| 266 | + if rank == 0 and (step % 5 == 0 or step == args.steps - 1): |
| 267 | + peak_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 |
| 268 | + print( |
| 269 | + f" Step {step:3d}/{args.steps} | " |
| 270 | + f"Loss: {loss_val:.4f} | " |
| 271 | + f"Time: {dt:.2f}s | " |
| 272 | + f"Tok/s: {tokens/dt:.0f} | " |
| 273 | + f"Peak mem: {peak_mb:.0f} MB" |
| 274 | + ) |
| 275 | + |
| 276 | + # Results |
| 277 | + if rank == 0: |
| 278 | + print(f"\n{'=' * 60}") |
| 279 | + print("Results") |
| 280 | + print(f"{'=' * 60}") |
| 281 | + print(f" Initial loss: {losses[0]:.4f}") |
| 282 | + print(f" Final loss: {losses[-1]:.4f}") |
| 283 | + |
| 284 | + if len(losses) >= 10: |
| 285 | + early = sum(losses[:5]) / 5 |
| 286 | + late = sum(losses[-5:]) / 5 |
| 287 | + if late < early: |
| 288 | + print(f" Loss DECREASED from {early:.4f} to {late:.4f} (PASS)") |
| 289 | + else: |
| 290 | + print(f" WARNING: Loss did not decrease ({early:.4f} -> {late:.4f})") |
| 291 | + |
| 292 | + # Report per-GPU peak memory |
| 293 | + peak = torch.tensor([torch.cuda.max_memory_allocated() / 1024 / 1024], device=device) |
| 294 | + peaks = [torch.zeros(1, device=device) for _ in range(world_size)] |
| 295 | + dist.all_gather(peaks, peak) |
| 296 | + if rank == 0: |
| 297 | + for r, p in enumerate(peaks): |
| 298 | + print(f" GPU {r} peak memory: {p.item():.0f} MB") |
| 299 | + |
| 300 | + dist.destroy_process_group() |
| 301 | + |
| 302 | + |
| 303 | +if __name__ == "__main__": |
| 304 | + main() |
0 commit comments