Skip to content

Commit 60bc7f9

Browse files
Merge pull request #3895 from AI-Hypercomputer:igorts/dpo-input-processing
PiperOrigin-RevId: 918547107
2 parents a1fb834 + d59a15e commit 60bc7f9

7 files changed

Lines changed: 553 additions & 27 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ checkpoint_conversion_fn: none
8686
# optional checkpoint context to use for loading. options: "orbax", "safetensors"
8787
source_checkpoint_layout: "orbax"
8888

89-
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
89+
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
9090
colocated_python_checkpointing: False
9191

9292
# enables autocheckpoint, which saves a checkpoint at the preemption step.
@@ -451,7 +451,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
451451
# internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc.
452452
internal_compile: False
453453
internal_compile_num_devices: -1 # You must specify the number of devices when using internal_compile.
454-
compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536"
454+
compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536"
455455

456456
# Parallelism
457457
shard_mode: "auto" # can be either auto or explicit
@@ -564,8 +564,8 @@ logical_axis_rules: [
564564
# ==========================================
565565
# Deprecated / Scheduled for Removal
566566
# ==========================================
567-
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
568-
['embed_tensor_transpose', ['tensor_transpose']],
567+
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
568+
['embed_tensor_transpose', ['tensor_transpose']],
569569
['exp_with_fsdp', 'fsdp'],
570570
]
571571
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
@@ -680,8 +680,6 @@ global_rampup_samples: 500
680680

681681
# direct preference optimization (DPO)
682682
use_dpo: False
683-
dpo_label_smoothing: 0.0
684-
dpo_beta: 0.1
685683

686684
# Supervised Fine-Tuning (SFT)
687685
use_sft: False
@@ -1206,7 +1204,7 @@ use_jax_splash: false
12061204
# Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter)
12071205
vllm_hf_config_path: ""
12081206
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
1209-
# This can be used to override specific settings without modifying the original config file.
1207+
# This can be used to override specific settings without modifying the original config file.
12101208
vllm_hf_overrides: {}
12111209
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
12121210
vllm_additional_config: {}
@@ -1221,7 +1219,7 @@ sinkhorn_iterations: 20
12211219

12221220
################################## DeepSeek Engram ##################################
12231221
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
1224-
# Example: [1, 4] attaches to the 2nd and 5th layer.
1222+
# Example: [1, 4] attaches to the 2nd and 5th layer.
12251223
engram_layers: []
12261224
# The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams.
12271225
engram_max_ngram_size: 3

src/maxtext/configs/post_train/dpo.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
base_config: "base.yml"
22

33
use_dpo: true
4+
dpo:
5+
algo: 'dpo'
6+
orpo_lambda: 0.1
7+
dpo_label_smoothing: 0.0
8+
dpo_beta: 0.1
9+
max_prompt_length: null
410
packing: false
511
train_data_columns: ['chosen', 'rejected']
612
eval_data_columns: ['chosen', 'rejected']
@@ -24,8 +30,6 @@ hf_eval_split: 'test'
2430

2531
gradient_clipping_threshold: 10.0
2632
learning_rate: 5.0e-7
27-
dpo_label_smoothing: 0.0
28-
dpo_beta: 0.1
2933

3034
enable_goodput_recording: false
3135
monitor_goodput: false

src/maxtext/configs/types.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,12 +1206,24 @@ class OlmoGrainDataset(BaseModel):
12061206
olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.")
12071207

12081208

1209+
class DPO(BaseModel):
1210+
"""Configuration for DPO and ORPO preference optimization algorithms."""
1211+
1212+
algo: Literal["dpo", "orpo"] = Field("dpo", description="Alignment algorithm to use.")
1213+
dpo_beta: float = Field(0.1, description="Beta parameter for DPO.")
1214+
orpo_lambda: float = Field(0.1, description="Weight for preference loss in ORPO.")
1215+
dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.")
1216+
max_prompt_length: int | None = Field(
1217+
None,
1218+
gt=0,
1219+
description="Maximum length for prompt. If None, defaults to half of max_target_length.",
1220+
)
1221+
1222+
12091223
class FineTuning(BaseModel):
12101224
"""Configuration for fine-tuning methods like DPO, SFT, and GRPO."""
12111225

