|
26 | 26 | from forge.data.tokenizer import HuggingFaceModelTokenizer |
27 | 27 | from forge.data.utils import StopAfterOneEpoch |
28 | 28 | from forge.observability import get_or_create_metric_logger, record_metric, Reduce |
| 29 | +from forge.types import TrainBatch |
29 | 30 | from forge.util.config import parse |
30 | 31 | from monarch.actor import current_rank, current_size, endpoint |
31 | 32 | from omegaconf import DictConfig, OmegaConf |
@@ -213,16 +214,16 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader: |
213 | 214 |
|
214 | 215 | def forward_backward( |
215 | 216 | self, |
216 | | - input_dict: dict[str, torch.Tensor], |
217 | | - labels: torch.Tensor, |
| 217 | + batch: TrainBatch, |
218 | 218 | skip_backward: bool = False, |
219 | 219 | ) -> torch.Tensor: |
220 | 220 | model_parts = self.model_parts |
221 | 221 | parallel_dims = self.parallel_dims |
222 | 222 |
|
223 | 223 | # apply context parallelism if cp is enabled |
224 | 224 | # ensure CP handles the separate freqs_cis buffer for each pp stage |
225 | | - inputs = input_dict["tokens"] |
| 225 | + inputs = batch.model_inputs["tokens"] |
| 226 | + labels = batch.loss_inputs["labels"] |
226 | 227 | optional_context_parallel_ctx = ( |
227 | 228 | dist_utils.create_context_parallel_ctx( |
228 | 229 | cp_mesh=parallel_dims.world_mesh["cp"], |
@@ -283,7 +284,11 @@ def train_step(self, batch) -> None: |
283 | 284 | # ) as grad_acc: |
284 | 285 | parallel_dims = self.parallel_dims |
285 | 286 | labels = batch.pop("labels") |
286 | | - loss = self.forward_backward(batch, labels) |
| 287 | + train_batch = TrainBatch( |
| 288 | + model_inputs=batch, |
| 289 | + loss_inputs={"labels": labels}, |
| 290 | + ) |
| 291 | + loss = self.forward_backward(train_batch) |
287 | 292 |
|
288 | 293 | grad_norm = dist_utils.clip_grad_norm_( |
289 | 294 | [p for m in self.model_parts for p in m.parameters()], |
@@ -373,7 +378,11 @@ async def evaluate(self) -> None: |
373 | 378 |
|
374 | 379 | # Process batch |
375 | 380 | labels = batch.pop("labels") |
376 | | - loss = self.forward_backward(batch, labels, skip_backward=True) |
| 381 | + train_batch = TrainBatch( |
| 382 | + model_inputs=batch, |
| 383 | + loss_inputs={"labels": labels}, |
| 384 | + ) |
| 385 | + loss = self.forward_backward(train_batch, skip_backward=True) |
377 | 386 | total_loss += loss |
378 | 387 | num_steps += 1 |
379 | 388 |
|
|
0 commit comments