Skip to content

Commit f3e108a

Browse files
committed
Adding data_template_path, and using chat_template_path for tokenizer template
1 parent ab3b935 commit f3e108a

4 files changed

Lines changed: 20 additions & 20 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ reasoning_start_token: '<reasoning>'
208208
reasoning_end_token: '</reasoning>'
209209
solution_start_token: '<answer>'
210210
solution_end_token: '</answer>'
211-
chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
211+
data_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
212212
skip_jax_distributed_system: true
213213

214214
# ====== Dataset Configuration ======

src/maxtext/configs/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,9 +1081,9 @@ class Tokenizer(BaseModel):
10811081
"",
10821082
description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.",
10831083
)
1084-
tokenizer_chat_template_path: str = Field(
1084+
data_template_path: str = Field(
10851085
"",
1086-
description="Path to a chat template file to be loaded into the tokenizer if missing.",
1086+
description="Path to a chat template file to be used when tokenizing the dataset. Used in RL workloads to provide the conversation.",
10871087
)
10881088
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
10891089
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ def prepare_datasets(
248248
model_tokenizer: AutoTokenizer,
249249
) -> tuple[grain.IterDataset, grain.IterDataset | None]:
250250
"""Setup and return train and test datasets."""
251-
template_config = load_data_template_from_file(trainer_config.chat_template_path)
251+
template_config = load_data_template_from_file(trainer_config.data_template_path)
252252
if template_config is None:
253253
raise ValueError(
254-
f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}"
254+
f"Chat template is required for processing dataset but failed to load from {trainer_config.data_template_path}"
255255
)
256256

257257
# Prepare train and test data from training data for certain datasets
@@ -548,10 +548,10 @@ def _reward_fn(**kwargs):
548548
epsilon_high=trainer_config.rl.epsilon_high,
549549
)
550550
# Instantiate the custom MaxText chat parser
551-
template_config = load_data_template_from_file(trainer_config.chat_template_path)
551+
template_config = load_data_template_from_file(trainer_config.data_template_path)
552552
if template_config is None:
553553
raise ValueError(
554-
f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}"
554+
f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.data_template_path}"
555555
)
556556
chat_parser = utils_rl.MaxTextChatParser(
557557
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
@@ -588,20 +588,20 @@ def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any)
588588
if getattr(model_tokenizer, "chat_template", None) is None:
589589
if getattr(trainer_config, "chat_template", None):
590590
model_tokenizer.chat_template = trainer_config.chat_template
591-
elif getattr(trainer_config, "tokenizer_chat_template_path", None):
591+
elif getattr(trainer_config, "chat_template_path", None):
592592
from maxtext.input_pipeline.instruction_data_processing import (
593593
load_chat_template_from_file,
594594
)
595595
model_tokenizer.chat_template = load_chat_template_from_file(
596-
trainer_config.tokenizer_chat_template_path
596+
trainer_config.chat_template_path
597597
)
598598
else:
599599
raise ValueError(
600600
f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template "
601-
"and config.chat_template / config.tokenizer_chat_template_path "
601+
"and config.chat_template / config.chat_template_path "
602602
"are both empty. Either pick an instruction-tuned tokenizer that "
603603
"ships with a chat_template, set config.chat_template to a Jinja "
604-
"string, or set config.tokenizer_chat_template_path to a JSON file "
604+
"string, or set config.chat_template_path to a JSON file "
605605
"with a 'chat_template' key."
606606
)
607607

tests/post_training/unit/train_rl_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
334334
eval_split="eval",
335335
hf_train_files=None,
336336
hf_eval_files=None,
337-
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
337+
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
338338
data_shuffle_seed=42,
339339
max_prefill_predict_length=10,
340340
batch_size=2,
@@ -389,7 +389,7 @@ def test_prepare_datasets_with_split(self, mock_load):
389389
eval_dataset_name="open-r1/OpenR1-Math-220k",
390390
train_split="train",
391391
hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet",
392-
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
392+
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
393393
data_shuffle_seed=42,
394394
num_batches=1,
395395
batch_size=5,
@@ -435,7 +435,7 @@ def test_prepare_datasets_without_split(self, mock_load):
435435
eval_split="test",
436436
hf_train_files="hf://openai/gsm8k/data/dummy.parquet",
437437
hf_eval_files="hf://openai/gsm8k/data/dummy.parquet",
438-
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
438+
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
439439
data_shuffle_seed=42,
440440
num_batches=1,
441441
batch_size=5,
@@ -482,7 +482,7 @@ def test_chat_template_populated_from_config_string(self):
482482
mock_tokenizer.chat_template = None
483483
trainer_config = SimpleNamespace(
484484
chat_template="{{ messages[0].content }}",
485-
tokenizer_chat_template_path=None,
485+
chat_template_path=None,
486486
tokenizer_path="dummy-base-model",
487487
)
488488
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
@@ -491,13 +491,13 @@ def test_chat_template_populated_from_config_string(self):
491491
@pytest.mark.cpu_only
492492
@mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file")
493493
def test_chat_template_populated_from_config_file(self, mock_load):
494-
"""Test that chat_template is loaded from tokenizer_chat_template_path when tokenizer lacks one."""
494+
"""Test that chat_template is loaded from chat_template_path when tokenizer lacks one."""
495495
mock_tokenizer = mock.MagicMock()
496496
mock_tokenizer.chat_template = None
497497
mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
498498
trainer_config = SimpleNamespace(
499499
chat_template=None,
500-
tokenizer_chat_template_path="/path/to/jinja_template.json",
500+
chat_template_path="/path/to/jinja_template.json",
501501
tokenizer_path="dummy-base-model",
502502
)
503503
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
@@ -513,7 +513,7 @@ def test_chat_template_raises_value_error_when_empty(self):
513513
mock_tokenizer.chat_template = None
514514
trainer_config = SimpleNamespace(
515515
chat_template=None,
516-
tokenizer_chat_template_path=None,
516+
chat_template_path=None,
517517
tokenizer_path="dummy-base-model",
518518
)
519519
with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"):
@@ -526,7 +526,7 @@ def test_chat_template_unchanged_when_already_exists(self):
526526
mock_tokenizer.chat_template = "{{ existing_template }}"
527527
trainer_config = SimpleNamespace(
528528
chat_template="{{ overridden_template }}",
529-
tokenizer_chat_template_path=None,
529+
chat_template_path=None,
530530
tokenizer_path="dummy-instruction-tuned-model",
531531
)
532532
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
@@ -549,7 +549,7 @@ def apply_chat_template(self, conversation, tokenize=False):
549549
tokenizer = DummyTokenizer()
550550
trainer_config = SimpleNamespace(
551551
chat_template="{{ messages[0].content }}",
552-
tokenizer_chat_template_path=None,
552+
chat_template_path=None,
553553
tokenizer_path="dummy-base-model",
554554
)
555555
# Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)

0 commit comments

Comments
 (0)