diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml
index 2570f5c915..73e28c33a4 100644
--- a/src/maxtext/configs/base.yml
+++ b/src/maxtext/configs/base.yml
@@ -645,6 +645,7 @@ tokenizer_path: ""
tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface"
use_chat_template: false
chat_template_path: "" # path to chat template json file
+chat_template: "" # Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.
tokenize_train_data: true # false if the dataset is pre-tokenized
tokenize_eval_data: true # false if the dataset is pre-tokenized
add_bos: true
diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml
index 133d66c730..94011ba5c1 100644
--- a/src/maxtext/configs/post_train/rl.yml
+++ b/src/maxtext/configs/post_train/rl.yml
@@ -208,7 +208,7 @@ reasoning_start_token: ''
reasoning_end_token: ''
solution_start_token: ''
solution_end_token: ''
-chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
+data_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
skip_jax_distributed_system: true
# ====== Dataset Configuration ======
diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py
index 8c17bc24dd..6cad0b6b8b 100644
--- a/src/maxtext/configs/types.py
+++ b/src/maxtext/configs/types.py
@@ -1081,6 +1081,10 @@ class Tokenizer(BaseModel):
"",
description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.",
)
+ data_template_path: str = Field(
+ "",
+ description="Path to a chat template file to be used when tokenizing the dataset.",
+ )
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")
add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.")
diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py
index d4826f73a0..997f021d4e 100644
--- a/src/maxtext/trainers/post_train/rl/train_rl.py
+++ b/src/maxtext/trainers/post_train/rl/train_rl.py
@@ -248,10 +248,10 @@ def prepare_datasets(
model_tokenizer: AutoTokenizer,
) -> tuple[grain.IterDataset, grain.IterDataset | None]:
"""Setup and return train and test datasets."""
- template_config = load_data_template_from_file(trainer_config.chat_template_path)
+ template_config = load_data_template_from_file(trainer_config.data_template_path)
if template_config is None:
raise ValueError(
- f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}"
+ f"Chat template is required for processing dataset but failed to load from {trainer_config.data_template_path}"
)
# Prepare train and test data from training data for certain datasets
@@ -548,10 +548,10 @@ def _reward_fn(**kwargs):
epsilon_high=trainer_config.rl.epsilon_high,
)
# Instantiate the custom MaxText chat parser
- template_config = load_data_template_from_file(trainer_config.chat_template_path)
+ template_config = load_data_template_from_file(trainer_config.data_template_path)
if template_config is None:
raise ValueError(
- f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}"
+ f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.data_template_path}"
)
chat_parser = utils_rl.MaxTextChatParser(
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
@@ -582,6 +582,28 @@ def _reward_fn(**kwargs):
return rl_cluster, rl_trainer, optimizer
+def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None:
+ """Populates the tokenizer's chat_template from config if missing."""
+ if getattr(model_tokenizer, "chat_template", None) is None:
+ if getattr(trainer_config, "chat_template", None):
+ model_tokenizer.chat_template = trainer_config.chat_template
+ elif getattr(trainer_config, "chat_template_path", None):
+ from maxtext.input_pipeline.instruction_data_processing import ( # pylint: disable=import-outside-toplevel
+ load_chat_template_from_file,
+ )
+
+ model_tokenizer.chat_template = load_chat_template_from_file(trainer_config.chat_template_path)
+ else:
+ raise ValueError(
+ f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template "
+ "and config.chat_template / config.chat_template_path "
+ "are both empty. Either pick an instruction-tuned tokenizer that "
+ "ships with a chat_template, set config.chat_template to a Jinja "
+ "string, or set config.chat_template_path to a JSON file "
+ "with a 'chat_template' key."
+ )
+
+
def rl_train(argv: Sequence[str], kwargs: dict):
"""
Run RL training with the provided configuration.
@@ -616,6 +638,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
trainer_config.tokenizer_path,
token=trainer_config.hf_access_token or None,
)
+ configure_tokenizer_chat_template(model_tokenizer, trainer_config)
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes(
trainer_config,
diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py
index ede2029dea..56ba2a7dcf 100644
--- a/tests/post_training/unit/train_rl_test.py
+++ b/tests/post_training/unit/train_rl_test.py
@@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self):
# Following the pattern in distillation_checkpointing_test.py for mocking jax objects
with (
mock.patch.object(jax, "devices", return_value=mock_devices),
- mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
+ mock.patch(
+ "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
+ return_value=mock_config,
+ ),
):
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
["dummy", "dummy"]
@@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
with (
mock.patch.object(jax, "devices", return_value=mock_devices),
- mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
+ mock.patch(
+ "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
+ return_value=mock_config,
+ ),
):
_, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"])
@@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 4,
}
- self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result)
+ self.assertEqual(
+ train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16),
+ expected_result,
+ )
@pytest.mark.cpu_only
def test_get_rollout_kwargs_auto_tp(self):
@@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 1,
}
- self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
+ self.assertEqual(
+ train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
+ expected_result,
+ )
@pytest.mark.cpu_only
def test_get_rollout_kwargs_fixed_tp_dp(self):
@@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 1,
}
- self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
+ self.assertEqual(
+ train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
+ expected_result,
+ )
@pytest.mark.cpu_only
def test_get_rollout_kwargs_auto_ep(self):
@@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 2,
}
- self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result)
+ self.assertEqual(
+ train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8),
+ expected_result,
+ )
@pytest.mark.cpu_only
def test_get_rollout_kwargs_errors(self):
@@ -307,7 +325,10 @@ def tokenize_side_effect(text):
{"question": "short", "answer": "a3"},
{"question": "long", "answer": "a4"},
]
- test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}]
+ test_data = [
+ {"question": "short", "answer": "a5"},
+ {"question": "long", "answer": "a6"},
+ ]
train_map_ds = grain.MapDataset.source(train_data)
test_map_ds = grain.MapDataset.source(test_data)
@@ -334,7 +355,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
eval_split="eval",
hf_train_files=None,
hf_eval_files=None,
- chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
+ data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_shuffle_seed=42,
max_prefill_predict_length=10,
batch_size=2,
@@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
)
with (
- mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
- mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect),
+ mock.patch(
+ "maxtext.trainers.post_train.rl.train_rl.get_dataset",
+ side_effect=get_dataset_side_effect,
+ ),
+ mock.patch(
+ "maxtext.trainers.post_train.rl.utils_rl.process_data",
+ side_effect=get_filtered_data_side_effect,
+ ),
):
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)
@@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
def test_prepare_datasets_with_split(self, mock_load):
mock_ds = mock.MagicMock()
mock_split_result = {
- "train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}],
+ "train": [
+ {"question": "q1", "answer": "a1"},
+ {"question": "q2", "answer": "a2"},
+ ],
"test": [{"question": "q3", "answer": "a3"}],
}
mock_ds.train_test_split.return_value = mock_split_result
@@ -389,7 +419,7 @@ def test_prepare_datasets_with_split(self, mock_load):
eval_dataset_name="open-r1/OpenR1-Math-220k",
train_split="train",
hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet",
- chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
+ data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_shuffle_seed=42,
num_batches=1,
batch_size=5,
@@ -435,7 +465,7 @@ def test_prepare_datasets_without_split(self, mock_load):
eval_split="test",
hf_train_files="hf://openai/gsm8k/data/dummy.parquet",
hf_eval_files="hf://openai/gsm8k/data/dummy.parquet",
- chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
+ data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_shuffle_seed=42,
num_batches=1,
batch_size=5,
@@ -480,7 +510,102 @@ def test_rl_train_invalid_vocab_tiling(self, mock_setup):
mock_setup.return_value = (mock_config, mock_config, [], [])
with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"):
- train_rl._rl_train_impl([], {})
+ train_rl._rl_train_impl([], {}) # pylint: disable=protected-access
+
+
+class TokenizerChatTemplateTest(unittest.TestCase):
+ """Unit tests for configure_tokenizer_chat_template."""
+
+ @pytest.mark.cpu_only
+ def test_chat_template_populated_from_config_string(self):
+ """Test that chat_template is set from config.chat_template when tokenizer lacks one."""
+ mock_tokenizer = mock.MagicMock()
+ mock_tokenizer.chat_template = None
+ trainer_config = SimpleNamespace(
+ chat_template="{{ messages[0].content }}",
+ chat_template_path=None,
+ tokenizer_path="dummy-base-model",
+ )
+ train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
+ self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}")
+
+ @pytest.mark.cpu_only
+ @mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file")
+ def test_chat_template_populated_from_config_file(self, mock_load):
+ """Test that chat_template is loaded from chat_template_path when tokenizer lacks one."""
+ mock_tokenizer = mock.MagicMock()
+ mock_tokenizer.chat_template = None
+ mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
+ trainer_config = SimpleNamespace(
+ chat_template=None,
+ chat_template_path="/path/to/jinja_template.json",
+ tokenizer_path="dummy-base-model",
+ )
+ train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
+ mock_load.assert_called_once_with("/path/to/jinja_template.json")
+ self.assertEqual(
+ mock_tokenizer.chat_template,
+ "{% for message in messages %}{{ message.content }}{% endfor %}",
+ )
+
+ @pytest.mark.cpu_only
+ def test_chat_template_raises_value_error_when_empty(self):
+ """Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty."""
+ mock_tokenizer = mock.MagicMock()
+ mock_tokenizer.chat_template = None
+ trainer_config = SimpleNamespace(
+ chat_template=None,
+ chat_template_path=None,
+ tokenizer_path="dummy-base-model",
+ )
+ with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"):
+ train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
+
+ @pytest.mark.cpu_only
+ def test_chat_template_unchanged_when_already_exists(self):
+ """Test that an existing chat_template on the tokenizer is preserved (backward compatibility)."""
+ mock_tokenizer = mock.MagicMock()
+ mock_tokenizer.chat_template = "{{ existing_template }}"
+ trainer_config = SimpleNamespace(
+ chat_template="{{ overridden_template }}",
+ chat_template_path=None,
+ tokenizer_path="dummy-instruction-tuned-model",
+ )
+ train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
+ self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}")
+
+ @pytest.mark.cpu_only
+ def test_apply_chat_template_works_after_configuration(self):
+ """Verifies apply_chat_template succeeds and produces the expected format after our code path runs."""
+
+ class DummyTokenizer: # pylint: disable=missing-class-docstring
+
+ def __init__(self):
+ self.chat_template = None
+
+ def apply_chat_template(self, conversation, tokenize=False):
+ if self.chat_template is None:
+ raise ValueError("Cannot apply chat template because chat_template is None")
+ import jinja2 # pylint: disable=import-outside-toplevel
+
+ env = jinja2.Environment()
+ template = env.from_string(self.chat_template)
+ return template.render(messages=conversation)
+
+ tokenizer = DummyTokenizer()
+ trainer_config = SimpleNamespace(
+ chat_template="{{ messages[0].content }}",
+ chat_template_path=None,
+ tokenizer_path="dummy-base-model",
+ )
+ # Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)
+ with self.assertRaises(ValueError):
+ tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
+ # Run the proposed change
+ train_rl.configure_tokenizer_chat_template(tokenizer, trainer_config)
+ # Verify apply_chat_template now runs successfully and renders correct content
+ rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
+ self.assertEqual(rendered, "Hello!")
if __name__ == "__main__":