Train batch generic#724
Conversation
felipemello1
left a comment
There was a problem hiding this comment.
i dont think that this class should be in trainer.py. Probably in types.py or something like that. Are you also going to add it to collate and test it in this PR?
Why wouldn't this be in the trainer.py file under api? It defines the training API of which this is part. I would vote to keep it in the trainer API. |
81e475d to
34af55b
Compare
this is also used collate_fn. Not sure if it may be used in other places. I think we would be exposed to circular dependencies. e.g. collate imports from train Also, thats what other frameworks do, like tinker: https://github.com/thinking-machines-lab/tinker/blob/ad03d44978096b1dcae662e469293e70f509d5a8/src/tinker/types/datum.py#L25 |
What would X be here? I will not hold up the PR on this point but am curious b/c I have a hard time imagining what that would be. |
I will leave that as an exercise for the reader jk, i guess it cannot happen if collate is its own file and doesnt really import from anywhere. It just makes more sense to me, given the patterns i have seen. But no big deal either way. Worst case we refactor later. |
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Hossein Kavianihamedani <hosseinkh@fb.com> Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Summary
Adds
TrainBatchdataclass that separatesmodel_inputsfromloss_inputs, enabling any training paradigm without type changes.Motivation
The current
TextTrainBatchhas limitations:Solution
Files Changed
File: src/forge/types.py
Change: Added TrainBatch dataclass
────────────────────────────────────────
File: src/forge/rl/collate.py
Change: Updated to return list[TrainBatch] with model_inputs/loss_inputs
────────────────────────────────────────
File: src/forge/actors/trainer/titan.py
Change: Updated train_step() to accept list[TrainBatch] and unpack fields
────────────────────────────────────────
File: apps/grpo/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)
────────────────────────────────────────
File: tests/sandbox/rl_trainer/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)
────────────────────────────────────────
File: tests/sandbox/weight_sync/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)
Test Plan
Rewards and Losses
Tested the GRPO for 100 steps: