Skip to content

Commit 573a8ff

Browse files
committed
Add validation to block unsupported parameters in post-training SFT, RL, and DPO
- SFT, RL, and DPO trainers now error out on optimizer_memory_host_offload=True since the underlying Tunix trainers do not support host offloading. - DPO and RL trainers now error out on num_vocab_tiling > 1 since they require full vocabulary projections to compute preference/policy log-probabilities, which is incompatible with vocabulary tiling. - Created unit tests for SFT, RL, and DPO validation functions to verify these restrictions.
1 parent 2e6cd11 commit 573a8ff

6 files changed

Lines changed: 163 additions & 6 deletions

File tree

src/maxtext/trainers/post_train/dpo/train_dpo.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,22 @@ def train(mt_config, goodput_recorder=None):
182182
return trainer, mesh
183183

184184

185+
def validate_config(config):
186+
"""Validates the configuration parameters for DPO training."""
187+
if config.optimizer_memory_host_offload:
188+
raise ValueError(
189+
"optimizer_memory_host_offload=True is not supported on the post-training "
190+
"DPO path because the underlying Tunix DPOTrainer does not support "
191+
"host offloading of the optimizer state."
192+
)
193+
194+
if config.num_vocab_tiling > 1:
195+
raise ValueError(
196+
f"Vocab Tiling is not supported with DPO. "
197+
f"num_vocab_tiling was configured to {config.num_vocab_tiling}, but it must be 1 when running train_dpo."
198+
)
199+
200+
185201
def main(argv: list[str]) -> None:
186202
"""Main function to run DPO training.
187203
@@ -191,6 +207,7 @@ def main(argv: list[str]) -> None:
191207
pathwaysutils.initialize()
192208

193209
mt_config = pyconfig.initialize_pydantic(argv)
210+
validate_config(mt_config)
194211
max_utils.print_system_information()
195212

196213
goodput_recorder = create_goodput_recorder(mt_config)

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -596,17 +596,28 @@ def rl_train(argv: Sequence[str], kwargs: dict):
596596
_rl_train_impl(argv, kwargs)
597597

598598

599+
def validate_config(config):
600+
"""Validates the configuration parameters for RL training."""
601+
if config.optimizer_memory_host_offload:
602+
raise ValueError(
603+
"optimizer_memory_host_offload=True is not supported on the post-training "
604+
"RL path because the underlying Tunix RLCluster/Trainer does not "
605+
"support host offloading of the optimizer state."
606+
)
607+
608+
if config.num_vocab_tiling > 1:
609+
raise ValueError(
610+
f"Vocab Tiling is not supported with RL. "
611+
f"num_vocab_tiling was configured to {config.num_vocab_tiling}, but it must be 1 when running train_rl."
612+
)
613+
614+
599615
def _rl_train_impl(argv: Sequence[str], kwargs: dict):
600616
"""rl_train body — kept separate so _tpu_inference_compat_patches wraps it cleanly."""
601617
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
602618
argv, kwargs
603619
)
604-
605-
if trainer_config.num_vocab_tiling > 1:
606-
raise ValueError(
607-
f"Vocab Tiling is not supported with RL. "
608-
f"num_vocab_tiling was configured to {trainer_config.num_vocab_tiling}, but it must be 1 when running train_rl."
609-
)
620+
validate_config(trainer_config)
610621

611622
# Create model tokenizer first so we can plumb its pad_id into the model
612623
# adapter (used to synthesize segment_ids that mask pad positions from

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,16 @@ def loss_func(
225225
return trainer
226226

227227

228+
def validate_config(config):
229+
"""Validates the configuration parameters for SFT training."""
230+
if config.optimizer_memory_host_offload:
231+
raise ValueError(
232+
"optimizer_memory_host_offload=True is not supported on the post-training "
233+
"SFT path because the underlying Tunix PeftTrainer does not support "
234+
"host offloading of the optimizer state."
235+
)
236+
237+
228238
def setup_trainer_state(mt_config, goodput_recorder=None):
229239
"""Set up prerequisites for training loop."""
230240
tunix_config = get_tunix_config(mt_config)
@@ -299,6 +309,7 @@ def main(argv: Sequence[str]) -> None:
299309
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
300310

301311
mt_config = pyconfig.initialize_pydantic(argv)
312+
validate_config(mt_config)
302313
max_utils.print_system_information()
303314

304315
goodput_recorder = create_goodput_recorder(mt_config)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for train_dpo.py."""
16+
17+
import unittest
18+
from types import SimpleNamespace
19+
import pytest
20+
21+
from maxtext.trainers.post_train.dpo import train_dpo
22+
23+
pytestmark = [pytest.mark.post_training]
24+
25+
26+
class TrainDPOTest(unittest.TestCase):
27+
"""Tests for train_dpo.py."""
28+
29+
@pytest.mark.cpu_only
30+
def test_validate_config_valid(self):
31+
config = SimpleNamespace(
32+
optimizer_memory_host_offload=False,
33+
num_vocab_tiling=1,
34+
)
35+
# Should not raise any exception
36+
train_dpo.validate_config(config)
37+
38+
@pytest.mark.cpu_only
39+
def test_validate_config_invalid_offload(self):
40+
config = SimpleNamespace(
41+
optimizer_memory_host_offload=True,
42+
num_vocab_tiling=1,
43+
)
44+
with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
45+
train_dpo.validate_config(config)
46+
47+
@pytest.mark.cpu_only
48+
def test_validate_config_invalid_vocab_tiling(self):
49+
config = SimpleNamespace(
50+
optimizer_memory_host_offload=False,
51+
num_vocab_tiling=2,
52+
)
53+
with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with DPO"):
54+
train_dpo.validate_config(config)
55+
56+
57+
if __name__ == "__main__":
58+
unittest.main()

