Skip to content

Commit 98c79e4

Browse files
Merge pull request #4049 from AI-Hypercomputer:davidsotomora-rl-chat-template
PiperOrigin-RevId: 933980665
2 parents 6eacf4e + 0ad94d0 commit 98c79e4

5 files changed

Lines changed: 171 additions & 18 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ tokenizer_path: ""
658658
tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface"
659659
use_chat_template: false
660660
chat_template_path: "" # path to chat template json file
661+
chat_template: "" # Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.
661662
tokenize_train_data: true # false if the dataset is pre-tokenized
662663
tokenize_eval_data: true # false if the dataset is pre-tokenized
663664
add_bos: true

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ reasoning_start_token: '<reasoning>'
213213
reasoning_end_token: '</reasoning>'
214214
solution_start_token: '<answer>'
215215
solution_end_token: '</answer>'
216-
chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
216+
data_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
217217
skip_jax_distributed_system: true
218218

219219
# ====== Dataset Configuration ======

src/maxtext/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,10 @@ class Tokenizer(BaseModel):
11021102
"",
11031103
description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.",
11041104
)
1105+
data_template_path: str = Field(
1106+
"",
1107+
description="Path to a chat template file to be used when tokenizing the dataset.",
1108+
)
11051109
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
11061110
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")
11071111
add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ def prepare_datasets(
248248
model_tokenizer: AutoTokenizer,
249249
) -> tuple[grain.IterDataset, grain.IterDataset | None]:
250250
"""Setup and return train and test datasets."""
251-
template_config = load_data_template_from_file(trainer_config.chat_template_path)
251+
template_config = load_data_template_from_file(trainer_config.data_template_path)
252252
if template_config is None:
253253
raise ValueError(
254-
f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}"
254+
f"Data template is required for processing dataset but failed to load from {trainer_config.data_template_path}"
255255
)
256256

