Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/maxtext/trainers/post_train/dpo/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,22 @@ def train(mt_config, goodput_recorder=None):
return trainer, mesh


def validate_config(config):
"""Validates the configuration parameters for DPO training."""
if config.optimizer_memory_host_offload:
raise ValueError(
"optimizer_memory_host_offload=True is not supported on the post-training "
"DPO path because the underlying Tunix DPOTrainer does not support "
"host offloading of the optimizer state."
)

if config.num_vocab_tiling > 1:
raise ValueError(
f"Vocab Tiling is not supported with DPO. "
f"num_vocab_tiling was configured to {config.num_vocab_tiling}, but it must be 1 when running train_dpo."
)


def main(argv: list[str]) -> None:
"""Main function to run DPO training.

Expand All @@ -191,6 +207,7 @@ def main(argv: list[str]) -> None:
pathwaysutils.initialize()

mt_config = pyconfig.initialize_pydantic(argv)
validate_config(mt_config)
max_utils.print_system_information()

goodput_recorder = create_goodput_recorder(mt_config)
Expand Down
23 changes: 17 additions & 6 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,17 +596,28 @@ def rl_train(argv: Sequence[str], kwargs: dict):
_rl_train_impl(argv, kwargs)


def validate_config(config):
"""Validates the configuration parameters for RL training."""
if config.optimizer_memory_host_offload:
raise ValueError(
"optimizer_memory_host_offload=True is not supported on the post-training "
"RL path because the underlying Tunix RLCluster/Trainer does not "
"support host offloading of the optimizer state."
)

if config.num_vocab_tiling > 1:
raise ValueError(
f"Vocab Tiling is not supported with RL. "
f"num_vocab_tiling was configured to {config.num_vocab_tiling}, but it must be 1 when running train_rl."
)


def _rl_train_impl(argv: Sequence[str], kwargs: dict):
"""rl_train body — kept separate so _tpu_inference_compat_patches wraps it cleanly."""
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
argv, kwargs
)

if trainer_config.num_vocab_tiling > 1:
raise ValueError(
f"Vocab Tiling is not supported with RL. "
f"num_vocab_tiling was configured to {trainer_config.num_vocab_tiling}, but it must be 1 when running train_rl."
)
validate_config(trainer_config)

# Create model tokenizer first so we can plumb its pad_id into the model
# adapter (used to synthesize segment_ids that mask pad positions from
Expand Down
11 changes: 11 additions & 0 deletions src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ def loss_func(
return trainer


def validate_config(config):
"""Validates the configuration parameters for SFT training."""
if config.optimizer_memory_host_offload:
raise ValueError(
"optimizer_memory_host_offload=True is not supported on the post-training "
"SFT path because the underlying Tunix PeftTrainer does not support "
"host offloading of the optimizer state."
)


def setup_trainer_state(mt_config, goodput_recorder=None):
"""Set up prerequisites for training loop."""
tunix_config = get_tunix_config(mt_config)
Expand Down Expand Up @@ -299,6 +309,7 @@ def main(argv: Sequence[str]) -> None:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

mt_config = pyconfig.initialize_pydantic(argv)
validate_config(mt_config)
max_utils.print_system_information()

goodput_recorder = create_goodput_recorder(mt_config)
Expand Down
58 changes: 58 additions & 0 deletions tests/post_training/unit/train_dpo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for train_dpo.py."""

import unittest
from types import SimpleNamespace
import pytest

from maxtext.trainers.post_train.dpo import train_dpo

pytestmark = [pytest.mark.post_training]


class TrainDPOTest(unittest.TestCase):
"""Tests for train_dpo.py."""

@pytest.mark.cpu_only
def test_validate_config_valid(self):
config = SimpleNamespace(
optimizer_memory_host_offload=False,
num_vocab_tiling=1,
)
# Should not raise any exception
train_dpo.validate_config(config)

@pytest.mark.cpu_only
def test_validate_config_invalid_offload(self):
config = SimpleNamespace(
optimizer_memory_host_offload=True,
num_vocab_tiling=1,
)
with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
train_dpo.validate_config(config)

@pytest.mark.cpu_only
def test_validate_config_invalid_vocab_tiling(self):
config = SimpleNamespace(
optimizer_memory_host_offload=False,
num_vocab_tiling=2,
)
with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with DPO"):
train_dpo.validate_config(config)


if __name__ == "__main__":
unittest.main()
15 changes: 14 additions & 1 deletion tests/post_training/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,24 @@ def test_prepare_datasets_without_split(self, mock_load):
def test_rl_train_invalid_vocab_tiling(self, mock_setup):
mock_config = SimpleNamespace(
num_vocab_tiling=2,
optimizer_memory_host_offload=False,
)
mock_setup.return_value = (mock_config, mock_config, [], [])

with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"):
train_rl._rl_train_impl([], {})
train_rl._rl_train_impl([], {}) # pylint: disable=protected-access

@pytest.mark.cpu_only
@mock.patch("maxtext.trainers.post_train.rl.train_rl.model_creation_utils.setup_configs_and_devices")
def test_rl_train_invalid_optimizer_memory_host_offload(self, mock_setup):
mock_config = SimpleNamespace(
num_vocab_tiling=1,
optimizer_memory_host_offload=True,
)
mock_setup.return_value = (mock_config, mock_config, [], [])

with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
train_rl._rl_train_impl([], {}) # pylint: disable=protected-access


if __name__ == "__main__":
Expand Down
47 changes: 47 additions & 0 deletions tests/post_training/unit/train_sft_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for train_sft.py."""

import unittest
from types import SimpleNamespace
import pytest

from maxtext.trainers.post_train.sft import train_sft

pytestmark = [pytest.mark.post_training]


class TrainSFTTest(unittest.TestCase):
"""Tests for train_sft.py."""

@pytest.mark.cpu_only
def test_validate_config_valid(self):
config = SimpleNamespace(
optimizer_memory_host_offload=False,
)
# Should not raise any exception
train_sft.validate_config(config)

@pytest.mark.cpu_only
def test_validate_config_invalid_offload(self):
config = SimpleNamespace(
optimizer_memory_host_offload=True,
)
with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
train_sft.validate_config(config)


if __name__ == "__main__":
unittest.main()
Loading