Skip to content

feat(trainer): add dpo#1190

Merged
garrett4wade merged 2 commits intoinclusionAI:mainfrom
HT-Yuan:feature/dpo-trainer
Apr 24, 2026
Merged

feat(trainer): add dpo#1190
garrett4wade merged 2 commits intoinclusionAI:mainfrom
HT-Yuan:feature/dpo-trainer

Conversation

@HT-Yuan
Copy link
Copy Markdown
Contributor

@HT-Yuan HT-Yuan commented Apr 16, 2026

Description

Add Direct Preference Optimization (DPO) training to AReaL as a
first-class alignment method, alongside the existing Reward Model (RM)
training on Anthropic/hh-rlhf.

Highlights:

  • New trainer DPOTrainer (exported from areal) wired to the
    existing TrainController / single-controller runtime. Ref-model
    log-probs are computed online per step via a colocated ref engine
    (same pattern PPO/GRPO use), so no pre-computed logprobs need to be
    stored on disk.
  • New engine layer areal/trainer/dpo/dpo_engine.py:
    • DPOEngine (train / eval / compute_logp)
    • DPOController (single-controller) and DPOControllerV2
      (gateway / V2 controller) dispatchers, both using group_size=2
      so each (chosen, rejected) pair stays on the same DP rank.
    • compute_dpo_loss loss function compatible with the engine
      loss_fn contract (packed, interleaved [chosen_0, rejected_0, chosen_1, rejected_1, ...]).
  • New functional ops in areal/utils/functional/functional.py:
    • dpo_pair_logratios — reduces per-token logprobs to per-pair
      (policy, ref) × (chosen, rejected) sequence logprobs.
    • dpo_preference_loss — supports loss_type ∈ {"sigmoid", "ipo"}
      (Rafailov et al. 2023 / Azar et al. 2023). Both are re-exported
      from areal.utils.functional.
  • New config DPOConfig in areal/api/cli_args.py, with ref:
    engine config and DPO hyperparameters (beta, loss_type).
  • New dataset loader areal/dataset/hhrlhf.py :: get_hhrlhf_dpo_dataset that infers the prompt/response boundary via
    the longest common token prefix between chosen and rejected, and
    marks only response tokens in loss_mask.
  • New example under examples/alignment/:
    hhrlhf_dpo.py + hhrlhf_dpo.yaml, and a ## Direct Preference Optimization (DPO) section in examples/alignment/README.md with
    the loss in LaTeX and a training curve (dpo_curve.png).
  • Tests:
    • tests/test_hhrlhf_dataset.py — offline sanity tests for the HH-RLHF
      DPO loader (mocks datasets.load_dataset; no network).
    • DPO loss / logratio unit tests covering sigmoid and ipo
      variants, zero-ref fallback, and invalid-pair masking.

Stats emitted per step (via stats_tracker): loss, chosen_reward,
rejected_reward, reward_accuracy, reward_margin, denominated by
n_pairs, so the dashboard matches the canonical DPO signature from the
original paper.

Housekeeping: the __all__ group-label comment in
areal/utils/functional/__init__.py was also updated from
# logprobs.py to # vocab_parallel.py to match the actual module name
(the file was renamed upstream; this just fixes the now-stale comment).

Related Issue

Fixes #1137

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Direct Preference Optimization (DPO) support to the framework, including the DPOTrainer, DPO-specific engine implementations for FSDP, Megatron, and Archon backends, and dataset processing for the HH-RLHF dataset. The review feedback identifies critical issues: potential crashes in DPOEngine due to incorrect batch type handling in training and evaluation paths, an incorrect dataset size calculation in DPOTrainer that would lead to improper learning rate scheduling, and an inefficient Python loop in the DPO loss computation that should be vectorized to improve performance and numerical stability.

Comment thread areal/trainer/dpo/dpo_engine.py Outdated

def _train_dpo(self, data: dict[str, Any]) -> None:
"""Train on a batch (DPO)."""
if _dpo_loss_weight(data) == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _dpo_loss_weight function expects a dictionary containing cu_seqlens, but data here is a list[dict[str, Any]] (the raw batch from the dataloader). This will cause a TypeError when trying to access data["cu_seqlens"]. Since the goal is to skip empty batches and log placeholder stats, you should check the list length instead.

Suggested change
if _dpo_loss_weight(data) == 0:
if not data:
_log_empty_dpo_stats(current_platform.current_device())
return

Comment thread areal/trainer/dpo/dpo_engine.py Outdated
batched_call(self._evaluate_dpo, data, unpack=False)

def _evaluate_dpo(self, data: dict[str, Any]) -> None:
if _dpo_loss_weight(data) == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the training path, _dpo_loss_weight will crash here because data is a list of dictionaries, not a packed dictionary with cu_seqlens.

Suggested change
if _dpo_loss_weight(data) == 0:
if not data:
_log_empty_dpo_stats(current_platform.current_device())
return


ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dataset_size is calculated using the length of the sharded dataloader, which represents the number of samples per rank. This will cause FinetuneSpec to compute an incorrect total_train_steps (underestimated by a factor of world_size), leading to incorrect learning rate scheduling and premature training termination. You should use the total dataset size instead.

