diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml
index da455a13e2..9cc305d090 100644
--- a/src/maxtext/configs/post_train/rl.yml
+++ b/src/maxtext/configs/post_train/rl.yml
@@ -177,7 +177,7 @@ reasoning_start_token: ''
reasoning_end_token: ''
solution_start_token: ''
solution_end_token: ''
-chat_template_path: 'src/maxtext/examples/chat_templates/gsm8k_rl.json'
+chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
skip_jax_distributed_system: True
# # TODO(@mazumdera): fix this
diff --git a/src/maxtext/input_pipeline/instruction_data_processing.py b/src/maxtext/input_pipeline/instruction_data_processing.py
index c8fd2be7f2..4777b323e2 100644
--- a/src/maxtext/input_pipeline/instruction_data_processing.py
+++ b/src/maxtext/input_pipeline/instruction_data_processing.py
@@ -24,6 +24,9 @@
def load_template_from_file(template_path):
"""Loads a template from a file."""
template_config = None
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
+ template_path = os.path.join(repo_root, template_path)
if os.path.isfile(template_path) and template_path.endswith(".json"):
with open(template_path, encoding="utf-8") as f:
template_config = json.load(f)
diff --git a/tests/unit/instruction_data_processing_test.py b/tests/unit/instruction_data_processing_test.py
index 396b336c92..6080a0acf3 100644
--- a/tests/unit/instruction_data_processing_test.py
+++ b/tests/unit/instruction_data_processing_test.py
@@ -21,6 +21,22 @@
class InstructionDataProcessingTest(unittest.TestCase):
+ def test_load_template_from_file(self):
+ template_config = instruction_data_processing.load_template_from_file("maxtext/examples/chat_templates/gsm8k_rl.json")
+ self.assertEqual(
+ template_config,
+ {
+ "SYSTEM_PROMPT": (
+ "You are given a problem. Think about the problem and provide"
+ " your reasoning. Place it between {reasoning_start_token} and"
+ " {reasoning_end_token}. Then, provide the final answer (i.e.,"
+ " just one numerical value) between {solution_start_token} and"
+ " {solution_end_token}."
+ ),
+ "TEMPLATE": ("user\n{system_prompt}\n\n{question}\nmodel"),
+ },
+ )
+
def test_map_qa_data_to_conversation_with_prompt_completion_template(self):
template_config = {
"PROMPT_TEMPLATE": "This is a question: {question}",