Skip to content

Commit ecb54cc

Browse files
author
Hossein Kavianihamedani
committed
Refactor SFT forward_backward to use TrainBatch
Replace separate input_dict and labels parameters with unified TrainBatch dataclass, aligning SFT with GRPO's training interface. - Import TrainBatch from forge.types - Update forward_backward signature to accept TrainBatch - Update train_step and evaluate to create TrainBatch from batch dict
1 parent 58bf8e3 commit ecb54cc

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

apps/sft/main.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from forge.data.tokenizer import HuggingFaceModelTokenizer
2727
from forge.data.utils import StopAfterOneEpoch
2828
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
29+
from forge.types import TrainBatch
2930
from forge.util.config import parse
3031
from monarch.actor import current_rank, current_size, endpoint
3132
from omegaconf import DictConfig, OmegaConf
@@ -213,16 +214,16 @@ def setup_data(self, dataset_configs: list[dict]) -> StatefulDataLoader:
213214

214215
def forward_backward(
215216
self,
216-
input_dict: dict[str, torch.Tensor],
217-
labels: torch.Tensor,
217+
batch: TrainBatch,
218218
skip_backward: bool = False,
219219
) -> torch.Tensor:
220220
model_parts = self.model_parts
221221
parallel_dims = self.parallel_dims
222222

223223
# apply context parallelism if cp is enabled
224224
# ensure CP handles the separate freqs_cis buffer for each pp stage
225-
inputs = input_dict["tokens"]
225+
inputs = batch.model_inputs["tokens"]
226+
labels = batch.loss_inputs["labels"]
226227
optional_context_parallel_ctx = (
227228
dist_utils.create_context_parallel_ctx(
228229
cp_mesh=parallel_dims.world_mesh["cp"],
@@ -283,7 +284,11 @@ def train_step(self, batch) -> None:
283284
# ) as grad_acc:
284285
parallel_dims = self.parallel_dims
285286
labels = batch.pop("labels")
286-
loss = self.forward_backward(batch, labels)
287+
train_batch = TrainBatch(
288+
model_inputs=batch,
289+
loss_inputs={"labels": labels},
290+
)
291+
loss = self.forward_backward(train_batch)
287292

288293
grad_norm = dist_utils.clip_grad_norm_(
289294
[p for m in self.model_parts for p in m.parameters()],
@@ -373,7 +378,11 @@ async def evaluate(self) -> None:
373378

374379
# Process batch
375380
labels = batch.pop("labels")
376-
loss = self.forward_backward(batch, labels, skip_backward=True)
381+
train_batch = TrainBatch(
382+
model_inputs=batch,
383+
loss_inputs={"labels": labels},
384+
)
385+
loss = self.forward_backward(train_batch, skip_backward=True)
377386
total_loss += loss
378387
num_steps += 1
379388

0 commit comments

Comments
 (0)