Skip to content

Commit 6c8a1d4

Browse files
committed
Adding option for rl workloads that use models without chat_templates to specify one
1 parent 2a38dc9 commit 6c8a1d4

2 files changed

Lines changed: 114 additions & 0 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,30 @@ def _reward_fn(**kwargs):
582582
return rl_cluster, rl_trainer, optimizer
583583

584584

585+
586+
def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None:
587+
"""Populates the tokenizer's chat_template from config if missing."""
588+
if getattr(model_tokenizer, "chat_template", None) is None:
589+
if getattr(trainer_config, "chat_template", None):
590+
model_tokenizer.chat_template = trainer_config.chat_template
591+
elif getattr(trainer_config, "tokenizer_chat_template_path", None):
592+
from maxtext.input_pipeline.instruction_data_processing import (
593+
load_chat_template_from_file,
594+
)
595+
model_tokenizer.chat_template = load_chat_template_from_file(
596+
trainer_config.tokenizer_chat_template_path
597+
)
598+
else:
599+
raise ValueError(
600+
f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template "
601+
"and config.chat_template / config.tokenizer_chat_template_path "
602+
"are both empty. Either pick an instruction-tuned tokenizer that "
603+
"ships with a chat_template, set config.chat_template to a Jinja "
604+
"string, or set config.tokenizer_chat_template_path to a JSON file "
605+
"with a 'chat_template' key."
606+
)
607+
608+
585609
def rl_train(argv: Sequence[str], kwargs: dict):
586610
"""
587611
Run RL training with the provided configuration.
@@ -610,6 +634,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
610634
trainer_config.tokenizer_path,
611635
token=trainer_config.hf_access_token or None,
612636
)
637+
configure_tokenizer_chat_template(model_tokenizer, trainer_config)
613638

614639
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes(
615640
trainer_config,

tests/post_training/unit/train_rl_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,5 +472,94 @@ def test_prepare_datasets_without_split(self, mock_load):
472472
assert mock_load.call_count == len(expected_calls)
473473

474474

475+
class TokenizerChatTemplateTest(unittest.TestCase):
476+
"""Unit tests for configure_tokenizer_chat_template."""
477+
478+
@pytest.mark.cpu_only
479+
def test_chat_template_populated_from_config_string(self):
480+
"""Test that chat_template is set from config.chat_template when tokenizer lacks one."""
481+
mock_tokenizer = mock.MagicMock()
482+
mock_tokenizer.chat_template = None
483+
trainer_config = SimpleNamespace(
484+
chat_template="{{ messages[0].content }}",
485+
tokenizer_chat_template_path=None,
486+
tokenizer_path="dummy-base-model",
487+
)
488+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
489+
self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}")
490+
491+
@pytest.mark.cpu_only
492+
@mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file")
493+
def test_chat_template_populated_from_config_file(self, mock_load):
494+
"""Test that chat_template is loaded from tokenizer_chat_template_path when tokenizer lacks one."""
495+
mock_tokenizer = mock.MagicMock()
496+
mock_tokenizer.chat_template = None
497+
mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
498+
trainer_config = SimpleNamespace(
499+
chat_template=None,
500+
tokenizer_chat_template_path="/path/to/jinja_template.json",
501+
tokenizer_path="dummy-base-model",
502+
)
503+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
504+
mock_load.assert_called_once_with("/path/to/jinja_template.json")
505+
self.assertEqual(
506+
mock_tokenizer.chat_template,
507+
"{% for message in messages %}{{ message.content }}{% endfor %}",
508+
)
509+
@pytest.mark.cpu_only
510+
def test_chat_template_raises_value_error_when_empty(self):
511+
"""Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty."""
512+
mock_tokenizer = mock.MagicMock()
513+
mock_tokenizer.chat_template = None
514+
trainer_config = SimpleNamespace(
515+
chat_template=None,
516+
tokenizer_chat_template_path=None,
517+
tokenizer_path="dummy-base-model",
518+
)
519+
with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"):
520+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
521+
522+
@pytest.mark.cpu_only
523+
def test_chat_template_unchanged_when_already_exists(self):
524+
"""Test that an existing chat_template on the tokenizer is preserved (backward compatibility)."""
525+
mock_tokenizer = mock.MagicMock()
526+
mock_tokenizer.chat_template = "{{ existing_template }}"
527+
trainer_config = SimpleNamespace(
528+
chat_template="{{ overridden_template }}",
529+
tokenizer_chat_template_path=None,
530+
tokenizer_path="dummy-instruction-tuned-model",
531+
)
532+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
533+
self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}")
534+
535+
@pytest.mark.cpu_only
536+
def test_apply_chat_template_works_after_configuration(self):
537+
"""Verifies apply_chat_template succeeds and produces the expected format after our code path runs."""
538+
class DummyTokenizer:
539+
def __init__(self):
540+
self.chat_template = None
541+
542+
def apply_chat_template(self, conversation, tokenize=False):
543+
if self.chat_template is None:
544+
raise ValueError("Cannot apply chat template because chat_template is None")
545+
import jinja2
546+
env = jinja2.Environment()
547+
template = env.from_string(self.chat_template)
548+
return template.render(messages=conversation)
549+
tokenizer = DummyTokenizer()
550+
trainer_config = SimpleNamespace(
551+
chat_template="{{ messages[0].content }}",
552+
tokenizer_chat_template_path=None,
553+
tokenizer_path="dummy-base-model",
554+
)
555+
# Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)
556+
with self.assertRaises(ValueError):
557+
tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
558+
# Run the proposed change
559+
train_rl.configure_tokenizer_chat_template(tokenizer, trainer_config)
560+
# Verify apply_chat_template now runs successfully and renders correct content
561+
rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
562+
self.assertEqual(rendered, "Hello!")
563+
475564
if __name__ == "__main__":
476565
unittest.main()

0 commit comments

Comments
 (0)