Skip to content

Commit add2a6d

Browse files
committed
Feature: add standalone DPO training and data hooks
1 parent 39470c9 commit add2a6d

2 files changed

Lines changed: 114 additions & 0 deletions

File tree

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Training and data loading hooks for DPO"""
17+
18+
from typing import override
19+
20+
import jax
21+
import jax.numpy as jnp
22+
23+
from maxtext.trainers.post_train.hooks import BaseTrainingHooks, BaseDataHooks
24+
25+
26+
class DPOTrainingHooks(BaseTrainingHooks):
27+
"""Training hooks for DPO."""
28+
29+
@override
30+
def get_total_weights(self, batch) -> jax.Array:
31+
# For DPO, we sum both chosen and rejected masks
32+
return jnp.sum(batch["chosen_mask"] != 0) + jnp.sum(batch["rejected_mask"] != 0)
33+
34+
35+
class DPODataHooks(BaseDataHooks):
36+
"""Data hooks for DPO."""
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for training and data loading hooks for DPO"""
16+
17+
import unittest
18+
from unittest.mock import MagicMock, patch
19+
20+
import pytest
21+
22+
pytestmark = [pytest.mark.cpu_only, pytest.mark.external_training, pytest.mark.post_training]
23+
24+
import jax
25+
from jax.sharding import Mesh
26+
import numpy as np
27+
import os
28+
import shutil
29+
import tempfile
30+
31+
from maxtext.configs import pyconfig
32+
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
33+
from maxtext.trainers.post_train.dpo import hooks as dpo_hooks
34+
from maxtext.utils import maxtext_utils
35+
36+
37+
class DPOHooksTest(unittest.TestCase):
38+
39+
def setUp(self):
40+
super().setUp()
41+
self.test_dir = tempfile.mkdtemp()
42+
self.config = pyconfig.initialize(
43+
["", os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "dpo.yml")],
44+
per_device_batch_size=1,
45+
run_name="test",
46+
base_output_directory=self.test_dir,
47+
tensorboard_dir=self.test_dir,
48+
skip_jax_distributed_system=True,
49+
)
50+
self.mesh = Mesh(maxtext_utils.create_device_mesh(self.config), self.config.mesh_axes)
51+
52+
def tearDown(self):
53+
shutil.rmtree(self.test_dir)
54+
super().tearDown()
55+
56+
@patch("maxtext.trainers.post_train.hooks.create_data_iterator")
57+
def test_dpo_data_hooks_load_next_train_batch(self, mock_create_data_iterator):
58+
expected_batch = {"inputs": np.zeros([jax.device_count(), self.config.max_target_length], dtype=int)}
59+
mock_data_iterator = MagicMock()
60+
mock_data_iterator.__next__.return_value = expected_batch
61+
mock_create_data_iterator.return_value = mock_data_iterator, None
62+
63+
data_hooks = dpo_hooks.DPODataHooks(self.config, self.mesh, goodput_recorder=None)
64+
data_hooks.load_next_train_batch(train_ctx=None)
65+
66+
self.assertIsNotNone(data_hooks.train_batch)
67+
self.assertEqual(data_hooks.train_batch["inputs"].shape, expected_batch["inputs"].shape)
68+
69+
def test_dpo_training_hooks_get_total_weights(self):
70+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(self.config)
71+
training_hooks = dpo_hooks.DPOTrainingHooks(self.config, self.mesh, learning_rate_schedule, goodput_recorder=None)
72+
batch = {"chosen_mask": np.array([[1, 1, 0], [1, 0, 0]]), "rejected_mask": np.array([[1, 0, 0], [1, 1, 0]])}
73+
total_weights = training_hooks.get_total_weights(batch)
74+
self.assertEqual(total_weights, 6)
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)