Suggested change
dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size,
dataset_size=len(train_dataset),

Comment thread areal/trainer/dpo/dpo_engine.py Outdated
Comment on lines +174 to +185
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu()
n_seqs = seqlens.shape[0]

policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)

for i in range(n_seqs):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
m = loss_mask[start:end]
policy_logps[i] = torch.where(m, logprobs[start:end], 0.0).sum()
ref_logps[i] = torch.where(m, ref_logprobs[start:end], 0.0).sum()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This Python loop over sequences in the packed batch is inefficient and causes multiple GPU-CPU synchronizations. Furthermore, summing log-probabilities in low precision (e.g., bfloat16) can lead to numerical instability for long sequences. It is highly recommended to vectorize this operation and perform the summation in float64 to maintain precision.

    n_seqs = cu_seqlens.numel() - 1
    seq_ids = torch.zeros(logprobs.shape[0], dtype=torch.long, device=device)
    seq_ids.scatter_(0, cu_seqlens[1:-1].long(), 1)
    seq_ids = seq_ids.cumsum(dim=0)

    policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
    policy_logps.index_add_(0, seq_ids, torch.where(loss_mask, logprobs, 0.0).to(torch.float64))

    ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
    ref_logps.index_add_(0, seq_ids, torch.where(loss_mask, ref_logprobs, 0.0).to(torch.float64))

@HT-Yuan HT-Yuan force-pushed the feature/dpo-trainer branch from 1c85889 to 7c8f3af Compare April 16, 2026 06:58
@HT-Yuan HT-Yuan marked this pull request as draft April 16, 2026 07:03
@HT-Yuan HT-Yuan force-pushed the feature/dpo-trainer branch 3 times, most recently from 4ae8b57 to 04d1d05 Compare April 23, 2026 04:00
@HT-Yuan HT-Yuan marked this pull request as ready for review April 23, 2026 04:03
@HT-Yuan HT-Yuan changed the title [WIP]Feature/dpo trainer Feature(trainer): add dpo Apr 23, 2026
@HT-Yuan HT-Yuan changed the title Feature(trainer): add dpo feat(trainer): add dpo Apr 23, 2026
@HT-Yuan HT-Yuan force-pushed the feature/dpo-trainer branch from 04d1d05 to 641008c Compare April 23, 2026 04:20
Add Direct Preference Optimization (Rafailov et al. 2023) as a new
trainer. The policy is directly optimized to prefer chosen over
rejected responses via a contrastive loss on log-probability ratios
against a frozen reference model, removing the need for a separately
trained reward model.

Reference logprobs are computed online each step by a colocated ref
engine, following the PPO/GRPO pattern. FSDP is the supported backend;
Megatron and Archon variants raise NotImplementedError as placeholders.

Verified on Qwen2.5-7B-Base + Anthropic/hh-rlhf (1 epoch, no SFT):
reward_accuracy rises from 0.50 to ~0.70 and reward_margin grows
monotonically, matching the original DPO paper's HH-RLHF results.
@HT-Yuan HT-Yuan force-pushed the feature/dpo-trainer branch from 641008c to 2274616 Compare April 23, 2026 04:29
…ct IPO normalization

Fixes several issues found during PR review of the DPO trainer:

Key changes:
- Create DPOEngineConfig(TrainEngineConfig) embedding beta and loss_type,
  fixing silent parameter drop in single-controller mode (as_controller
  never forwarded beta/loss_type to workers)
- Make ref a required field in DPOConfig (ref_logprobs are required at
  runtime, so config should enforce this upfront)
- Remove zero-ref fallback in compute_dpo_loss; use input_["ref_logprobs"]
  directly
- Add IPO loss with per-token length normalization matching TRL author-
  confirmed convention (normalize per-sequence logratios by completion
  length before the squared loss)
- Remove all ref-is-None guard branches from DPOTrainer
- Update docs, YAML config, and tests for all changes

Refs: inclusionAI#1190
@garrett4wade
Copy link
Copy Markdown
Collaborator

Review fixes pushed to feature/dpo-trainer-fixes

The following fixes from the review have been implemented and pushed to feature/dpo-trainer-fixes (commit f867819 on top of this PR's HEAD):

Changes

  1. DPOEngineConfig — Created DPOEngineConfig(TrainEngineConfig) embedding beta and loss_type, fixing a critical bug where as_controller silently dropped these parameters in single-controller mode (Ray/Slurm)
  2. ref required — Made ref a required field in DPOConfig (was optional but runtime required ref_logprobs)
  3. Zero-ref fallback removedcompute_dpo_loss now uses input_["ref_logprobs"] directly (KeyError if missing)
  4. IPO length normalization — Added IPO loss variant with per-token length normalization matching TRL's author-confirmed convention
  5. Docs/YAML/tests updated — All 26 tests pass (16 sigmoid + 10 IPO)

Please cherry-pick or merge this branch into your PR.

Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@garrett4wade garrett4wade merged commit 70acd22 into inclusionAI:main Apr 24, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DPO algo implementation

2 participants