diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index da455a13e2..9cc305d090 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -177,7 +177,7 @@ reasoning_start_token: '' reasoning_end_token: '' solution_start_token: '' solution_end_token: '' -chat_template_path: 'src/maxtext/examples/chat_templates/gsm8k_rl.json' +chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' skip_jax_distributed_system: True # # TODO(@mazumdera): fix this diff --git a/src/maxtext/input_pipeline/instruction_data_processing.py b/src/maxtext/input_pipeline/instruction_data_processing.py index c8fd2be7f2..4777b323e2 100644 --- a/src/maxtext/input_pipeline/instruction_data_processing.py +++ b/src/maxtext/input_pipeline/instruction_data_processing.py @@ -24,6 +24,9 @@ def load_template_from_file(template_path): """Loads a template from a file.""" template_config = None + current_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.abspath(os.path.join(current_dir, "..", "..")) + template_path = os.path.join(repo_root, template_path) if os.path.isfile(template_path) and template_path.endswith(".json"): with open(template_path, encoding="utf-8") as f: template_config = json.load(f) diff --git a/tests/unit/instruction_data_processing_test.py b/tests/unit/instruction_data_processing_test.py index 396b336c92..6080a0acf3 100644 --- a/tests/unit/instruction_data_processing_test.py +++ b/tests/unit/instruction_data_processing_test.py @@ -21,6 +21,22 @@ class InstructionDataProcessingTest(unittest.TestCase): + def test_load_template_from_file(self): + template_config = instruction_data_processing.load_template_from_file("maxtext/examples/chat_templates/gsm8k_rl.json") + self.assertEqual( + template_config, + { + "SYSTEM_PROMPT": ( + "You are given a problem. Think about the problem and provide" + " your reasoning. Place it between {reasoning_start_token} and" + " {reasoning_end_token}. Then, provide the final answer (i.e.," + " just one numerical value) between {solution_start_token} and" + " {solution_end_token}." + ), + "TEMPLATE": ("user\n{system_prompt}\n\n{question}\nmodel"), + }, + ) + def test_map_qa_data_to_conversation_with_prompt_completion_template(self): template_config = { "PROMPT_TEMPLATE": "This is a question: {question}",