feat(trainer): add dpo#1190
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Direct Preference Optimization (DPO) support to the framework, including the DPOTrainer, DPO-specific engine implementations for FSDP, Megatron, and Archon backends, and dataset processing for the HH-RLHF dataset. The review feedback identifies critical issues: potential crashes in DPOEngine due to incorrect batch type handling in training and evaluation paths, an incorrect dataset size calculation in DPOTrainer that would lead to improper learning rate scheduling, and an inefficient Python loop in the DPO loss computation that should be vectorized to improve performance and numerical stability.
|
|
||
| def _train_dpo(self, data: dict[str, Any]) -> None: | ||
| """Train on a batch (DPO).""" | ||
| if _dpo_loss_weight(data) == 0: |
There was a problem hiding this comment.
The _dpo_loss_weight function expects a dictionary containing cu_seqlens, but data here is a list[dict[str, Any]] (the raw batch from the dataloader). This will cause a TypeError when trying to access data["cu_seqlens"]. Since the goal is to skip empty batches and log placeholder stats, you should check the list length instead.
| if _dpo_loss_weight(data) == 0: | |
| if not data: | |
| _log_empty_dpo_stats(current_platform.current_device()) | |
| return |
| batched_call(self._evaluate_dpo, data, unpack=False) | ||
|
|
||
| def _evaluate_dpo(self, data: dict[str, Any]) -> None: | ||
| if _dpo_loss_weight(data) == 0: |
There was a problem hiding this comment.
|
|
||
| ft_spec = FinetuneSpec( | ||
| total_train_epochs=config.total_train_epochs, | ||
| dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size, |
There was a problem hiding this comment.
The dataset_size is calculated using the length of the sharded dataloader, which represents the number of samples per rank. This will cause FinetuneSpec to compute an incorrect total_train_steps (underestimated by a factor of world_size), leading to incorrect learning rate scheduling and premature training termination. You should use the total dataset size instead.
| dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size, | |
| dataset_size=len(train_dataset), |
| seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu() | ||
| n_seqs = seqlens.shape[0] | ||
|
|
||
| policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device) | ||
| ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device) | ||
|
|
||
| for i in range(n_seqs): | ||
| start = cu_seqlens[i] | ||
| end = cu_seqlens[i + 1] | ||
| m = loss_mask[start:end] | ||
| policy_logps[i] = torch.where(m, logprobs[start:end], 0.0).sum() | ||
| ref_logps[i] = torch.where(m, ref_logprobs[start:end], 0.0).sum() |
There was a problem hiding this comment.
This Python loop over sequences in the packed batch is inefficient and causes multiple GPU-CPU synchronizations. Furthermore, summing log-probabilities in low precision (e.g., bfloat16) can lead to numerical instability for long sequences. It is highly recommended to vectorize this operation and perform the summation in float64 to maintain precision.
n_seqs = cu_seqlens.numel() - 1
seq_ids = torch.zeros(logprobs.shape[0], dtype=torch.long, device=device)
seq_ids.scatter_(0, cu_seqlens[1:-1].long(), 1)
seq_ids = seq_ids.cumsum(dim=0)
policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
policy_logps.index_add_(0, seq_ids, torch.where(loss_mask, logprobs, 0.0).to(torch.float64))
ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
ref_logps.index_add_(0, seq_ids, torch.where(loss_mask, ref_logprobs, 0.0).to(torch.float64))1c85889 to
7c8f3af
Compare
4ae8b57 to
04d1d05
Compare
04d1d05 to
641008c
Compare
Add Direct Preference Optimization (Rafailov et al. 2023) as a new trainer. The policy is directly optimized to prefer chosen over rejected responses via a contrastive loss on log-probability ratios against a frozen reference model, removing the need for a separately trained reward model. Reference logprobs are computed online each step by a colocated ref engine, following the PPO/GRPO pattern. FSDP is the supported backend; Megatron and Archon variants raise NotImplementedError as placeholders. Verified on Qwen2.5-7B-Base + Anthropic/hh-rlhf (1 epoch, no SFT): reward_accuracy rises from 0.50 to ~0.70 and reward_margin grows monotonically, matching the original DPO paper's HH-RLHF results.
641008c to
2274616
Compare
…ct IPO normalization Fixes several issues found during PR review of the DPO trainer: Key changes: - Create DPOEngineConfig(TrainEngineConfig) embedding beta and loss_type, fixing silent parameter drop in single-controller mode (as_controller never forwarded beta/loss_type to workers) - Make ref a required field in DPOConfig (ref_logprobs are required at runtime, so config should enforce this upfront) - Remove zero-ref fallback in compute_dpo_loss; use input_["ref_logprobs"] directly - Add IPO loss with per-token length normalization matching TRL author- confirmed convention (normalize per-sequence logratios by completion length before the squared loss) - Remove all ref-is-None guard branches from DPOTrainer - Update docs, YAML config, and tests for all changes Refs: inclusionAI#1190
Review fixes pushed to
|
Description
Add Direct Preference Optimization (DPO) training to AReaL as a
first-class alignment method, alongside the existing Reward Model (RM)
training on
Anthropic/hh-rlhf.Highlights:
DPOTrainer(exported fromareal) wired to theexisting
TrainController/ single-controller runtime. Ref-modellog-probs are computed online per step via a colocated ref engine
(same pattern PPO/GRPO use), so no pre-computed logprobs need to be
stored on disk.
areal/trainer/dpo/dpo_engine.py:DPOEngine(train / eval /compute_logp)DPOController(single-controller) andDPOControllerV2(gateway / V2 controller) dispatchers, both using
group_size=2so each (chosen, rejected) pair stays on the same DP rank.
compute_dpo_lossloss function compatible with the engineloss_fncontract (packed, interleaved[chosen_0, rejected_0, chosen_1, rejected_1, ...]).areal/utils/functional/functional.py:dpo_pair_logratios— reduces per-token logprobs to per-pair(policy, ref) × (chosen, rejected)sequence logprobs.dpo_preference_loss— supportsloss_type ∈ {"sigmoid", "ipo"}(Rafailov et al. 2023 / Azar et al. 2023). Both are re-exported
from
areal.utils.functional.DPOConfiginareal/api/cli_args.py, withref:engine config and DPO hyperparameters (
beta,loss_type).areal/dataset/hhrlhf.py :: get_hhrlhf_dpo_datasetthat infers the prompt/response boundary viathe longest common token prefix between
chosenandrejected, andmarks only response tokens in
loss_mask.examples/alignment/:hhrlhf_dpo.py+hhrlhf_dpo.yaml, and a## Direct Preference Optimization (DPO)section inexamples/alignment/README.mdwiththe loss in LaTeX and a training curve (
dpo_curve.png).tests/test_hhrlhf_dataset.py— offline sanity tests for the HH-RLHFDPO loader (mocks
datasets.load_dataset; no network).sigmoidandipovariants, zero-ref fallback, and invalid-pair masking.
Stats emitted per step (via
stats_tracker):loss,chosen_reward,rejected_reward,reward_accuracy,reward_margin, denominated byn_pairs, so the dashboard matches the canonical DPO signature from theoriginal paper.
Housekeeping: the
__all__group-label comment inareal/utils/functional/__init__.pywas also updated from# logprobs.pyto# vocab_parallel.pyto match the actual module name(the file was renamed upstream; this just fixes the now-stale comment).
Related Issue
Fixes #1137
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!