Skip to content

Commit 641008c

Browse files
committed
feat(trainer): add DPO trainer with FSDP backend
Add Direct Preference Optimization (Rafailov et al. 2023) as a new trainer. The policy is directly optimized to prefer chosen over rejected responses via a contrastive loss on log-probability ratios against a frozen reference model, removing the need for a separately trained reward model. Reference logprobs are computed online each step by a colocated ref engine, following the PPO/GRPO pattern. FSDP is the supported backend; Megatron and Archon variants raise NotImplementedError as placeholders. Verified on Qwen2.5-7B-Base + Anthropic/hh-rlhf (1 epoch, no SFT): reward_accuracy rises from 0.50 to ~0.70 and reward_margin grows monotonically, matching the original DPO paper's HH-RLHF results.
1 parent e6f3c3c commit 641008c

24 files changed

Lines changed: 2051 additions & 10 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ All RL algorithms support both asynchronous and synchronous versions by setting
197197
| **RLOO** | [📖 Docs](docs/en/algorithms/grpo_series.md) | [📄 Paper](https://arxiv.org/pdf/2402.14740v1) | [🔗 GSM8K Example](examples/math/gsm8k_rloo.yaml) |
198198
| **SAPO** | [📖 Docs](docs/en/algorithms/grpo_series.md) | [📄 Paper](https://arxiv.org/abs/2511.20347) | [🔗 GSM8K Example](examples/math/gsm8k_sapo.yaml) |
199199
| **M2PO** | [📖 Docs](docs/algorithms/m2po.md) | [📄 Paper](https://arxiv.org/abs/2510.01161) | [🔗 GSM8K Example](examples/math/gsm8k_m2po.yaml) |
200-
| **RLHF Reward Modeling** | - | - | [🔗 RLHF Example](examples/alignment/) |
200+
| **DPO** | [📖 Docs](docs/en/algorithms/dpo.md) | [📄 Paper](https://arxiv.org/abs/2305.18290) | [🔗 HH-RLHF Example](examples/alignment/hhrlhf_dpo.yaml) |
201+
| **RLHF Reward Modeling** | - | - | [🔗 RLHF Example](examples/alignment/hhrlhf_rw.yaml) |
201202
| **SFT** | - | - | [🔗 GSM8K Example](examples/math/gsm8k_sft.py) |
202203
| **Distillation** | [📖 Docs](docs/en/algorithms/distillation.md) | [📄 Paper](https://arxiv.org/pdf/2506.02208) | [🔗 GSM8K Example](examples/distillation/gsm8k_grpo_distill.yaml) |
203204

areal/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616

1717
def __getattr__(name: str):
18-
if name in ("PPOTrainer", "RWTrainer", "SFTTrainer"):
19-
from .trainer import PPOTrainer, RWTrainer, SFTTrainer
18+
if name in ("DPOTrainer", "PPOTrainer", "RWTrainer", "SFTTrainer"):
19+
from .trainer import DPOTrainer, PPOTrainer, RWTrainer, SFTTrainer
2020

2121
_map = {
22+
"DPOTrainer": DPOTrainer,
2223
"PPOTrainer": PPOTrainer,
2324
"RWTrainer": RWTrainer,
2425
"SFTTrainer": SFTTrainer,
@@ -29,6 +30,7 @@ def __getattr__(name: str):
2930

3031

3132
__all__ = [
33+
"DPOTrainer",
3234
"PPOTrainer",
3335
"RolloutController",
3436
"RWTrainer",

areal/api/cli_args.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,6 +2637,51 @@ def __post_init__(self):
26372637
)
26382638

26392639

2640+
@dataclass
2641+
class DPOConfig(BaseExperimentConfig):
2642+
"""Configuration for Direct Preference Optimization (DPO) experiments."""
2643+
2644+
actor: TrainEngineConfig = field(default_factory=TrainEngineConfig)
2645+
2646+
ref: TrainEngineConfig | None = field(
2647+
default=None,
2648+
metadata={
2649+
"help": "Reference model configuration for DPO. "
2650+
"The ref model computes reference log-probabilities online during training. "
2651+
"If None, ref_logprobs default to zeros (degenerates to contrastive logprob loss)."
2652+
},
2653+
)
2654+
2655+
beta: float = field(
2656+
default=0.1,
2657+
metadata={"help": "KL penalty coefficient for DPO loss."},
2658+
)
2659+
2660+
loss_type: str = field(
2661+
default="sigmoid",
2662+
metadata={
2663+
"help": "DPO loss variant. "
2664+
"'sigmoid': original DPO loss (Rafailov et al. 2023). "
2665+
"'ipo': Identity Preference Optimization, uses squared loss (Azar et al. 2023).",
2666+
"choices": ["sigmoid", "ipo"],
2667+
},
2668+
)
2669+
2670+
def __post_init__(self):
2671+
super().__post_init__()
2672+
if getattr(self.actor, "is_critic", False):
2673+
raise ValueError(
2674+
"DPOConfig requires a language model (is_critic=False). "
2675+
"Remove 'actor.is_critic: true' from your YAML config."
2676+
)
2677+
_valid_loss_types = {"sigmoid", "ipo"}
2678+
if self.loss_type not in _valid_loss_types:
2679+
raise ValueError(
2680+
f"Unsupported DPO loss_type '{self.loss_type}'. "
2681+
f"Must be one of {sorted(_valid_loss_types)}."
2682+
)
2683+
2684+
26402685
@dataclass
26412686
class TeacherConfig(PPOActorConfig):
26422687
rl_loss_weight: float = field(

areal/dataset/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ def _get_custom_dataset(
113113
max_length=max_length,
114114
**kwargs,
115115
)
116+
elif "hh-rlhf" in path and type == "dpo":
117+
from .hhrlhf import get_hhrlhf_dpo_dataset
118+
119+
return get_hhrlhf_dpo_dataset(
120+
path=path,
121+
split=split,
122+
tokenizer=tokenizer,
123+
max_length=max_length,
124+
**kwargs,
125+
)
116126
elif "torl_data" in path and type == "rl":
117127
from .torl_data import get_torl_data_rl_dataset
118128

areal/dataset/hhrlhf.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,51 @@ def process(sample):
2626
)
2727

2828
return dataset
29+
30+
31+
def get_hhrlhf_dpo_dataset(
32+
path: str,
33+
split: str,
34+
tokenizer,
35+
max_length: int | None = None,
36+
):
37+
"""Load HH-RLHF dataset for DPO training.
38+
39+
Each sample will contain:
40+
- ``chosen_ids`` / ``rejected_ids``: full token ids (prompt + response).
41+
- ``chosen_loss_mask`` / ``rejected_loss_mask``: boolean mask where ``True``
42+
marks the response tokens that participate in the loss.
43+
44+
Reference log-probabilities are computed online by the ref engine during
45+
training (configured via the ``ref`` field in ``DPOConfig``).
46+
"""
47+
dataset = load_dataset(path=path, split=split)
48+
49+
def process(sample):
50+
chosen_ids = tokenizer.encode(sample["chosen"] + tokenizer.eos_token)
51+
rejected_ids = tokenizer.encode(sample["rejected"] + tokenizer.eos_token)
52+
53+
prompt_len = 0
54+
for c, r in zip(chosen_ids, rejected_ids):
55+
if c == r:
56+
prompt_len += 1
57+
else:
58+
break
59+
60+
return {
61+
"chosen_ids": chosen_ids,
62+
"rejected_ids": rejected_ids,
63+
"chosen_loss_mask": [0] * prompt_len + [1] * (len(chosen_ids) - prompt_len),
64+
"rejected_loss_mask": [0] * prompt_len
65+
+ [1] * (len(rejected_ids) - prompt_len),
66+
}
67+
68+
dataset = dataset.map(process).remove_columns(["chosen", "rejected"])
69+
70+
if max_length is not None:
71+
dataset = dataset.filter(
72+
lambda x: (len(x["chosen_ids"]) <= max_length)
73+
and (len(x["rejected_ids"]) <= max_length)
74+
)
75+
76+
return dataset

areal/engine/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
"FSDPPPOCritic",
77
"FSDPLMEngine",
88
"FSDPRWEngine",
9+
"FSDPDPOEngine",
910
"MegatronEngine",
1011
"MegatronPPOActor",
1112
"MegatronPPOCritic",
1213
"MegatronLMEngine",
1314
"MegatronRWEngine",
15+
"MegatronDPOEngine",
1416
"RemoteSGLangEngine",
1517
"RemotevLLMEngine",
1618
]
@@ -21,11 +23,13 @@
2123
"FSDPPPOCritic": "areal.engine.fsdp_engine",
2224
"FSDPLMEngine": "areal.engine.fsdp_engine",
2325
"FSDPRWEngine": "areal.engine.fsdp_engine",
26+
"FSDPDPOEngine": "areal.engine.fsdp_engine",
2427
"MegatronEngine": "areal.engine.megatron_engine",
2528
"MegatronPPOActor": "areal.engine.megatron_engine",
2629
"MegatronPPOCritic": "areal.engine.megatron_engine",
2730
"MegatronLMEngine": "areal.engine.megatron_engine",
2831
"MegatronRWEngine": "areal.engine.megatron_engine",
32+
"MegatronDPOEngine": "areal.engine.megatron_engine",
2933
"RemoteSGLangEngine": "areal.engine.sglang_remote",
3034
"RemotevLLMEngine": "areal.engine.vllm_remote",
3135
}

areal/engine/fsdp_engine.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,3 +1958,51 @@ def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
19581958
from areal.trainer.rw.rw_engine import RWController
19591959

19601960
return RWController(train_engine=cls, config=config, scheduler=scheduler)
1961+
1962+
1963+
class FSDPDPOEngine(FSDPEngine):
1964+
"""DPO training engine using FSDP backend."""
1965+
1966+
def __init__(
1967+
self,
1968+
config: TrainEngineConfig,
1969+
beta: float = 0.1,
1970+
loss_type: str = "sigmoid",
1971+
):
1972+
from copy import deepcopy
1973+
1974+
from areal.trainer.dpo.dpo_engine import DPOEngine
1975+
1976+
super().__init__(config)
1977+
self.dpo_engine = DPOEngine(self, beta=beta, loss_type=loss_type)
1978+
if self.config.mb_spec.granularity != 2:
1979+
dpo_logger = logging.getLogger("DPOEngine")
1980+
dpo_logger.warning("mb_spec.granularity must be 2 for DPO training")
1981+
self.config = deepcopy(self.config)
1982+
self.config.mb_spec.granularity = 2
1983+
1984+
def train_dpo(self, data):
1985+
return self.dpo_engine.train_dpo(data)
1986+
1987+
def evaluate_dpo(self, data):
1988+
return self.dpo_engine.evaluate_dpo(data)
1989+
1990+
def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None:
1991+
return self.dpo_engine.compute_logp(data)
1992+
1993+
@classmethod
1994+
def as_controller(
1995+
cls,
1996+
config: TrainEngineConfig,
1997+
scheduler: Scheduler,
1998+
beta: float = 0.1,
1999+
loss_type: str = "sigmoid",
2000+
):
2001+
if config._version == "v2":
2002+
from areal.trainer.dpo.dpo_engine import DPOControllerV2
2003+
2004+
return DPOControllerV2(train_engine=cls, config=config, scheduler=scheduler)
2005+
2006+
from areal.trainer.dpo.dpo_engine import DPOController
2007+
2008+
return DPOController(train_engine=cls, config=config, scheduler=scheduler)

areal/engine/megatron_engine.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,3 +1930,51 @@ def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
19301930
from areal.trainer.rw.rw_engine import RWController
19311931

19321932
return RWController(train_engine=cls, config=config, scheduler=scheduler)
1933+
1934+
1935+
class MegatronDPOEngine(MegatronEngine):
1936+
"""DPO training engine using Megatron backend."""
1937+
1938+
def __init__(
1939+
self,
1940+
config: TrainEngineConfig,
1941+
beta: float = 0.1,
1942+
loss_type: str = "sigmoid",
1943+
):
1944+
from copy import deepcopy
1945+
1946+
from areal.trainer.dpo.dpo_engine import DPOEngine
1947+
1948+
super().__init__(config)
1949+
self.dpo_engine = DPOEngine(self, beta=beta, loss_type=loss_type)
1950+
if self.config.mb_spec.granularity != 2:
1951+
dpo_logger = logging.getLogger("DPOEngine")
1952+
dpo_logger.warning("mb_spec.granularity must be 2 for DPO training")
1953+
self.config = deepcopy(self.config)
1954+
self.config.mb_spec.granularity = 2
1955+
1956+
def train_dpo(self, data):
1957+
return self.dpo_engine.train_dpo(data)
1958+
1959+
def evaluate_dpo(self, data):
1960+
return self.dpo_engine.evaluate_dpo(data)
1961+
1962+
def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None:
1963+
return self.dpo_engine.compute_logp(data)
1964+
1965+
@classmethod
1966+
def as_controller(
1967+
cls,
1968+
config: TrainEngineConfig,
1969+
scheduler: Scheduler,
1970+
beta: float = 0.1,
1971+
loss_type: str = "sigmoid",
1972+
):
1973+
if config._version == "v2":
1974+
from areal.trainer.dpo.dpo_engine import DPOControllerV2
1975+
1976+
return DPOControllerV2(train_engine=cls, config=config, scheduler=scheduler)
1977+
1978+
from areal.trainer.dpo.dpo_engine import DPOController
1979+
1980+
return DPOController(train_engine=cls, config=config, scheduler=scheduler)

areal/experimental/engine/archon_engine.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,3 +1468,51 @@ def as_controller(cls, config: TrainEngineConfig, scheduler: Scheduler):
14681468
from areal.trainer.rw.rw_engine import RWController
14691469

14701470
return RWController(train_engine=cls, config=config, scheduler=scheduler)
1471+
1472+
1473+
class ArchonDPOEngine(ArchonEngine):
1474+
"""Archon-based DPO Engine for direct preference optimization."""
1475+
1476+
def __init__(
1477+
self,
1478+
config: TrainEngineConfig,
1479+
beta: float = 0.1,
1480+
loss_type: str = "sigmoid",
1481+
):
1482+
from copy import deepcopy
1483+
1484+
from areal.trainer.dpo.dpo_engine import DPOEngine
1485+
1486+
super().__init__(config)
1487+
self.dpo_engine = DPOEngine(self, beta=beta, loss_type=loss_type)
1488+
if self.config.mb_spec.granularity != 2:
1489+
dpo_logger = logging.getLogger("DPOEngine")
1490+
dpo_logger.warning("mb_spec.granularity must be 2 for DPO training")
1491+
self.config = deepcopy(self.config)
1492+
self.config.mb_spec.granularity = 2
1493+
1494+
def train_dpo(self, data):
1495+
return self.dpo_engine.train_dpo(data)
1496+
1497+
def evaluate_dpo(self, data):
1498+
return self.dpo_engine.evaluate_dpo(data)
1499+
1500+
def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None:
1501+
return self.dpo_engine.compute_logp(data)
1502+
1503+
@classmethod
1504+
def as_controller(
1505+
cls,
1506+
config: TrainEngineConfig,
1507+
scheduler: Scheduler,
1508+
beta: float = 0.1,
1509+
loss_type: str = "sigmoid",
1510+
):
1511+
if config._version == "v2":
1512+
from areal.trainer.dpo.dpo_engine import DPOControllerV2
1513+
1514+
return DPOControllerV2(train_engine=cls, config=config, scheduler=scheduler)
1515+
1516+
from areal.trainer.dpo.dpo_engine import DPOController
1517+
1518+
return DPOController(train_engine=cls, config=config, scheduler=scheduler)

areal/trainer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from .dpo_trainer import DPOTrainer
34
from .rl_trainer import PPOTrainer
45
from .rw_trainer import RWTrainer
56
from .sft_trainer import SFTTrainer
67

7-
__all__ = ["PPOTrainer", "RWTrainer", "SFTTrainer"]
8+
__all__ = ["DPOTrainer", "PPOTrainer", "RWTrainer", "SFTTrainer"]

0 commit comments

Comments
 (0)