Skip to content

Commit 70acd22

Browse files
feat(trainer): add dpo (#1190)
* feat(trainer): add DPO trainer with FSDP backend Add Direct Preference Optimization (Rafailov et al. 2023) as a new trainer. The policy is directly optimized to prefer chosen over rejected responses via a contrastive loss on log-probability ratios against a frozen reference model, removing the need for a separately trained reward model. Reference logprobs are computed online each step by a colocated ref engine, following the PPO/GRPO pattern. FSDP is the supported backend; Megatron and Archon variants raise NotImplementedError as placeholders. Verified on Qwen2.5-7B-Base + Anthropic/hh-rlhf (1 epoch, no SFT): reward_accuracy rises from 0.50 to ~0.70 and reward_margin grows monotonically, matching the original DPO paper's HH-RLHF results. * fix(trainer): fix DPO config forwarding, require ref model, and correct IPO normalization Fixes several issues found during PR review of the DPO trainer: Key changes: - Create DPOEngineConfig(TrainEngineConfig) embedding beta and loss_type, fixing silent parameter drop in single-controller mode (as_controller never forwarded beta/loss_type to workers) - Make ref a required field in DPOConfig (ref_logprobs are required at runtime, so config should enforce this upfront) - Remove zero-ref fallback in compute_dpo_loss; use input_["ref_logprobs"] directly - Add IPO loss with per-token length normalization matching TRL author- confirmed convention (normalize per-sequence logratios by completion length before the squared loss) - Remove all ref-is-None guard branches from DPOTrainer - Update docs, YAML config, and tests for all changes Refs: #1190 --------- Co-authored-by: 博惟 <bowei.fw@antgroup.com>
1 parent 5bd7a18 commit 70acd22

26 files changed

Lines changed: 2158 additions & 13 deletions

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ All RL algorithms support both asynchronous and synchronous versions by setting
207207
| **RLOO** | [📖 Docs](docs/en/algorithms/grpo_series.md) | [📄 Paper](https://arxiv.org/pdf/2402.14740v1) | [🔗 GSM8K Example](examples/math/gsm8k_rloo.yaml) |
208208
| **SAPO** | [📖 Docs](docs/en/algorithms/grpo_series.md) | [📄 Paper](https://arxiv.org/abs/2511.20347) | [🔗 GSM8K Example](examples/math/gsm8k_sapo.yaml) |
209209
| **M2PO** | [📖 Docs](docs/algorithms/m2po.md) | [📄 Paper](https://arxiv.org/abs/2510.01161) | [🔗 GSM8K Example](examples/math/gsm8k_m2po.yaml) |
210-
| **RLHF Reward Modeling** | - | - | [🔗 RLHF Example](examples/alignment/) |
210+
| **DPO** | [📖 Docs](docs/en/algorithms/dpo.md) | [📄 Paper](https://arxiv.org/abs/2305.18290) | [🔗 HH-RLHF Example](examples/alignment/hhrlhf_dpo.yaml) |
211+
| **RLHF Reward Modeling** | - | - | [🔗 RLHF Example](examples/alignment/hhrlhf_rw.yaml) |
211212
| **SFT** | - | - | [🔗 GSM8K Example](examples/math/gsm8k_sft.py) |
212213
| **Distillation** | [📖 Docs](docs/en/algorithms/distillation.md) | [📄 Paper](https://arxiv.org/pdf/2506.02208) | [🔗 GSM8K Example](examples/distillation/gsm8k_grpo_distill.yaml) |
213214

areal/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616

1717
def __getattr__(name: str):
18-
if name in ("PPOTrainer", "RWTrainer", "SFTTrainer"):
19-
from .trainer import PPOTrainer, RWTrainer, SFTTrainer
18+
if name in ("DPOTrainer", "PPOTrainer", "RWTrainer", "SFTTrainer"):
19+
from .trainer import DPOTrainer, PPOTrainer, RWTrainer, SFTTrainer
2020

2121
_map = {
22+
"DPOTrainer": DPOTrainer,
2223
"PPOTrainer": PPOTrainer,
2324
"RWTrainer": RWTrainer,
2425
"SFTTrainer": SFTTrainer,
@@ -29,6 +30,7 @@ def __getattr__(name: str):
2930

3031

3132
__all__ = [
33+
"DPOTrainer",
3234
"PPOTrainer",
3335
"RolloutController",
3436
"RWTrainer",

areal/api/cli_args.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,6 +2669,52 @@ def __post_init__(self):
26692669
)
26702670

26712671

2672+
@dataclass
2673+
class DPOEngineConfig(TrainEngineConfig):
2674+
"""Engine configuration for DPO training, extending TrainEngineConfig with DPO-specific fields."""
2675+
2676+
beta: float = field(
2677+
default=0.1,
2678+
metadata={"help": "KL penalty coefficient for DPO loss."},
2679+
)
2680+
2681+
loss_type: str = field(
2682+
default="sigmoid",
2683+
metadata={
2684+
"help": "DPO loss variant. "
2685+
"'sigmoid': original DPO loss (Rafailov et al. 2023). "
2686+
"'ipo': Identity Preference Optimization with per-token length normalization (Azar et al. 2023).",
2687+
"choices": ["sigmoid", "ipo"],
2688+
},
2689+
)
2690+
2691+
def __post_init__(self):
2692+
super().__post_init__()
2693+
_valid = {"sigmoid", "ipo"}
2694+
if self.loss_type not in _valid:
2695+
raise ValueError(
2696+
f"Unsupported DPO loss_type '{self.loss_type}'. "
2697+
f"Must be one of {sorted(_valid)}."
2698+
)
2699+
2700+
2701+
@dataclass
2702+
class DPOConfig(BaseExperimentConfig):
2703+
"""Configuration for Direct Preference Optimization (DPO) experiments."""
2704+
2705+
actor: DPOEngineConfig = field(default_factory=DPOEngineConfig)
2706+
2707+
ref: DPOEngineConfig = field(default_factory=DPOEngineConfig)
2708+
2709+
def __post_init__(self):
2710+
super().__post_init__()
2711+
if getattr(self.actor, "is_critic", False):
2712+
raise ValueError(
2713+
"DPOConfig requires a language model (is_critic=False). "
2714+
"Remove 'actor.is_critic: true' from your YAML config."
2715+
)
2716+
2717+
26722718
@dataclass
26732719
class TeacherConfig(PPOActorConfig):
26742720
rl_loss_weight: float = field(

areal/dataset/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ def _get_custom_dataset(
113113
max_length=max_length,
114114
**kwargs,
115115
)
116+
elif "hh-rlhf" in path and type == "dpo":
117+
from .hhrlhf import get_hhrlhf_dpo_dataset
118+
119+
return get_hhrlhf_dpo_dataset(
120+
path=path,
121+
split=split,
122+
tokenizer=tokenizer,
123+
max_length=max_length,
124+
**kwargs,
125+
)
116126
elif "torl_data" in path and type == "rl":
117127
from .torl_data import get_torl_data_rl_dataset
118128

areal/dataset/hhrlhf.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,51 @@ def process(sample):
2828
)
2929

3030
return dataset
31+
32+
33+
def get_hhrlhf_dpo_dataset(
34+
path: str,
35+
split: str,
36+
tokenizer,
37+
max_length: int | None = None,
38+
):
39+
"""Load HH-RLHF dataset for DPO training.
40+
41+
Each sample will contain:
42+
- ``chosen_ids`` / ``rejected_ids``: full token ids (prompt + response).
43+
- ``chosen_loss_mask`` / ``rejected_loss_mask``: boolean mask where ``True``
44+
marks the response tokens that participate in the loss.
45+
46+
Reference log-probabilities are computed online by the ref engine during
47+
training (configured via the ``ref`` field in ``DPOConfig``).
48+
"""
49+
dataset = load_dataset(path=path, split=split)
50+
51+
def process(sample):
52+
chosen_ids = tokenizer.encode(sample["chosen"] + tokenizer.eos_token)
53+
rejected_ids = tokenizer.encode(sample["rejected"] + tokenizer.eos_token)
54+
55+
prompt_len = 0
56+
for c, r in zip(chosen_ids, rejected_ids):
57+
if c == r:
58+
prompt_len += 1
59+
else:
60+
break
61+
62+
return {
63+
"chosen_ids": chosen_ids,
64+
"rejected_ids": rejected_ids,
65+
"chosen_loss_mask": [0] * prompt_len + [1] * (len(chosen_ids) - prompt_len),
66+
"rejected_loss_mask": [0] * prompt_len
67+
+ [1] * (len(rejected_ids) - prompt_len),
68+
}
69+
70+
dataset = dataset.map(process).remove_columns(["chosen", "rejected"])
71+
72+
if max_length is not None:
73+
dataset = dataset.filter(
74+
lambda x: (len(x["chosen_ids"]) <= max_length)
75+
and (len(x["rejected_ids"]) <= max_length)
76+
)
77+
78+
return dataset

areal/engine/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
"FSDPPPOCritic",
77
"FSDPLMEngine",
88
"FSDPRWEngine",
9+
"FSDPDPOEngine",
910
"MegatronEngine",
1011
"MegatronPPOActor",
1112
"MegatronPPOCritic",
1213
"MegatronLMEngine",
1314
"MegatronRWEngine",
15+
"MegatronDPOEngine",
1416
"RemoteSGLangEngine",
1517
"RemotevLLMEngine",
1618
]
@@ -21,11 +23,13 @@
2123
"FSDPPPOCritic": "areal.engine.fsdp_engine",
2224
"FSDPLMEngine": "areal.engine.fsdp_engine",
2325
"FSDPRWEngine": "areal.engine.fsdp_engine",
26+
"FSDPDPOEngine": "areal.engine.fsdp_engine",
2427
"MegatronEngine": "areal.engine.megatron_engine",
2528
"MegatronPPOActor": "areal.engine.megatron_engine",
2629
"MegatronPPOCritic": "areal.engine.megatron_engine",
2730
"MegatronLMEngine": "areal.engine.megatron_engine",
2831
"MegatronRWEngine": "areal.engine.megatron_engine",
32+
"MegatronDPOEngine": "areal.engine.megatron_engine",
2933
"RemoteSGLangEngine": "areal.engine.sglang_remote",
3034
"RemotevLLMEngine": "areal.engine.vllm_remote",
3135
}

areal/engine/fsdp_engine.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131

132132
if TYPE_CHECKING:
133133
from areal.api import Scheduler
134-
from areal.api.cli_args import PPOActorConfig, PPOCriticConfig
134+
from areal.api.cli_args import DPOEngineConfig, PPOActorConfig, PPOCriticConfig
135135

136136

137137
@dataclasses.dataclass
@@ -1966,3 +1966,44 @@ def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
19661966
from areal.trainer.rw.rw_engine import RWController
19671967

19681968
return RWController(train_engine=cls, config=config, scheduler=scheduler)
1969+
1970+
1971+
class FSDPDPOEngine(FSDPEngine):
1972+
"""DPO training engine using FSDP backend."""
1973+
1974+
def __init__(self, config: DPOEngineConfig):
1975+
from copy import deepcopy
1976+
1977+
from areal.trainer.dpo.dpo_engine import DPOEngine
1978+
1979+
super().__init__(config)
1980+
self.dpo_engine = DPOEngine(self)
1981+
if self.config.mb_spec.granularity != 2:
1982+
dpo_logger = logging.getLogger("DPOEngine")
1983+
dpo_logger.warning("mb_spec.granularity must be 2 for DPO training")
1984+
self.config = deepcopy(self.config)
1985+
self.config.mb_spec.granularity = 2
1986+
1987+
def train_dpo(self, data):
1988+
return self.dpo_engine.train_dpo(data)
1989+
1990+
def evaluate_dpo(self, data):
1991+
return self.dpo_engine.evaluate_dpo(data)
1992+
1993+
def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None:
1994+
return self.dpo_engine.compute_logp(data)
1995+
1996+
@classmethod
1997+
def as_controller(
1998+
cls,
1999+
config: DPOEngineConfig,
2000+
scheduler: Scheduler,
2001+
):
2002+
if config._version == "v2":
2003+
from areal.trainer.dpo.dpo_engine import DPOControllerV2
2004+
2005+
return DPOControllerV2(train_engine=cls, config=config, scheduler=scheduler)
2006+
2007+
from areal.trainer.dpo.dpo_engine import DPOController
2008+
2009+
return DPOController(train_engine=cls, config=config, scheduler=scheduler)

areal/engine/megatron_engine.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113

114114
if TYPE_CHECKING:
115115
from areal.api import Scheduler
116-
from areal.api.cli_args import PPOActorConfig, PPOCriticConfig
116+
from areal.api.cli_args import DPOEngineConfig, PPOActorConfig, PPOCriticConfig
117117

118118

119119
class _MegatronModelList(list):
@@ -1987,3 +1987,44 @@ def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
19871987
from areal.trainer.rw.rw_engine import RWController
19881988

19891989
return RWController(train_engine=cls, config=config, scheduler=scheduler)
1990+
1991+
1992+
class MegatronDPOEngine(MegatronEngine):
1993+
"""DPO training engine using Megatron backend."""
1994+
1995+
def __init__(self, config: DPOEngineConfig):
1996+
from copy import deepcopy
1997+
1998+
from areal.trainer.dpo.dpo_engine import DPOEngine
1999+
2000+
super().__init__(config)
2001+
self.dpo_engine = DPOEngine(self)
2002+
if self.config.mb_spec.granularity != 2:
2003+
dpo_logger = logging.getLogger("DPOEngine")
2004+
dpo_logger.warning("mb_spec.granularity must be 2 for DPO training")
2005+
self.config = deepcopy(self.config)
2006+
self.config.mb_spec.granularity = 2
2007+
2008+
def train_dpo(self, data):
2009+
return self.dpo_engine.train_dpo(data)
2010+
2011+
def evaluate_dpo(self, data):
2012+
return self.dpo_engine.evaluate_dpo(data)
2013+
2014+
def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None:
2015+
return self.dpo_engine.compute_logp(data)
2016+
2017+
@classmethod
2018+
def as_controller(
2019+
cls,
2020+
config: DPOEngineConfig,
2021+
scheduler: Scheduler,
2022+
):
2023+
if config._version == "v2":
2024+
from areal.trainer.dpo.dpo_engine import DPOControllerV2
2025+
2026+
return DPOControllerV2(train_engine=cls, config=config, scheduler=scheduler)
2027+
2028+
from areal.trainer.dpo.dpo_engine import DPOController
2029+
2030+
return DPOController(train_engine=cls, config=config, scheduler=scheduler)

areal/experimental/engine/archon_engine.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
from torchdata.stateful_dataloader import StatefulDataLoader
114114

115115
from areal.api import InferenceEngine, Scheduler, WorkflowLike
116-
from areal.api.cli_args import PerfTracerConfig, TrainEngineConfig
116+
from areal.api.cli_args import DPOEngineConfig, PerfTracerConfig, TrainEngineConfig
117117
from areal.experimental.engine.archon_runner import ForwardBackwardRunner
118118

119119

@@ -1476,3 +1476,44 @@ def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
14761476
from areal.trainer.rw.rw_engine import RWController
14771477

14781478
return RWController(train_engine=cls, config=config, scheduler=scheduler)
1479+
1480+
1481+
class ArchonDPOEngine(ArchonEngine):
1482+
"""Archon-based DPO Engine for direct preference optimization."""
1483+
1484+
def __init__(self, config: DPOEngineConfig):
1485+
from copy import deepcopy
1486+
1487+
from areal.trainer.dpo.dpo_engine import DPOEngine
1488+
1489+
super().__init__(config)
1490+
self.dpo_engine = DPOEngine(self)
1491+
if self.config.mb_spec.granularity != 2:
1492+
dpo_logger = logging.getLogger("DPOEngine")
1493+
dpo_logger.warning("mb_spec.granularity must be 2 for DPO training")
1494+
self.config = deepcopy(self.config)
1495+
self.config.mb_spec.granularity = 2
1496+
1497+
def train_dpo(self, data):
1498+
return self.dpo_engine.train_dpo(data)
1499+
1500+
def evaluate_dpo(self, data):
1501+
return self.dpo_engine.evaluate_dpo(data)
1502+
1503+
def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None:
1504+
return self.dpo_engine.compute_logp(data)
1505+
1506+
@classmethod
1507+
def as_controller(
1508+
cls,
1509+
config: DPOEngineConfig,
1510+
scheduler: Scheduler,
1511+
):
1512+
if config._version == "v2":
1513+
from areal.trainer.dpo.dpo_engine import DPOControllerV2
1514+
1515+
return DPOControllerV2(train_engine=cls, config=config, scheduler=scheduler)
1516+
1517+
from areal.trainer.dpo.dpo_engine import DPOController
1518+
1519+
return DPOController(train_engine=cls, config=config, scheduler=scheduler)

areal/trainer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from .dpo_trainer import DPOTrainer
34
from .rl_trainer import PPOTrainer
45
from .rw_trainer import RWTrainer
56
from .sft_trainer import SFTTrainer
67

7-
__all__ = ["PPOTrainer", "RWTrainer", "SFTTrainer"]
8+
__all__ = ["DPOTrainer", "PPOTrainer", "RWTrainer", "SFTTrainer"]

0 commit comments

Comments
 (0)