Skip to content

Commit 400f1be

Browse files
committed
Error out when num_vocab_tiling > 1 is configured for RL training
1 parent d8763ef commit 400f1be

2 files changed

Lines changed: 17 additions & 0 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
602602
argv, kwargs
603603
)
604604

605+
if trainer_config.num_vocab_tiling > 1:
606+
raise ValueError(
607+
f"Vocab Tiling is not supported with RL. "
608+
f"num_vocab_tiling was configured to {trainer_config.num_vocab_tiling}, but it must be 1 when running train_rl."
609+
)
610+
605611
# Create model tokenizer first so we can plumb its pad_id into the model
606612
# adapter (used to synthesize segment_ids that mask pad positions from
607613
# attention — without this the trainer attends to pad tokens and produces

tests/post_training/unit/train_rl_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,17 @@ def test_prepare_datasets_without_split(self, mock_load):
471471
mock_load.assert_has_calls(expected_calls, any_order=True)
472472
assert mock_load.call_count == len(expected_calls)
473473

474+
@pytest.mark.cpu_only
475+
@mock.patch("maxtext.trainers.post_train.rl.train_rl.model_creation_utils.setup_configs_and_devices")
476+
def test_rl_train_invalid_vocab_tiling(self, mock_setup):
477+
mock_config = SimpleNamespace(
478+
num_vocab_tiling=2,
479+
)
480+
mock_setup.return_value = (mock_config, mock_config, [], [])
481+
482+
with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"):
483+
train_rl._rl_train_impl([], {})
484+
474485

475486
if __name__ == "__main__":
476487
unittest.main()

0 commit comments

Comments
 (0)