Skip to content

Commit 27784e7

Browse files
committed
adding support for None chat templates.
1 parent 24a3c79 commit 27784e7

3 files changed

Lines changed: 48 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def get_dataset(
114114
)
115115

116116
template_config = load_data_template_from_file(tmvp_config.chat_template_path)
117+
if template_config is None:
118+
raise ValueError(
119+
f"Chat template is required for processing dataset but failed to load from {tmvp_config.chat_template_path}"
120+
)
117121

118122
loaded_dataset = (
119123
grain.MapDataset.source(data)
@@ -231,6 +235,10 @@ def prepare_openinstructmath2_dataset(
231235
split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M"
232236
splits = prepare_openinstructmath2_dataset(split=split_name)
233237
template_config = load_data_template_from_file(trainer_config.chat_template_path)
238+
if template_config is None:
239+
raise ValueError(
240+
f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}"
241+
)
234242

235243
train_dataset = (
236244
grain.MapDataset.source(splits["train"])
@@ -401,7 +409,6 @@ def create_rl_components(
401409
rollout_vllm_model_version=trainer_config.tokenizer_path,
402410
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
403411
rollout_vllm_tpu_backend_type="jax",
404-
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
405412
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
406413
rollout_vllm_additional_config=rollout_additional_config,
407414
rollout_vllm_init_with_random_weights=True,
@@ -495,6 +502,10 @@ def _reward_fn(**kwargs):
495502
)
496503
# Instantiate the custom MaxText chat parser
497504
template_config = load_data_template_from_file(trainer_config.chat_template_path)
505+
if template_config is None:
506+
raise ValueError(
507+
f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}"
508+
)
498509
chat_parser = utils_rl.MaxTextChatParser(
499510
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
500511
)

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,11 @@ def make_optimizer(learning_rate):
526526
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
527527

528528

529-
def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
529+
def format_maxtext_messages(messages: list[str], template_config: dict, tmvp_config) -> list[dict[str, str]]:
530530
"""Helper to inject MaxText's system prompt into the input user messages."""
531+
if template_config is None:
532+
raise ValueError("template_config cannot be None for format_maxtext_messages.")
533+
531534
formatted_messages = []
532535
for msg in messages:
533536
formatted_content = template_config["TEMPLATE"].format(

tests/post_training/unit/rl_utils_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,5 +370,37 @@ def test_returns_optimizer_with_clipping(self):
370370
self.assertIn("learning_rate", state.hyperparams)
371371

372372

373+
class TestFormatMaxTextMessages(unittest.TestCase):
374+
"""Tests for utils_rl.format_maxtext_messages."""
375+
376+
def setUp(self):
377+
self.config = _make_config()
378+
self.template_config = {
379+
"SYSTEM_PROMPT": "Reason between {reasoning_start_token} and {reasoning_end_token}. "
380+
+ "Solution between {solution_start_token} and {solution_end_token}.",
381+
"TEMPLATE": "system: {system_prompt}\nquestion: {question}",
382+
}
383+
384+
@pytest.mark.cpu_only
385+
def test_format_with_template(self):
386+
"""Test formatting when a template is provided."""
387+
messages = ["What is 2+2?"]
388+
formatted = utils_rl.format_maxtext_messages(messages, self.template_config, self.config)
389+
self.assertEqual(len(formatted), 1)
390+
self.assertEqual(formatted[0]["role"], "user")
391+
expected_content = (
392+
"system: Reason between <reasoning> and </reasoning>. "
393+
"Solution between <answer> and </answer>.\n"
394+
"question: What is 2+2?"
395+
)
396+
self.assertEqual(formatted[0]["content"], expected_content)
397+
398+
@pytest.mark.cpu_only
399+
def test_format_without_template(self):
400+
"""Test formatting when template_config is None (the fix)."""
401+
messages = ["What is 2+2?"]
402+
self.assertRaises(ValueError, lambda: utils_rl.format_maxtext_messages(messages, None, self.config))
403+
404+
373405
if __name__ == "__main__":
374406
unittest.main()

0 commit comments

Comments
 (0)