11"""Pipeline parallelism training example using bitsandbytes kbit quantization.
22
3- Demonstrates distributed pipeline training across 2+ GPUs:
4- - Loads a HuggingFace model and applies KbitLoraModel
5- - Splits decoder layers across GPUs (first stage = embedding + first layers,
6- last stage = remaining layers + norm + LM head)
7- - Trains using DistributedPipelineEngine with NCCL
8- - Reports per-GPU memory and throughput
3+ Demonstrates distributed pipeline training across 2+ GPUs with per-rank
4+ model loading — each GPU only loads the decoder layers it needs:
5+ - First stage: embedding + first half of layers
6+ - Last stage: remaining layers + final norm + LM head (loss )
7+
8+ This reduces per-GPU memory compared to loading the full model everywhere.
99
1010Usage:
1111 # 2-GPU pipeline training on Qwen3-0.6B
@@ -50,21 +50,20 @@ class KbitFirstStage(nn.Module):
5050 """First pipeline stage: embedding + first layers.
5151
5252 Takes input_ids [B, S], returns hidden states [B, S, H].
53+ The KbitLoraModel has already been created with only this stage's layers.
5354 """
5455
55- def __init__ (self , kbit_model , layer_start , layer_end ):
56+ def __init__ (self , kbit_model ):
5657 super ().__init__ ()
5758 self .km = kbit_model
58- self .layer_start = layer_start
59- self .layer_end = layer_end
6059
6160 def forward (self , input_ids ):
6261 B , S = input_ids .shape
6362 device = input_ids .device
6463 position_ids = torch .arange (S , device = device ).unsqueeze (0 ).expand (B , - 1 )
6564 self .km ._extend_rope_cache (S , device )
6665 hidden = self .km .embed_tokens (input_ids ).to (self .km .compute_dtype )
67- for i in range (self .layer_start , self . layer_end ):
66+ for i in range (self .km . _num_loaded_layers ):
6867 hidden = self .km ._layer_forward (i , hidden , position_ids )
6968 return hidden
7069
@@ -74,13 +73,13 @@ class KbitLastStage(nn.Module):
7473
7574 Takes hidden states [B, S, H], returns hidden states after norm [B*S, H].
7675 Loss is computed externally by the engine's loss_fn.
76+ The KbitLoraModel has already been created with only this stage's layers
77+ plus the final norm and LM head.
7778 """
7879
79- def __init__ (self , kbit_model , layer_start , layer_end ):
80+ def __init__ (self , kbit_model ):
8081 super ().__init__ ()
8182 self .km = kbit_model
82- self .layer_start = layer_start
83- self .layer_end = layer_end
8483
8584 def forward (self , hidden ):
8685 from bitsandbytes .autograd .training_kernels import rmsnorm
@@ -90,7 +89,7 @@ def forward(self, hidden):
9089 position_ids = torch .arange (S , device = device ).unsqueeze (0 ).expand (B , - 1 )
9190 self .km ._extend_rope_cache (S , device )
9291
93- for i in range (self .layer_start , self . layer_end ):
92+ for i in range (self .km . _num_loaded_layers ):
9493 hidden = self .km ._layer_forward (i , hidden , position_ids )
9594
9695 # Final norm
@@ -110,12 +109,6 @@ def make_loss_fn(kbit_model):
110109 lm = km ._lm_head_info
111110
112111 def loss_fn (hidden_2d , labels ):
113- """Compute chunked cross-entropy loss.
114-
115- Args:
116- hidden_2d: [B*S, H] hidden states from last stage.
117- labels: [B, S] target token IDs.
118- """
119112 shift_hidden = hidden_2d [:- 1 ]
120113 shift_labels = labels .reshape (- 1 )[1 :]
121114 loss = chunked_cross_entropy (
@@ -138,65 +131,76 @@ def main():
138131 device = torch .device (f"cuda:{ rank } " )
139132 torch .cuda .set_device (device )
140133
134+ is_first = (rank == 0 )
135+ is_last = (rank == world_size - 1 )
136+
141137 if rank == 0 :
142138 print (f"{ '=' * 60 } " )
143- print (f"Pipeline QLoRA Training ({ world_size } GPUs)" )
139+ print (f"Pipeline QLoRA Training ({ world_size } GPUs, per-rank loading )" )
144140 print (f"{ '=' * 60 } " )
145141 print (f"Model: { args .model } " )
146142 print (f"LoRA rank: { args .lora_r } , k={ args .k } " )
147143 print (f"Seq len: { args .seq_len } , Micro-batches: { args .micro_batches } " )
148144 print (f"Steps: { args .steps } " )
149145 print ()
150146
151- # Load model
152- from transformers import AutoModelForCausalLM
147+ # Load model — each rank loads the full HF model temporarily to extract
148+ # its layer weights. We immediately delete the original after quantization.
149+ from transformers import AutoModelForCausalLM , AutoConfig
150+
151+ config = AutoConfig .from_pretrained (args .model , trust_remote_code = True )
152+ num_layers = config .num_hidden_layers
153+ layers_per_stage = num_layers // world_size
154+ layer_start = rank * layers_per_stage
155+ layer_end = (rank + 1 ) * layers_per_stage if rank < world_size - 1 else num_layers
156+
157+ role = "first" if is_first else ("last" if is_last else "mid" )
158+ print (f" GPU { rank } : layers { layer_start } -{ layer_end - 1 } ({ role } stage)" )
153159
154160 if rank == 0 :
155- print ("Loading base model..." )
161+ print (f"\n Loading and quantizing (per-rank)..." )
162+ torch .cuda .reset_peak_memory_stats ()
163+
156164 model = AutoModelForCausalLM .from_pretrained (
157165 args .model ,
158166 dtype = torch .float16 ,
159167 device_map = {"" : device },
160168 trust_remote_code = True ,
161169 )
162170
163- # Quantize
164- if rank == 0 :
165- print ("Quantizing and creating LoRA adapters..." )
171+ mem_after_load = torch .cuda .memory_allocated () / 1024 / 1024
172+ print (f" GPU { rank } : { mem_after_load :.0f} MB after HF model load" )
173+
174+ # Create KbitLoraModel with ONLY this rank's layers
166175 kbit_model = KbitLoraModel (
167176 model ,
168177 lora_r = args .lora_r ,
169178 lora_alpha = 16.0 ,
170179 k = args .k ,
171180 compute_dtype = torch .bfloat16 ,
181+ layer_range = (layer_start , layer_end ),
182+ include_embed = is_first ,
183+ include_lm_head = is_last ,
172184 )
185+
186+ # Delete the original HF model to free memory
173187 del model
174188 torch .cuda .empty_cache ()
175189
176- num_layers = kbit_model .num_layers
177- layers_per_stage = num_layers // world_size
178- layer_start = rank * layers_per_stage
179- layer_end = (rank + 1 ) * layers_per_stage if rank < world_size - 1 else num_layers
190+ mem_after_quant = torch .cuda .memory_allocated () / 1024 / 1024
191+ print (f" GPU { rank } : { mem_after_quant :.0f} MB after quantize + cleanup "
192+ f"({ kbit_model ._num_loaded_layers } layers, "
193+ f"embed={ 'yes' if is_first else 'no' } , "
194+ f"lm_head={ 'yes' if is_last else 'no' } )" )
180195
181- is_first = ( rank == 0 )
182- is_last = ( rank == world_size - 1 )
196+ if rank == 0 :
197+ print ( f" Trainable params (rank 0): { kbit_model . num_trainable_parameters ():, } " )
183198
199+ # Create pipeline stage wrappers
184200 if is_first :
185- stage = KbitFirstStage (kbit_model , layer_start , layer_end )
201+ stage = KbitFirstStage (kbit_model )
186202 else :
187- stage = KbitLastStage (kbit_model , layer_start , layer_end )
188-
189- if rank == 0 :
190- print (f" Total layers: { num_layers } " )
191- print (f" Trainable params: { kbit_model .num_trainable_parameters ():,} " )
192-
193- for r in range (world_size ):
194- if r == rank :
195- ls = r * layers_per_stage
196- le = (r + 1 ) * layers_per_stage if r < world_size - 1 else num_layers
197- role = "first" if r == 0 else ("last" if r == world_size - 1 else "mid" )
198- print (f" GPU { r } : layers { ls } -{ le - 1 } ({ role } stage)" )
199- dist .barrier ()
203+ stage = KbitLastStage (kbit_model )
200204
201205 # Loss function for the last stage
202206 loss_fn = make_loss_fn (kbit_model ) if is_last else None
@@ -215,7 +219,7 @@ def main():
215219 dtype = torch .bfloat16 ,
216220 )
217221
218- # Optimizer — each rank has its own view of the parameters
222+ # Optimizer — each rank optimizes only its own trainable parameters
219223 trainable_params = kbit_model .get_trainable_parameters ()
220224 optimizer = torch .optim .AdamW (trainable_params , lr = args .lr , weight_decay = 0.01 )
221225
@@ -233,8 +237,7 @@ def main():
233237 t_step = time .time ()
234238 optimizer .zero_grad ()
235239
236- # Generate micro-batches (all ranks generate same data for labels)
237- # Use deterministic seed per step so last rank has correct labels
240+ # All ranks generate same data with same seed (for label consistency)
238241 torch .manual_seed (step * 1000 + 42 )
239242 micro_batch_inputs = []
240243 micro_batch_labels = []
0 commit comments