Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ tokenizer_path: ""
tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface"
use_chat_template: false
chat_template_path: "" # path to chat template json file
chat_template: "" # Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.
tokenize_train_data: true # false if the dataset is pre-tokenized
tokenize_eval_data: true # false if the dataset is pre-tokenized
add_bos: true
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ reasoning_start_token: '<reasoning>'
reasoning_end_token: '</reasoning>'
solution_start_token: '<answer>'
solution_end_token: '</answer>'
chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
data_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
skip_jax_distributed_system: true

# ====== Dataset Configuration ======
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,10 @@ class Tokenizer(BaseModel):
"",
description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.",
)
data_template_path: str = Field(
"",
description="Path to a chat template file to be used when tokenizing the dataset.",
)
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")
add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.")
Expand Down
31 changes: 27 additions & 4 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ def prepare_datasets(
model_tokenizer: AutoTokenizer,
) -> tuple[grain.IterDataset, grain.IterDataset | None]:
"""Setup and return train and test datasets."""
template_config = load_data_template_from_file(trainer_config.chat_template_path)
template_config = load_data_template_from_file(trainer_config.data_template_path)
if template_config is None:
raise ValueError(
f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}"
f"Chat template is required for processing dataset but failed to load from {trainer_config.data_template_path}"
)

# Prepare train and test data from training data for certain datasets
Expand Down Expand Up @@ -548,10 +548,10 @@ def _reward_fn(**kwargs):
epsilon_high=trainer_config.rl.epsilon_high,
)
# Instantiate the custom MaxText chat parser
template_config = load_data_template_from_file(trainer_config.chat_template_path)
template_config = load_data_template_from_file(trainer_config.data_template_path)
if template_config is None:
raise ValueError(
f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}"
f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.data_template_path}"
)
chat_parser = utils_rl.MaxTextChatParser(
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
Expand Down Expand Up @@ -582,6 +582,28 @@ def _reward_fn(**kwargs):
return rl_cluster, rl_trainer, optimizer


def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None:
"""Populates the tokenizer's chat_template from config if missing."""
if getattr(model_tokenizer, "chat_template", None) is None:
if getattr(trainer_config, "chat_template", None):
model_tokenizer.chat_template = trainer_config.chat_template
elif getattr(trainer_config, "chat_template_path", None):
from maxtext.input_pipeline.instruction_data_processing import ( # pylint: disable=import-outside-toplevel
load_chat_template_from_file,
)

model_tokenizer.chat_template = load_chat_template_from_file(trainer_config.chat_template_path)
else:
raise ValueError(
f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template "
"and config.chat_template / config.chat_template_path "
"are both empty. Either pick an instruction-tuned tokenizer that "
"ships with a chat_template, set config.chat_template to a Jinja "
"string, or set config.chat_template_path to a JSON file "
"with a 'chat_template' key."
)


def rl_train(argv: Sequence[str], kwargs: dict):
"""
Run RL training with the provided configuration.
Expand Down Expand Up @@ -616,6 +638,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict):
trainer_config.tokenizer_path,
token=trainer_config.hf_access_token or None,
)
configure_tokenizer_chat_template(model_tokenizer, trainer_config)

reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes(
trainer_config,
Expand Down
153 changes: 139 additions & 14 deletions tests/post_training/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self):
# Following the pattern in distillation_checkpointing_test.py for mocking jax objects
with (
mock.patch.object(jax, "devices", return_value=mock_devices),
mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
mock.patch(
"maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
return_value=mock_config,
),
):
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
["dummy", "dummy"]
Expand Down Expand Up @@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):

with (
mock.patch.object(jax, "devices", return_value=mock_devices),
mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
mock.patch(
"maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
return_value=mock_config,
),
):
_, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"])

Expand Down Expand Up @@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 4,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result)
self.assertEqual(
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16),
expected_result,
)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_auto_tp(self):
Expand All @@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 1,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
self.assertEqual(
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
expected_result,
)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_fixed_tp_dp(self):
Expand All @@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 1,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
self.assertEqual(
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
expected_result,
)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_auto_ep(self):
Expand All @@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self):
"tensor_parallel_size": 2,
"expert_parallel_size": 2,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result)
self.assertEqual(
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8),
expected_result,
)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_errors(self):
Expand Down Expand Up @@ -307,7 +325,10 @@ def tokenize_side_effect(text):
{"question": "short", "answer": "a3"},
{"question": "long", "answer": "a4"},
]
test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}]
test_data = [
{"question": "short", "answer": "a5"},
{"question": "long", "answer": "a6"},
]
train_map_ds = grain.MapDataset.source(train_data)
test_map_ds = grain.MapDataset.source(test_data)

Expand All @@ -334,7 +355,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
eval_split="eval",
hf_train_files=None,
hf_eval_files=None,
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_shuffle_seed=42,
max_prefill_predict_length=10,
batch_size=2,
Expand All @@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
)

with (
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect),
mock.patch(
"maxtext.trainers.post_train.rl.train_rl.get_dataset",
side_effect=get_dataset_side_effect,
),
mock.patch(
"maxtext.trainers.post_train.rl.utils_rl.process_data",
side_effect=get_filtered_data_side_effect,
),
):
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)

Expand Down Expand Up @@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
def test_prepare_datasets_with_split(self, mock_load):
mock_ds = mock.MagicMock()
mock_split_result = {
"train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}],
"train": [
{"question": "q1", "answer": "a1"},
{"question": "q2", "answer": "a2"},
],
"test": [{"question": "q3", "answer": "a3"}],
}
mock_ds.train_test_split.return_value = mock_split_result
Expand All @@ -389,7 +419,7 @@ def test_prepare_datasets_with_split(self, mock_load):
eval_dataset_name="open-r1/OpenR1-Math-220k",
train_split="train",
hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet",
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_shuffle_seed=42,
num_batches=1,
batch_size=5,
Expand Down Expand Up @@ -435,7 +465,7 @@ def test_prepare_datasets_without_split(self, mock_load):
eval_split="test",
hf_train_files="hf://openai/gsm8k/data/dummy.parquet",
hf_eval_files="hf://openai/gsm8k/data/dummy.parquet",
chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json",
data_shuffle_seed=42,
num_batches=1,
batch_size=5,
Expand Down Expand Up @@ -480,7 +510,102 @@ def test_rl_train_invalid_vocab_tiling(self, mock_setup):
mock_setup.return_value = (mock_config, mock_config, [], [])

with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"):
train_rl._rl_train_impl([], {})
train_rl._rl_train_impl([], {}) # pylint: disable=protected-access


class TokenizerChatTemplateTest(unittest.TestCase):
"""Unit tests for configure_tokenizer_chat_template."""

@pytest.mark.cpu_only
def test_chat_template_populated_from_config_string(self):
"""Test that chat_template is set from config.chat_template when tokenizer lacks one."""
mock_tokenizer = mock.MagicMock()
mock_tokenizer.chat_template = None
trainer_config = SimpleNamespace(
chat_template="{{ messages[0].content }}",
chat_template_path=None,
tokenizer_path="dummy-base-model",
)
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}")

@pytest.mark.cpu_only
@mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file")
def test_chat_template_populated_from_config_file(self, mock_load):
"""Test that chat_template is loaded from chat_template_path when tokenizer lacks one."""
mock_tokenizer = mock.MagicMock()
mock_tokenizer.chat_template = None
mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
trainer_config = SimpleNamespace(
chat_template=None,
chat_template_path="/path/to/jinja_template.json",
tokenizer_path="dummy-base-model",
)
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
mock_load.assert_called_once_with("/path/to/jinja_template.json")
self.assertEqual(
mock_tokenizer.chat_template,
"{% for message in messages %}{{ message.content }}{% endfor %}",
)

@pytest.mark.cpu_only
def test_chat_template_raises_value_error_when_empty(self):
"""Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty."""
mock_tokenizer = mock.MagicMock()
mock_tokenizer.chat_template = None
trainer_config = SimpleNamespace(
chat_template=None,
chat_template_path=None,
tokenizer_path="dummy-base-model",
)
with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"):
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)

@pytest.mark.cpu_only
def test_chat_template_unchanged_when_already_exists(self):
"""Test that an existing chat_template on the tokenizer is preserved (backward compatibility)."""
mock_tokenizer = mock.MagicMock()
mock_tokenizer.chat_template = "{{ existing_template }}"
trainer_config = SimpleNamespace(
chat_template="{{ overridden_template }}",
chat_template_path=None,
tokenizer_path="dummy-instruction-tuned-model",
)
train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config)
self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}")

@pytest.mark.cpu_only
def test_apply_chat_template_works_after_configuration(self):
"""Verifies apply_chat_template succeeds and produces the expected format after our code path runs."""

class DummyTokenizer: # pylint: disable=missing-class-docstring

def __init__(self):
self.chat_template = None

def apply_chat_template(self, conversation, tokenize=False):
if self.chat_template is None:
raise ValueError("Cannot apply chat template because chat_template is None")
import jinja2 # pylint: disable=import-outside-toplevel

env = jinja2.Environment()
template = env.from_string(self.chat_template)
return template.render(messages=conversation)

tokenizer = DummyTokenizer()
trainer_config = SimpleNamespace(
chat_template="{{ messages[0].content }}",
chat_template_path=None,
tokenizer_path="dummy-base-model",
)
# Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)
with self.assertRaises(ValueError):
tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
# Run the proposed change
train_rl.configure_tokenizer_chat_template(tokenizer, trainer_config)
# Verify apply_chat_template now runs successfully and renders correct content
rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}])
self.assertEqual(rendered, "Hello!")


if __name__ == "__main__":
Expand Down
Loading