tests/post_training/unit/train_rl_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,25 @@ def test_prepare_datasets_without_split(self, mock_load):
476476
def test_rl_train_invalid_vocab_tiling(self, mock_setup):
477477
mock_config = SimpleNamespace(
478478
num_vocab_tiling=2,
479+
optimizer_memory_host_offload=False,
479480
)
480481
mock_setup.return_value = (mock_config, mock_config, [], [])
481482

482483
with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"):
483484
train_rl._rl_train_impl([], {})
484485

486+
@pytest.mark.cpu_only
487+
@mock.patch("maxtext.trainers.post_train.rl.train_rl.model_creation_utils.setup_configs_and_devices")
488+
def test_rl_train_invalid_optimizer_memory_host_offload(self, mock_setup):
489+
mock_config = SimpleNamespace(
490+
num_vocab_tiling=1,
491+
optimizer_memory_host_offload=True,
492+
)
493+
mock_setup.return_value = (mock_config, mock_config, [], [])
494+
495+
with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
496+
train_rl._rl_train_impl([], {})
497+
485498

486499
if __name__ == "__main__":
487500
unittest.main()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for train_sft.py."""
16+
17+
import unittest
18+
from types import SimpleNamespace
19+
import pytest
20+
21+
from maxtext.trainers.post_train.sft import train_sft
22+
23+
pytestmark = [pytest.mark.post_training]
24+
25+
26+
class TrainSFTTest(unittest.TestCase):
27+
"""Tests for train_sft.py."""
28+
29+
@pytest.mark.cpu_only
30+
def test_validate_config_valid(self):
31+
config = SimpleNamespace(
32+
optimizer_memory_host_offload=False,
33+
)
34+
# Should not raise any exception
35+
train_sft.validate_config(config)
36+
37+
@pytest.mark.cpu_only
38+
def test_validate_config_invalid_offload(self):
39+
config = SimpleNamespace(
40+
optimizer_memory_host_offload=True,
41+
)
42+
with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
43+
train_sft.validate_config(config)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)