Skip to content

Train batch generic#724

Merged
felipemello1 merged 8 commits into
meta-pytorch:mainfrom
HosseinKaviani-H:TrainBatch_Generic
Jan 26, 2026
Merged

Train batch generic#724
felipemello1 merged 8 commits into
meta-pytorch:mainfrom
HosseinKaviani-H:TrainBatch_Generic

Conversation

@HosseinKaviani-H

@HosseinKaviani-H HosseinKaviani-H commented Jan 22, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds TrainBatch dataclass that separates model_inputs from loss_inputs, enabling any training paradigm without type changes.

Motivation

The current TextTrainBatch has limitations:

  • Hardcoded fields require changes for each new training mode
  • Text-only naming doesn't support multimodal
  • Every new paradigm (DPO, distillation, etc.) needs type updates

Solution

@dataclass
class TrainBatch:
    model_inputs: dict[str, Any]
    loss_inputs: dict[str, Any]
    meta: dict[str, Any] = field(default_factory=dict)

# Usage:
logits = model(**batch.model_inputs)
loss = loss_fn(logits, **batch.loss_inputs)

Files Changed

File: src/forge/types.py
Change: Added TrainBatch dataclass
────────────────────────────────────────
File: src/forge/rl/collate.py
Change: Updated to return list[TrainBatch] with model_inputs/loss_inputs
────────────────────────────────────────
File: src/forge/actors/trainer/titan.py
Change: Updated train_step() to accept list[TrainBatch] and unpack fields
────────────────────────────────────────
File: apps/grpo/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)
────────────────────────────────────────
File: tests/sandbox/rl_trainer/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)
────────────────────────────────────────
File: tests/sandbox/weight_sync/main.py
Change: Updated to pass batch directly: trainer.train_step.call(batch)

Test Plan

  • Core implementation: types.py, collate.py, titan.py, main.py
  • Update test files (tests/sandbox/)

Rewards and Losses

Tested the GRPO for 100 steps:

image image

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 22, 2026
Comment thread src/forge/api/trainer.py Outdated

@felipemello1 felipemello1 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think that this class should be in trainer.py. Probably in types.py or something like that. Are you also going to add it to collate and test it in this PR?

@joecummings

Copy link
Copy Markdown
Contributor

i dont think that this class should be in trainer.py. Probably in types.py or something like that. Are you also going to add it to collate and test it in this PR?

Why wouldn't this be in the trainer.py file under api? It defines the training API of which this is part. I would vote to keep it in the trainer API.

@felipemello1

Copy link
Copy Markdown
Contributor

Why wouldn't this be in the trainer.py file under api?

this is also used collate_fn. Not sure if it may be used in other places. I think we would be exposed to circular dependencies.

e.g. collate imports from train
train imports from X
X imports from collate

Also, thats what other frameworks do, like tinker: https://github.com/thinking-machines-lab/tinker/blob/ad03d44978096b1dcae662e469293e70f509d5a8/src/tinker/types/datum.py#L25

@joecummings

Copy link
Copy Markdown
Contributor

e.g. collate imports from train
train imports from X
X imports from collate

What would X be here? I will not hold up the PR on this point but am curious b/c I have a hard time imagining what that would be.

@felipemello1

felipemello1 commented Jan 23, 2026

Copy link
Copy Markdown
Contributor

What would X be here?

I will leave that as an exercise for the reader

jk, i guess it cannot happen if collate is its own file and doesnt really import from anywhere. It just makes more sense to me, given the patterns i have seen. But no big deal either way. Worst case we refactor later.

Comment thread src/forge/types.py Outdated
Comment thread src/forge/types.py Outdated

@felipemello1 felipemello1 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ty! but i would like to see main.py for 25-50 steps. Could you run it and share the rewards and loss sections?

HosseinKaviani-H and others added 2 commits January 26, 2026 08:55
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
@felipemello1 felipemello1 merged commit a3ae18b into meta-pytorch:main Jan 26, 2026
10 checks passed
HosseinKaviani-H added a commit to HosseinKaviani-H/forge that referenced this pull request Feb 9, 2026
Co-authored-by: Hossein Kavianihamedani <hosseinkh@fb.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants