|
| 1 | +"""1F1B Pipeline Parallelism Engine. |
| 2 | +
|
| 3 | +Custom implementation of one-forward-one-backward pipeline schedule for |
| 4 | +training large models across multiple stages. Each stage processes a |
| 5 | +subset of model layers. |
| 6 | +
|
| 7 | +Supports: |
| 8 | +- Single-process mode: all stages on one GPU (for testing) |
| 9 | +- Multi-process NCCL mode: one stage per GPU, activation transfer via isend/irecv |
| 10 | +
|
| 11 | +The 1F1B schedule minimizes peak activation memory by interleaving forward |
| 12 | +and backward passes, keeping at most (num_stages) micro-batches in flight. |
| 13 | +""" |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | + |
| 18 | + |
| 19 | +def generate_1f1b_schedule(num_stages, num_micro_batches): |
| 20 | + """Generate the 1F1B (one-forward-one-backward) pipeline schedule. |
| 21 | +
|
| 22 | + The schedule for each stage consists of: |
| 23 | + 1. Warmup phase: (num_stages - 1 - stage_id) forward passes |
| 24 | + 2. Steady state: alternating backward+forward (non-last) or forward+backward (last) |
| 25 | + 3. Cooldown phase: remaining backward passes |
| 26 | +
|
| 27 | + Args: |
| 28 | + num_stages: Number of pipeline stages. |
| 29 | + num_micro_batches: Number of micro-batches per training step. |
| 30 | + Must be >= num_stages. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + List of lists: schedule[stage_id] = [(op, micro_batch_id), ...] |
| 34 | + where op is 'F' (forward) or 'B' (backward). |
| 35 | + """ |
| 36 | + assert num_micro_batches >= num_stages, ( |
| 37 | + f"Need at least {num_stages} micro-batches for {num_stages} stages, " |
| 38 | + f"got {num_micro_batches}" |
| 39 | + ) |
| 40 | + |
| 41 | + S = num_stages |
| 42 | + M = num_micro_batches |
| 43 | + schedules = [[] for _ in range(S)] |
| 44 | + |
| 45 | + for s in range(S): |
| 46 | + warmup_forwards = S - 1 - s |
| 47 | + is_last_stage = s == S - 1 |
| 48 | + |
| 49 | + # Warmup: forward-only passes to fill the pipeline |
| 50 | + for m in range(warmup_forwards): |
| 51 | + schedules[s].append(("F", m)) |
| 52 | + |
| 53 | + # Steady state: interleave F and B |
| 54 | + f_idx = warmup_forwards |
| 55 | + b_idx = 0 |
| 56 | + num_steady = M - warmup_forwards |
| 57 | + |
| 58 | + for _ in range(num_steady): |
| 59 | + if is_last_stage: |
| 60 | + # Last stage: F then B (must receive activation before backward) |
| 61 | + schedules[s].append(("F", f_idx)) |
| 62 | + f_idx += 1 |
| 63 | + schedules[s].append(("B", b_idx)) |
| 64 | + b_idx += 1 |
| 65 | + else: |
| 66 | + # Non-last stages: B then F (drain before filling) |
| 67 | + schedules[s].append(("B", b_idx)) |
| 68 | + b_idx += 1 |
| 69 | + schedules[s].append(("F", f_idx)) |
| 70 | + f_idx += 1 |
| 71 | + |
| 72 | + # Cooldown: remaining backward passes |
| 73 | + while b_idx < M: |
| 74 | + schedules[s].append(("B", b_idx)) |
| 75 | + b_idx += 1 |
| 76 | + |
| 77 | + return schedules |
| 78 | + |
| 79 | + |
| 80 | +class PipelineEngine: |
| 81 | + """1F1B pipeline parallelism engine. |
| 82 | +
|
| 83 | + Splits a model into pipeline stages and executes them using the 1F1B |
| 84 | + schedule. Supports single-process mode for testing and multi-process |
| 85 | + NCCL mode for multi-GPU training. |
| 86 | +
|
| 87 | + The model must be provided as a list of stage callables. Each stage |
| 88 | + takes a hidden state tensor and returns the next hidden state. |
| 89 | + The last stage should include the loss computation. |
| 90 | +
|
| 91 | + Args: |
| 92 | + stage_modules: List of nn.Module instances, one per stage. |
| 93 | + Each module's forward takes (hidden_states,) and returns hidden_states. |
| 94 | + loss_fn: Loss function taking (last_stage_output, labels) -> scalar loss. |
| 95 | + Used only at the last stage. If None, the last stage must return the loss. |
| 96 | + num_micro_batches: Number of micro-batches per training step. |
| 97 | + device: Device for all stages (single-process mode). |
| 98 | + """ |
| 99 | + |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + stage_modules: list[nn.Module], |
| 103 | + loss_fn=None, |
| 104 | + num_micro_batches: int = 4, |
| 105 | + device: torch.device = None, |
| 106 | + ): |
| 107 | + self.stage_modules = stage_modules |
| 108 | + self.loss_fn = loss_fn |
| 109 | + self.num_stages = len(stage_modules) |
| 110 | + self.num_micro_batches = num_micro_batches |
| 111 | + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 112 | + |
| 113 | + self.schedule = generate_1f1b_schedule(self.num_stages, num_micro_batches) |
| 114 | + |
| 115 | + def step(self, micro_batch_inputs, micro_batch_labels=None): |
| 116 | + """Run one training step with 1F1B schedule (single-process mode). |
| 117 | +
|
| 118 | + Executes all stages sequentially in a single process, following the |
| 119 | + 1F1B schedule for correct ordering. At each schedule index: |
| 120 | + - Forward operations are processed left-to-right (stage 0 first) |
| 121 | + - Backward operations are processed right-to-left (last stage first) |
| 122 | +
|
| 123 | + This respects data dependencies: forward outputs flow left-to-right, |
| 124 | + backward gradients flow right-to-left. |
| 125 | +
|
| 126 | + Args: |
| 127 | + micro_batch_inputs: List of M input tensors, one per micro-batch. |
| 128 | + micro_batch_labels: List of M label tensors. Required if loss_fn is set. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + dict with: |
| 132 | + loss: Average loss across micro-batches (float). |
| 133 | + losses: List of per-micro-batch losses. |
| 134 | + """ |
| 135 | + S = self.num_stages |
| 136 | + M = self.num_micro_batches |
| 137 | + |
| 138 | + assert len(micro_batch_inputs) == M, ( |
| 139 | + f"Expected {M} micro-batch inputs, got {len(micro_batch_inputs)}" |
| 140 | + ) |
| 141 | + |
| 142 | + # Storage for intermediate activations |
| 143 | + # fwd_inputs[s][m] = input tensor to stage s for micro-batch m (requires_grad) |
| 144 | + # fwd_outputs[s][m] = output tensor from stage s for micro-batch m |
| 145 | + fwd_inputs = [[None] * M for _ in range(S)] |
| 146 | + fwd_outputs = [[None] * M for _ in range(S)] |
| 147 | + losses = [None] * M |
| 148 | + grad_inputs = [[None] * M for _ in range(S)] # gradients from backward |
| 149 | + |
| 150 | + # Execute the 1F1B schedule with proper dependency ordering |
| 151 | + max_ops = max(len(sched) for sched in self.schedule) |
| 152 | + |
| 153 | + for op_idx in range(max_ops): |
| 154 | + # Collect operations at this schedule index |
| 155 | + forward_ops = [] |
| 156 | + backward_ops = [] |
| 157 | + for s in range(S): |
| 158 | + if op_idx >= len(self.schedule[s]): |
| 159 | + continue |
| 160 | + op, m = self.schedule[s][op_idx] |
| 161 | + if op == "F": |
| 162 | + forward_ops.append((s, m)) |
| 163 | + else: |
| 164 | + backward_ops.append((s, m)) |
| 165 | + |
| 166 | + # Process forward operations left-to-right (stage 0 first) |
| 167 | + for s, m in sorted(forward_ops, key=lambda x: x[0]): |
| 168 | + self._forward_step(s, m, micro_batch_inputs, micro_batch_labels, |
| 169 | + fwd_inputs, fwd_outputs, losses) |
| 170 | + |
| 171 | + # Process backward operations right-to-left (last stage first) |
| 172 | + for s, m in sorted(backward_ops, key=lambda x: -x[0]): |
| 173 | + self._backward_step(s, m, fwd_inputs, fwd_outputs, losses, |
| 174 | + grad_inputs) |
| 175 | + |
| 176 | + # Compute average loss |
| 177 | + valid_losses = [l.item() for l in losses if l is not None] |
| 178 | + avg_loss = sum(valid_losses) / len(valid_losses) if valid_losses else 0.0 |
| 179 | + |
| 180 | + return { |
| 181 | + "loss": avg_loss, |
| 182 | + "losses": valid_losses, |
| 183 | + } |
| 184 | + |
| 185 | + def _forward_step(self, stage, micro_batch, inputs, labels, |
| 186 | + fwd_inputs, fwd_outputs, losses): |
| 187 | + """Execute one forward step for a stage and micro-batch.""" |
| 188 | + S = self.num_stages |
| 189 | + |
| 190 | + # Get input |
| 191 | + if stage == 0: |
| 192 | + # First stage: use the micro-batch input directly |
| 193 | + inp = inputs[micro_batch] |
| 194 | + else: |
| 195 | + # Get output from previous stage (detached for pipeline boundary) |
| 196 | + inp = fwd_outputs[stage - 1][micro_batch].detach() |
| 197 | + |
| 198 | + # Enable gradient tracking at stage boundaries |
| 199 | + inp = inp.requires_grad_(True) |
| 200 | + fwd_inputs[stage][micro_batch] = inp |
| 201 | + |
| 202 | + # Run forward through this stage's layers |
| 203 | + output = self.stage_modules[stage](inp) |
| 204 | + fwd_outputs[stage][micro_batch] = output |
| 205 | + |
| 206 | + # Last stage: compute loss |
| 207 | + if stage == S - 1 and self.loss_fn is not None and labels is not None: |
| 208 | + loss = self.loss_fn(output, labels[micro_batch]) |
| 209 | + losses[micro_batch] = loss |
| 210 | + |
| 211 | + def _backward_step(self, stage, micro_batch, fwd_inputs, fwd_outputs, |
| 212 | + losses, grad_inputs): |
| 213 | + """Execute one backward step for a stage and micro-batch.""" |
| 214 | + S = self.num_stages |
| 215 | + |
| 216 | + output = fwd_outputs[stage][micro_batch] |
| 217 | + inp = fwd_inputs[stage][micro_batch] |
| 218 | + |
| 219 | + if stage == S - 1: |
| 220 | + # Last stage: backward from loss |
| 221 | + if losses[micro_batch] is not None: |
| 222 | + # Scale loss by 1/M for gradient accumulation |
| 223 | + scaled_loss = losses[micro_batch] / self.num_micro_batches |
| 224 | + scaled_loss.backward(retain_graph=False) |
| 225 | + else: |
| 226 | + # If no loss_fn, backward on output directly |
| 227 | + output.backward( |
| 228 | + torch.ones_like(output) / self.num_micro_batches, |
| 229 | + retain_graph=False, |
| 230 | + ) |
| 231 | + else: |
| 232 | + # Non-last stage: backward using gradient from next stage |
| 233 | + grad_from_next = grad_inputs[stage + 1][micro_batch] |
| 234 | + if grad_from_next is not None: |
| 235 | + output.backward(grad_from_next, retain_graph=False) |
| 236 | + |
| 237 | + # Save input gradient for the previous stage |
| 238 | + if inp.grad is not None: |
| 239 | + grad_inputs[stage][micro_batch] = inp.grad.detach() |
| 240 | + |
| 241 | + def parameters(self): |
| 242 | + """Return all trainable parameters across all stages.""" |
| 243 | + for stage_module in self.stage_modules: |
| 244 | + yield from stage_module.parameters() |
| 245 | + |
| 246 | + @staticmethod |
| 247 | + def split_model_layers(layers, num_stages): |
| 248 | + """Split a list of layers evenly across stages. |
| 249 | +
|
| 250 | + Args: |
| 251 | + layers: List of nn.Module layers. |
| 252 | + num_stages: Number of pipeline stages. |
| 253 | +
|
| 254 | + Returns: |
| 255 | + List of lists: stage_layers[stage_id] = [layer1, layer2, ...] |
| 256 | + """ |
| 257 | + n = len(layers) |
| 258 | + assert n >= num_stages, ( |
| 259 | + f"Cannot split {n} layers into {num_stages} stages" |
| 260 | + ) |
| 261 | + |
| 262 | + # Even split with remainder going to earlier stages |
| 263 | + base = n // num_stages |
| 264 | + remainder = n % num_stages |
| 265 | + |
| 266 | + stage_layers = [] |
| 267 | + idx = 0 |
| 268 | + for s in range(num_stages): |
| 269 | + count = base + (1 if s < remainder else 0) |
| 270 | + stage_layers.append(layers[idx:idx + count]) |
| 271 | + idx += count |
| 272 | + |
| 273 | + return stage_layers |
| 274 | + |
| 275 | + |
| 276 | +class SequentialStage(nn.Module): |
| 277 | + """A pipeline stage that sequentially runs a list of layers. |
| 278 | +
|
| 279 | + Simple wrapper that takes a list of nn.Module layers and runs them |
| 280 | + in sequence. Used as the default stage module when splitting a model. |
| 281 | + """ |
| 282 | + |
| 283 | + def __init__(self, layers): |
| 284 | + super().__init__() |
| 285 | + self.layers = nn.ModuleList(layers) |
| 286 | + |
| 287 | + def forward(self, x): |
| 288 | + for layer in self.layers: |
| 289 | + x = layer(x) |
| 290 | + return x |
0 commit comments