12121226
use_dpo: bool = Field(False, description="If True, enables Direct Preference Optimization training.")
1213-
dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.")
1214-
dpo_beta: float = Field(0.1, description="Beta parameter for DPO.")
12151227
use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.")
12161228
sft_train_on_completion_only: bool = Field(
12171229
False, description="If True, trains only on the completion part of the text."
@@ -2303,6 +2315,10 @@ class MaxTextConfig(
23032315
"""
23042316

23052317
debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.")
2318+
dpo: DPO = Field(
2319+
default_factory=DPO,
2320+
description="Configuration for DPO and ORPO alignment algorithms.",
2321+
)
23062322
rl: RL = Field(
23072323
default_factory=RL,
23082324
description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).",
@@ -2889,6 +2905,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
28892905
raise ValueError("For multimodal SFT, `sft_train_on_completion_only` must be True.")
28902906
if self.packing:
28912907
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
2908+
if self.use_dpo:
2909+
if self.packing:
2910+
raise ValueError("For DPO/ORPO, `packing` is not supported.")
2911+
if self.dpo.max_prompt_length is not None and self.dpo.max_prompt_length >= self.max_target_length:
2912+
raise ValueError(
2913+
f"dpo.max_prompt_length ({self.dpo.max_prompt_length}) must be less than max_target_length"
2914+
f" ({self.max_target_length})."
2915+
)
2916+
if self.use_sft and self.use_dpo:
2917+
raise ValueError("Only one of `use_sft` or `use_dpo` can be True.")
28922918
if self.shard_mode == ShardMode.EXPLICIT:
28932919
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
28942920
if self.decoder_block.value not in supported_decoders:
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""DPO specific input pipeline utilities."""
16+
17+
import dataclasses
18+
import grain.python as grain
19+
import numpy as np
20+
21+
22+
@dataclasses.dataclass
23+
class DPODataFormatting(grain.MapTransform):
24+
"""Prepares DPO data.
25+
Renames input columns, extracts common prefix if needed, generates masks, and performs
26+
DPO-aware padding (left-padded prompts, right-padded responses).
27+
"""
28+
29+
pad_id: int
30+
max_target_length: int
31+
data_column_names: tuple[str, ...]
32+
max_prompt_length: int | None = None
33+
34+
def map(self, element):
35+
"Apply the dataset transformations for DPO."
36+
# 1. Reformat/Extract Columns
37+
try:
38+
if len(self.data_column_names) == 3:
39+
input_ids = element[self.data_column_names[0]]
40+
chosen_ids = element[self.data_column_names[1]]
41+
rejected_ids = element[self.data_column_names[2]]
42+
elif len(self.data_column_names) == 2:
43+
# Support for datasets like Anthropic/hh-rlhf where prompt is a common prefix
44+
full_chosen = element[self.data_column_names[0]]
45+
full_rejected = element[self.data_column_names[1]]
46+
47+
# Find common prefix length
48+
prefix_len = 0
49+
for c, r in zip(full_chosen, full_rejected):
50+
if c != r:
51+
break
52+
prefix_len += 1
53+
input_ids = full_chosen[:prefix_len]
54+
chosen_ids = full_chosen[prefix_len:]
55+
rejected_ids = full_rejected[prefix_len:]
56+
else:
57+
raise ValueError(f"DPODataFormatting expects 2 or 3 columns, got {len(self.data_column_names)}")
58+
except KeyError as e:
59+
raise KeyError(
60+
f"Column '{e.args[0]}' not found in the dataset. "
61+
f"Expected columns: {self.data_column_names}. "
62+
f"Available columns: {list(element.keys())}. "
63+
"Please verify that 'train_data_columns' and 'eval_data_columns' match your dataset."
64+
) from e
65+
66+
# 2. Padding and Masking
67+
max_prompt_length = self.max_prompt_length or (self.max_target_length // 2)
68+
max_response_length = self.max_target_length - max_prompt_length
69+
70+
assert max_prompt_length > 0, (
71+
"max_prompt_length must be positive. " "Check the configs for 'max_prompt_length' and 'max_target_length'."
72+
)
73+
assert max_response_length > 0, (
74+
"max_response_length must be positive. " "Check the configs for 'max_prompt_length' and 'max_target_length'."
75+
)
76+
77+
prompt_ids = self._pad(input_ids, max_prompt_length, left=True)
78+
chosen_ids = self._pad(chosen_ids, max_response_length, left=False)
79+
rejected_ids = self._pad(rejected_ids, max_response_length, left=False)
80+
81+
# Remove old columns if they exist
82+
for key in self.data_column_names:
83+
if key in element:
84+
del element[key]
85+
86+
element["prompt_ids"] = prompt_ids
87+
element["chosen_ids"] = chosen_ids
88+
element["rejected_ids"] = rejected_ids
89+
element["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32)
90+
element["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32)
91+
element["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32)
92+
return element
93+
94+
def _pad(self, x, length, left=False):
95+
"""Pads or trims an array to a specific length.
96+
97+
When left=True (for prompts), trims from the left to keep the suffix (closest context).
98+
When left=False (for responses), trims from the right to keep the prefix.
99+
"""
100+
x = np.asarray(x)
101+
pad_amount = max(length - x.shape[0], 0)
102+
if left:
103+
pad_width = ((pad_amount, 0),)
104+
x_trimmed = x[-length:]
105+
else:
106+
pad_width = ((0, pad_amount),)
107+
x_trimmed = x[:length]
108+
return np.pad(x_trimmed, pad_width, constant_values=self.pad_id).astype(np.int32)

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -24,9 +24,8 @@
2424

2525
import grain.python as grain
2626

27-
import numpy as np
28-
2927
from maxtext.input_pipeline import data_processing_utils
28+
from maxtext.input_pipeline import dpo_utils
3029
from maxtext.input_pipeline import input_pipeline_utils
3130
from maxtext.input_pipeline import instruction_data_processing
3231
from maxtext.input_pipeline import multihost_dataloading
@@ -214,7 +213,7 @@ def preprocessing_pipeline(
214213
num_threads=1,
215214
drop_remainder=True,
216215
generate_padding_batch=False,
217-
use_dpo=None,
216+
use_dpo=False,
218217
use_sft=None,
219218
use_tunix_gradient_accumulation=False,
220219
num_microbatches=1,
@@ -330,19 +329,12 @@ def preprocessing_pipeline(
330329
)
331330
)
332331
data_column_names = ("inputs", "targets")
333-
elif use_dpo:
334-
335-
def lists2array(x):
336-
"""Convert lists/tuples to array"""
337-
return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple)))
338-
339-
operations.append(grain.MapOperation(lists2array))
340-
else:
332+
elif not use_dpo:
341333
assert len(data_column_names) == 1
342334
operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
343335
data_column_names = ("inputs", "targets")
344336

345-
if packing and not use_dpo:
337+
if packing:
346338
length_struct = {col: max_target_length for col in data_column_names}
347339
max_segments = max_segments_per_seq
348340
if max_segments is not None and max_segments <= 0:
@@ -356,7 +348,12 @@ def lists2array(x):
356348
)
357349
operations.append(input_pipeline_utils.ReformatPacking(data_column_names))
358350
else:
359-
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
351+
if use_dpo:
352+
# Renames arbitrary DPO columns and performs DPO-aware padding.
353+
max_prompt_length = config.dpo.max_prompt_length
354+
operations.append(dpo_utils.DPODataFormatting(pad_id, max_target_length, data_column_names, max_prompt_length))
355+
else:
356+
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
360357
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))
361358

362359
if shift and not use_dpo:

0 commit comments

Comments
 (0)