From 6c8a1d450a93457294a8f2b3a5a9ad5922932697 Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Tue, 2 Jun 2026 17:14:44 +0000 Subject: [PATCH 1/7] Adding option for rl workloads that use models without chat_templates to specify one --- .../trainers/post_train/rl/train_rl.py | 25 ++++++ tests/post_training/unit/train_rl_test.py | 89 +++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index a7df87c8e5..927289edc8 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -582,6 +582,30 @@ def _reward_fn(**kwargs): return rl_cluster, rl_trainer, optimizer + +def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None: + """Populates the tokenizer's chat_template from config if missing.""" + if getattr(model_tokenizer, "chat_template", None) is None: + if getattr(trainer_config, "chat_template", None): + model_tokenizer.chat_template = trainer_config.chat_template + elif getattr(trainer_config, "tokenizer_chat_template_path", None): + from maxtext.input_pipeline.instruction_data_processing import ( + load_chat_template_from_file, + ) + model_tokenizer.chat_template = load_chat_template_from_file( + trainer_config.tokenizer_chat_template_path + ) + else: + raise ValueError( + f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template " + "and config.chat_template / config.tokenizer_chat_template_path " + "are both empty. Either pick an instruction-tuned tokenizer that " + "ships with a chat_template, set config.chat_template to a Jinja " + "string, or set config.tokenizer_chat_template_path to a JSON file " + "with a 'chat_template' key." + ) + + def rl_train(argv: Sequence[str], kwargs: dict): """ Run RL training with the provided configuration. @@ -610,6 +634,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict): trainer_config.tokenizer_path, token=trainer_config.hf_access_token or None, ) + configure_tokenizer_chat_template(model_tokenizer, trainer_config) reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes( trainer_config, diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index adce3bf9e1..8c9cbae8ec 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -472,5 +472,94 @@ def test_prepare_datasets_without_split(self, mock_load): assert mock_load.call_count == len(expected_calls) +class TokenizerChatTemplateTest(unittest.TestCase): + """Unit tests for configure_tokenizer_chat_template.""" + + @pytest.mark.cpu_only + def test_chat_template_populated_from_config_string(self): + """Test that chat_template is set from config.chat_template when tokenizer lacks one.""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = None + trainer_config = SimpleNamespace( + chat_template="{{ messages[0].content }}", + tokenizer_chat_template_path=None, + tokenizer_path="dummy-base-model", + ) + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}") + + @pytest.mark.cpu_only + @mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file") + def test_chat_template_populated_from_config_file(self, mock_load): + """Test that chat_template is loaded from tokenizer_chat_template_path when tokenizer lacks one.""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = None + mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}" + trainer_config = SimpleNamespace( + chat_template=None, + tokenizer_chat_template_path="/path/to/jinja_template.json", + tokenizer_path="dummy-base-model", + ) + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + mock_load.assert_called_once_with("/path/to/jinja_template.json") + self.assertEqual( + mock_tokenizer.chat_template, + "{% for message in messages %}{{ message.content }}{% endfor %}", + ) + @pytest.mark.cpu_only + def test_chat_template_raises_value_error_when_empty(self): + """Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty.""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = None + trainer_config = SimpleNamespace( + chat_template=None, + tokenizer_chat_template_path=None, + tokenizer_path="dummy-base-model", + ) + with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"): + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + + @pytest.mark.cpu_only + def test_chat_template_unchanged_when_already_exists(self): + """Test that an existing chat_template on the tokenizer is preserved (backward compatibility).""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = "{{ existing_template }}" + trainer_config = SimpleNamespace( + chat_template="{{ overridden_template }}", + tokenizer_chat_template_path=None, + tokenizer_path="dummy-instruction-tuned-model", + ) + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}") + + @pytest.mark.cpu_only + def test_apply_chat_template_works_after_configuration(self): + """Verifies apply_chat_template succeeds and produces the expected format after our code path runs.""" + class DummyTokenizer: + def __init__(self): + self.chat_template = None + + def apply_chat_template(self, conversation, tokenize=False): + if self.chat_template is None: + raise ValueError("Cannot apply chat template because chat_template is None") + import jinja2 + env = jinja2.Environment() + template = env.from_string(self.chat_template) + return template.render(messages=conversation) + tokenizer = DummyTokenizer() + trainer_config = SimpleNamespace( + chat_template="{{ messages[0].content }}", + tokenizer_chat_template_path=None, + tokenizer_path="dummy-base-model", + ) + # Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None) + with self.assertRaises(ValueError): + tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}]) + # Run the proposed change + train_rl.configure_tokenizer_chat_template(tokenizer, trainer_config) + # Verify apply_chat_template now runs successfully and renders correct content + rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}]) + self.assertEqual(rendered, "Hello!") + if __name__ == "__main__": unittest.main() From ab3b93546aa6b13ed5f812363eb938ef923e1fe0 Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Tue, 2 Jun 2026 20:37:14 +0000 Subject: [PATCH 2/7] Adding tokenizer_chat_template_path to valid inputs for config --- src/maxtext/configs/base.yml | 2 ++ src/maxtext/configs/types.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2570f5c915..3b6b9bc62c 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -645,6 +645,8 @@ tokenizer_path: "" tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface" use_chat_template: false chat_template_path: "" # path to chat template json file +chat_template: "" # Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template. +tokenizer_chat_template_path: "" # Path to a chat template file to be loaded into the tokenizer if missing. tokenize_train_data: true # false if the dataset is pre-tokenized tokenize_eval_data: true # false if the dataset is pre-tokenized add_bos: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 55febb6f54..6ae0ffdcbf 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1081,6 +1081,10 @@ class Tokenizer(BaseModel): "", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.", ) + tokenizer_chat_template_path: str = Field( + "", + description="Path to a chat template file to be loaded into the tokenizer if missing.", + ) tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.") tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.") From f3e108a302490bd9392d04671d9c0b69b32bae8c Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Wed, 3 Jun 2026 05:00:11 +0000 Subject: [PATCH 3/7] Adding data_template_path, and using chat_template_path for tokenizer template --- src/maxtext/configs/post_train/rl.yml | 2 +- src/maxtext/configs/types.py | 4 ++-- src/maxtext/trainers/post_train/rl/train_rl.py | 16 ++++++++-------- tests/post_training/unit/train_rl_test.py | 18 +++++++++--------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 133d66c730..94011ba5c1 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -208,7 +208,7 @@ reasoning_start_token: '' reasoning_end_token: '' solution_start_token: '' solution_end_token: '' -chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' +data_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' skip_jax_distributed_system: true # ====== Dataset Configuration ====== diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 6ae0ffdcbf..dfc55e0342 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1081,9 +1081,9 @@ class Tokenizer(BaseModel): "", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.", ) - tokenizer_chat_template_path: str = Field( + data_template_path: str = Field( "", - description="Path to a chat template file to be loaded into the tokenizer if missing.", + description="Path to a chat template file to be used when tokenizing the dataset. Used in RL workloads to provide the conversation.", ) tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.") tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 927289edc8..b52239fbce 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -248,10 +248,10 @@ def prepare_datasets( model_tokenizer: AutoTokenizer, ) -> tuple[grain.IterDataset, grain.IterDataset | None]: """Setup and return train and test datasets.""" - template_config = load_data_template_from_file(trainer_config.chat_template_path) + template_config = load_data_template_from_file(trainer_config.data_template_path) if template_config is None: raise ValueError( - f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}" + f"Chat template is required for processing dataset but failed to load from {trainer_config.data_template_path}" ) # Prepare train and test data from training data for certain datasets @@ -548,10 +548,10 @@ def _reward_fn(**kwargs): epsilon_high=trainer_config.rl.epsilon_high, ) # Instantiate the custom MaxText chat parser - template_config = load_data_template_from_file(trainer_config.chat_template_path) + template_config = load_data_template_from_file(trainer_config.data_template_path) if template_config is None: raise ValueError( - f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}" + f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.data_template_path}" ) chat_parser = utils_rl.MaxTextChatParser( model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config @@ -588,20 +588,20 @@ def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) if getattr(model_tokenizer, "chat_template", None) is None: if getattr(trainer_config, "chat_template", None): model_tokenizer.chat_template = trainer_config.chat_template - elif getattr(trainer_config, "tokenizer_chat_template_path", None): + elif getattr(trainer_config, "chat_template_path", None): from maxtext.input_pipeline.instruction_data_processing import ( load_chat_template_from_file, ) model_tokenizer.chat_template = load_chat_template_from_file( - trainer_config.tokenizer_chat_template_path + trainer_config.chat_template_path ) else: raise ValueError( f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template " - "and config.chat_template / config.tokenizer_chat_template_path " + "and config.chat_template / config.chat_template_path " "are both empty. Either pick an instruction-tuned tokenizer that " "ships with a chat_template, set config.chat_template to a Jinja " - "string, or set config.tokenizer_chat_template_path to a JSON file " + "string, or set config.chat_template_path to a JSON file " "with a 'chat_template' key." ) diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index 8c9cbae8ec..b8accaf4c7 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -334,7 +334,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config eval_split="eval", hf_train_files=None, hf_eval_files=None, - chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", data_shuffle_seed=42, max_prefill_predict_length=10, batch_size=2, @@ -389,7 +389,7 @@ def test_prepare_datasets_with_split(self, mock_load): eval_dataset_name="open-r1/OpenR1-Math-220k", train_split="train", hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet", - chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", data_shuffle_seed=42, num_batches=1, batch_size=5, @@ -435,7 +435,7 @@ def test_prepare_datasets_without_split(self, mock_load): eval_split="test", hf_train_files="hf://openai/gsm8k/data/dummy.parquet", hf_eval_files="hf://openai/gsm8k/data/dummy.parquet", - chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", data_shuffle_seed=42, num_batches=1, batch_size=5, @@ -482,7 +482,7 @@ def test_chat_template_populated_from_config_string(self): mock_tokenizer.chat_template = None trainer_config = SimpleNamespace( chat_template="{{ messages[0].content }}", - tokenizer_chat_template_path=None, + chat_template_path=None, tokenizer_path="dummy-base-model", ) train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) @@ -491,13 +491,13 @@ def test_chat_template_populated_from_config_string(self): @pytest.mark.cpu_only @mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file") def test_chat_template_populated_from_config_file(self, mock_load): - """Test that chat_template is loaded from tokenizer_chat_template_path when tokenizer lacks one.""" + """Test that chat_template is loaded from chat_template_path when tokenizer lacks one.""" mock_tokenizer = mock.MagicMock() mock_tokenizer.chat_template = None mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}" trainer_config = SimpleNamespace( chat_template=None, - tokenizer_chat_template_path="/path/to/jinja_template.json", + chat_template_path="/path/to/jinja_template.json", tokenizer_path="dummy-base-model", ) train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) @@ -513,7 +513,7 @@ def test_chat_template_raises_value_error_when_empty(self): mock_tokenizer.chat_template = None trainer_config = SimpleNamespace( chat_template=None, - tokenizer_chat_template_path=None, + chat_template_path=None, tokenizer_path="dummy-base-model", ) with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"): @@ -526,7 +526,7 @@ def test_chat_template_unchanged_when_already_exists(self): mock_tokenizer.chat_template = "{{ existing_template }}" trainer_config = SimpleNamespace( chat_template="{{ overridden_template }}", - tokenizer_chat_template_path=None, + chat_template_path=None, tokenizer_path="dummy-instruction-tuned-model", ) train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) @@ -549,7 +549,7 @@ def apply_chat_template(self, conversation, tokenize=False): tokenizer = DummyTokenizer() trainer_config = SimpleNamespace( chat_template="{{ messages[0].content }}", - tokenizer_chat_template_path=None, + chat_template_path=None, tokenizer_path="dummy-base-model", ) # Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None) From 4957551e0fb959c23e8e86a66a60f4f5989da38c Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Wed, 3 Jun 2026 05:20:32 +0000 Subject: [PATCH 4/7] removing old key --- src/maxtext/configs/base.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 3b6b9bc62c..73e28c33a4 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -646,7 +646,6 @@ tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepie use_chat_template: false chat_template_path: "" # path to chat template json file chat_template: "" # Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template. -tokenizer_chat_template_path: "" # Path to a chat template file to be loaded into the tokenizer if missing. tokenize_train_data: true # false if the dataset is pre-tokenized tokenize_eval_data: true # false if the dataset is pre-tokenized add_bos: true From 6aeaba5e7e0ec6a9a3924d204bf6dfb7df767378 Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Wed, 3 Jun 2026 18:57:41 +0000 Subject: [PATCH 5/7] Linting --- src/maxtext/configs/types.py | 2 +- src/maxtext/trainers/post_train/rl/train_rl.py | 2 +- tests/post_training/unit/train_rl_test.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index dfc55e0342..a97fc252af 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1083,7 +1083,7 @@ class Tokenizer(BaseModel): ) data_template_path: str = Field( "", - description="Path to a chat template file to be used when tokenizing the dataset. Used in RL workloads to provide the conversation.", + description="Path to a chat template file to be used when tokenizing the dataset.", ) tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.") tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index b52239fbce..9fb622cc88 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -589,7 +589,7 @@ def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) if getattr(trainer_config, "chat_template", None): model_tokenizer.chat_template = trainer_config.chat_template elif getattr(trainer_config, "chat_template_path", None): - from maxtext.input_pipeline.instruction_data_processing import ( + from maxtext.input_pipeline.instruction_data_processing import ( # pylint: disable=import-outside-toplevel load_chat_template_from_file, ) model_tokenizer.chat_template = load_chat_template_from_file( diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index b8accaf4c7..8ef3513561 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -474,7 +474,7 @@ def test_prepare_datasets_without_split(self, mock_load): class TokenizerChatTemplateTest(unittest.TestCase): """Unit tests for configure_tokenizer_chat_template.""" - + @pytest.mark.cpu_only def test_chat_template_populated_from_config_string(self): """Test that chat_template is set from config.chat_template when tokenizer lacks one.""" @@ -487,7 +487,7 @@ def test_chat_template_populated_from_config_string(self): ) train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}") - + @pytest.mark.cpu_only @mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file") def test_chat_template_populated_from_config_file(self, mock_load): @@ -518,7 +518,7 @@ def test_chat_template_raises_value_error_when_empty(self): ) with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"): train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) - + @pytest.mark.cpu_only def test_chat_template_unchanged_when_already_exists(self): """Test that an existing chat_template on the tokenizer is preserved (backward compatibility).""" @@ -531,18 +531,18 @@ def test_chat_template_unchanged_when_already_exists(self): ) train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}") - + @pytest.mark.cpu_only def test_apply_chat_template_works_after_configuration(self): """Verifies apply_chat_template succeeds and produces the expected format after our code path runs.""" - class DummyTokenizer: + class DummyTokenizer: # pylint: disable=missing-class-docstring def __init__(self): self.chat_template = None - + def apply_chat_template(self, conversation, tokenize=False): if self.chat_template is None: raise ValueError("Cannot apply chat template because chat_template is None") - import jinja2 + import jinja2 # pylint: disable=import-outside-toplevel env = jinja2.Environment() template = env.from_string(self.chat_template) return template.render(messages=conversation) From 885d96bcfdcb275c688f1a7ed7be2a74672cf7a1 Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Wed, 3 Jun 2026 19:14:49 +0000 Subject: [PATCH 6/7] Apply pyink formatting --- src/maxtext/trainers/post_train/rl/train_rl.py | 6 ++---- tests/post_training/unit/train_rl_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 9fb622cc88..fe73268203 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -582,7 +582,6 @@ def _reward_fn(**kwargs): return rl_cluster, rl_trainer, optimizer - def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None: """Populates the tokenizer's chat_template from config if missing.""" if getattr(model_tokenizer, "chat_template", None) is None: @@ -592,9 +591,8 @@ def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) from maxtext.input_pipeline.instruction_data_processing import ( # pylint: disable=import-outside-toplevel load_chat_template_from_file, ) - model_tokenizer.chat_template = load_chat_template_from_file( - trainer_config.chat_template_path - ) + + model_tokenizer.chat_template = load_chat_template_from_file(trainer_config.chat_template_path) else: raise ValueError( f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template " diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index 8ef3513561..0e71adb1e1 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -506,6 +506,7 @@ def test_chat_template_populated_from_config_file(self, mock_load): mock_tokenizer.chat_template, "{% for message in messages %}{{ message.content }}{% endfor %}", ) + @pytest.mark.cpu_only def test_chat_template_raises_value_error_when_empty(self): """Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty.""" @@ -535,7 +536,9 @@ def test_chat_template_unchanged_when_already_exists(self): @pytest.mark.cpu_only def test_apply_chat_template_works_after_configuration(self): """Verifies apply_chat_template succeeds and produces the expected format after our code path runs.""" + class DummyTokenizer: # pylint: disable=missing-class-docstring + def __init__(self): self.chat_template = None @@ -543,9 +546,11 @@ def apply_chat_template(self, conversation, tokenize=False): if self.chat_template is None: raise ValueError("Cannot apply chat template because chat_template is None") import jinja2 # pylint: disable=import-outside-toplevel + env = jinja2.Environment() template = env.from_string(self.chat_template) return template.render(messages=conversation) + tokenizer = DummyTokenizer() trainer_config = SimpleNamespace( chat_template="{{ messages[0].content }}", @@ -561,5 +566,6 @@ def apply_chat_template(self, conversation, tokenize=False): rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}]) self.assertEqual(rendered, "Hello!") + if __name__ == "__main__": unittest.main() From 792a8eeba82542a831ae921cf6c5fe078b4c091e Mon Sep 17 00:00:00 2001 From: David Soto Mora Date: Wed, 3 Jun 2026 19:28:11 +0000 Subject: [PATCH 7/7] Fix pylint protected-access on _rl_train_impl test from merged main --- tests/post_training/unit/train_rl_test.py | 52 ++++++++++++++++++----- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index 7ea05156c9..56ba2a7dcf 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self): # Following the pattern in distillation_checkpointing_test.py for mocking jax objects with ( mock.patch.object(jax, "devices", return_value=mock_devices), - mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config), + mock.patch( + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", + return_value=mock_config, + ), ): trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices( ["dummy", "dummy"] @@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self): with ( mock.patch.object(jax, "devices", return_value=mock_devices), - mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config), + mock.patch( + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", + return_value=mock_config, + ), ): _, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"]) @@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self): "tensor_parallel_size": 2, "expert_parallel_size": 4, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_auto_tp(self): @@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self): "tensor_parallel_size": 2, "expert_parallel_size": 1, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_fixed_tp_dp(self): @@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self): "tensor_parallel_size": 2, "expert_parallel_size": 1, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_auto_ep(self): @@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self): "tensor_parallel_size": 2, "expert_parallel_size": 2, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_errors(self): @@ -307,7 +325,10 @@ def tokenize_side_effect(text): {"question": "short", "answer": "a3"}, {"question": "long", "answer": "a4"}, ] - test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}] + test_data = [ + {"question": "short", "answer": "a5"}, + {"question": "long", "answer": "a6"}, + ] train_map_ds = grain.MapDataset.source(train_data) test_map_ds = grain.MapDataset.source(test_data) @@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config ) with ( - mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect), - mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect), + mock.patch( + "maxtext.trainers.post_train.rl.train_rl.get_dataset", + side_effect=get_dataset_side_effect, + ), + mock.patch( + "maxtext.trainers.post_train.rl.utils_rl.process_data", + side_effect=get_filtered_data_side_effect, + ), ): train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer) @@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config def test_prepare_datasets_with_split(self, mock_load): mock_ds = mock.MagicMock() mock_split_result = { - "train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}], + "train": [ + {"question": "q1", "answer": "a1"}, + {"question": "q2", "answer": "a2"}, + ], "test": [{"question": "q3", "answer": "a3"}], } mock_ds.train_test_split.return_value = mock_split_result @@ -480,7 +510,7 @@ def test_rl_train_invalid_vocab_tiling(self, mock_setup): mock_setup.return_value = (mock_config, mock_config, [], []) with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"): - train_rl._rl_train_impl([], {}) + train_rl._rl_train_impl([], {}) # pylint: disable=protected-access class TokenizerChatTemplateTest(unittest.TestCase):