257257
# Optional user-provided `process_data(dataset_name, tokenizer, template, config, x) -> dict`.
@@ -559,10 +559,10 @@ def _reward_fn(**kwargs):
559559
epsilon_high=trainer_config.rl.epsilon_high,
560560
)
561561
# Instantiate the custom MaxText chat parser
562-
template_config = load_data_template_from_file(trainer_config.chat_template_path)
562+
template_config = load_data_template_from_file(trainer_config.data_template_path)
563563
if template_config is None:
564564
raise ValueError(
565-
f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}"
565+
f"Data template is required for AgenticGRPOLearner but failed to load from {trainer_config.data_template_path}"
566566
)
567567
chat_parser = utils_rl.MaxTextChatParser(
568568
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
@@ -593,6 +593,28 @@ def _reward_fn(**kwargs):
593593
return rl_cluster, rl_trainer, optimizer, reward_fns
594594

595595

596+
def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None:
597+
"""Populates the tokenizer's chat_template from config if missing."""
598+
if getattr(model_tokenizer, "chat_template", None) is None:
599+
if getattr(trainer_config, "chat_template", None):
600+
model_tokenizer.chat_template = trainer_config.chat_template
601+
elif getattr(trainer_config, "chat_template_path", None):
602+
from maxtext.input_pipeline.instruction_data_processing import ( # pylint: disable=import-outside-toplevel
603+
load_chat_template_from_file,
604+
)
605+
606+
model_tokenizer.chat_template = load_chat_template_from_file(trainer_config.chat_template_path)
607+
else:
608+
raise ValueError(
609+
f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template "
610+
"and config.chat_template / config.chat_template_path "
611+
"are both empty. Either pick an instruction-tuned tokenizer that "
612+
"ships with a chat_template, set config.chat_template to a Jinja "
613+
"string, or set config.chat_template_path to a JSON file "
614+
"with a 'chat_template' key."
615+
)
616+
617+
596618
def rl_train(argv: Sequence[str], kwargs: dict):
597619
"""
598620
Run RL training with the provided configuration.
@@ -638,6 +660,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
638660
trainer_config.tokenizer_path,
639661
token=trainer_config.hf_access_token or None,
640662
)
663+
configure_tokenizer_chat_template(model_tokenizer, trainer_config)
641664

642665
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes(
643666
trainer_config,

tests/post_training/unit/train_rl_test.py

Lines changed: 138 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self):
5858
# Following the pattern in distillation_checkpointing_test.py for mocking jax objects
5959
with (
6060
mock.patch.object(jax, "devices", return_value=mock_devices),
61-
mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
61+
mock.patch(
62+
"maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
63+
return_value=mock_config,
64+
),
6265
):
6366
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
6467
["dummy", "dummy"]
@@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
8790

8891
with (
8992
mock.patch.object(jax, "devices", return_value=mock_devices),
90-
mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
93+
mock.patch(
94+
"maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
95+
return_value=mock_config,
96+
),
9197
):
9298
_, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"])
9399

@@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self):
189195
"tensor_parallel_size": 2,
190196
"expert_parallel_size": 4,
191197
}
192-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result)
198+
self.assertEqual(
199+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16),
200+
expected_result,
201+
)
193202

194203
@pytest.mark.cpu_only
195204
def test_get_rollout_kwargs_auto_tp(self):
@@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self):
204213
"tensor_parallel_size": 2,
205214
"expert_parallel_size": 1,
206215
}
207-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
216+
self.assertEqual(
217+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
218+
expected_result,
219+
)
208220

209221
@pytest.mark.cpu_only
210222
def test_get_rollout_kwargs_fixed_tp_dp(self):
@@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self):
219231
"tensor_parallel_size": 2,
220232
"expert_parallel_size": 1,
221233
}
222-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
234+
self.assertEqual(
235+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
236+
expected_result,
237+
)
223238

224239
@pytest.mark.cpu_only
225240
def test_get_rollout_kwargs_auto_ep(self):
@@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self):
235250
"tensor_parallel_size": 2,
236251
"expert_parallel_size": 2,
237252
}
238-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result)
253+
self.assertEqual(
254+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8),
255+
expected_result,
256+
)
239257

240258
@pytest.mark.cpu_only
241259
def test_get_rollout_kwargs_errors(self):
@@ -307,7 +325,10 @@ def tokenize_side_effect(text):
307325
{"question": "short", "answer": "a3"},
308326
{"question": "long", "answer": "a4"},
309327
]
310-
test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}]
328+
test_data = [
329+
{"question": "short", "answer": "a5"},
330+
{"question": "long", "answer": "a6"},
331+
]
311332
train_map_ds = grain.MapDataset.source(train_data)
312333
test_map_ds = grain.MapDataset.source(test_data)
313334

@@ -334,7 +355,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
334355
eval_split="eval",
335356
hf_train_files=None,
336357
hf_eval_files=None,
337-
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
358+
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
338359
data_shuffle_seed=42,
339360
max_prefill_predict_length=10,
340361
batch_size=2,
@@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
346367
)
347368

348369
with (
349-
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
350-
mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect),
370+
mock.patch(
371+
"maxtext.trainers.post_train.rl.train_rl.get_dataset",
372+
side_effect=get_dataset_side_effect,
373+
),
374+
mock.patch(
375+
"maxtext.trainers.post_train.rl.utils_rl.process_data",
376+
side_effect=get_filtered_data_side_effect,
377+
),
351378
):
352379
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)
353380

@@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
378405
def test_prepare_datasets_with_split(self, mock_load):
379406
mock_ds = mock.MagicMock()
380407
mock_split_result = {
381-
"train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}],
408+
"train": [
409+
{"question": "q1", "answer": "a1"},
410+
{"question": "q2", "answer": "a2"},
411+
],
382412
"test": [{"question": "q3", "answer": "a3"}],
383413
}
384414
mock_ds.train_test_split.return_value = mock_split_result
@@ -389,7 +419,7 @@ def test_prepare_datasets_with_split(self, mock_load):
389419
eval_dataset_name="open-r1/OpenR1-Math-220k",
390420
train_split="train",
391421
hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet",
392-
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
422+
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
393423
data_shuffle_seed=42,
394424
num_batches=1,
395425
batch_size=5,
@@ -435,7 +465,7 @@ def test_prepare_datasets_without_split(self, mock_load):
435465
eval_split="test",
436466
hf_train_files="hf://openai/gsm8k/data/dummy.parquet",
437467
hf_eval_files="hf://openai/gsm8k/data/dummy.parquet",
438-
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
468+
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
439469
data_shuffle_seed=42,
440470
num_batches=1,
441471
batch_size=5,
@@ -496,5 +526,100 @@ def test_rl_train_invalid_optimizer_memory_host_offload(self, mock_setup):
496526
train_rl._rl_train_impl([], {}) # pylint: disable=protected-access
497527

498528

529+
class TokenizerChatTemplateTest(unittest.TestCase):
530+
"""Unit tests for configure_tokenizer_chat_template."""
531+
532+
@pytest.mark.cpu_only
533+
def test_chat_template_populated_from_config_string(self):
534+
"""Test that chat_template is set from config.chat_template when tokenizer lacks one."""
535+
mock_tokenizer = mock.MagicMock()
536+
mock_tokenizer.chat_template = None
537+
trainer_config = SimpleNamespace(
538+
chat_template="{{ messages[0].content }}",
539+
chat_template_path=None,
540+
tokenizer_path="dummy-base-model",
541+
)
542+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
543+
self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}")
544+
545+
@pytest.mark.cpu_only
546+
@mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file")
547+
def test_chat_template_populated_from_config_file(self, mock_load):
548+
"""Test that chat_template is loaded from chat_template_path when tokenizer lacks one."""
549+
mock_tokenizer = mock.MagicMock()
550+
mock_tokenizer.chat_template = None
551+
mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
552+
trainer_config = SimpleNamespace(
553+
chat_template=None,
554+
chat_template_path="/path/to/jinja_template.json",
555+
tokenizer_path="dummy-base-model",
556+
)
557+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
558+
mock_load.assert_called_once_with("/path/to/jinja_template.json")
559+
self.assertEqual(
560+
mock_tokenizer.chat_template,
561+
"{% for message in messages %}{{ message.content }}{% endfor %}",
562+
)
563+
564+
@pytest.mark.cpu_only
565+
def test_chat_template_raises_value_error_when_empty(self):
566+
"""Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty."""
567+
mock_tokenizer = mock.MagicMock()
568+
mock_tokenizer.chat_template = None
569+
trainer_config = SimpleNamespace(
570+
chat_template=None,
571+
chat_template_path=None,
572+
tokenizer_path="dummy-base-model",
573+
)
574+
with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"):
575+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
576+
577+
@pytest.mark.cpu_only
578+
def test_chat_template_unchanged_when_already_exists(self):
579+
"""Test that an existing chat_template on the tokenizer is preserved (backward compatibility)."""
580+
mock_tokenizer = mock.MagicMock()
581+
mock_tokenizer.chat_template = "{{ existing_template }}"
582+
trainer_config = SimpleNamespace(
583+
chat_template="{{ overridden_template }}",
584+
chat_template_path=None,
585+
tokenizer_path="dummy-instruction-tuned-model",
586+
)
587+
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
588+
self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}")
589+
590+
@pytest.mark.cpu_only
591+
def test_apply_chat_template_works_after_configuration(self):
592+
"""Verifies apply_chat_template succeeds and produces the expected format after our code path runs."""
593+
594+
class DummyTokenizer: # pylint: disable=missing-class-docstring
595+
596+
def __init__(self):
597+
self.chat_template = None
598+
599+
def apply_chat_template(self, conversation, tokenize=False):
600+
if self.chat_template is None:
601+
raise ValueError("Cannot apply chat template because chat_template is None")
602+
import jinja2 # pylint: disable=import-outside-toplevel
603+
604+
env = jinja2.Environment()
605+
template = env.from_string(self.chat_template)
606+
return template.render(messages=conversation)
607+
608+
tokenizer = DummyTokenizer()
609+
trainer_config = SimpleNamespace(
610+
chat_template="{{ messages[0].content }}",
611+
chat_template_path=None,
612+
tokenizer_path="dummy-base-model",
613+
)
614+
# Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)
615+
with self.assertRaises(ValueError):
616+
tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
617+
# Run the proposed change
618+
train_rl.configure_tokenizer_chat_template(tokenizer, trainer_config)
619+
# Verify apply_chat_template now runs successfully and renders correct content
620+
rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
621+
self.assertEqual(rendered, "Hello!")
622+
623+
499624
if __name__ == "__main__":
500625
unittest.main()

0 commit comments

Comments
 (0)