Skip to content

Commit f93627f

Browse files
Merge pull request #4031 from AI-Hypercomputer:pr/dataset-processor-path
PiperOrigin-RevId: 928781201
2 parents da8b70a + f2d4f3b commit f93627f

5 files changed

Lines changed: 180 additions & 16 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ skip_jax_distributed_system: true
227227
# Loads separate dataset for training and evaluation (e.g., train on OpenMathInstruct-2, eval on GSM8K).
228228
dataset_name: 'openai/gsm8k'
229229
eval_dataset_name: 'openai/gsm8k'
230+
# Optional: path to a user-provided Python file with a custom `process_data`
231+
# function. Signature: process_data(dataset_name, model_tokenizer, template_config,
232+
# tmvp_config, x) -> dict with keys {prompts, question, answer}. When empty
233+
# (default), the built-in utils_rl.process_data is used.
234+
dataset_processor_path: ''
230235
train_split: 'train'
231236
eval_split: 'test'
232237
hf_name: 'main' # subset of Hugging Face dataset

src/maxtext/configs/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2024,6 +2024,13 @@ class RLDataset(BaseModel):
20242024
train_fraction: float = Field(1.0, description="Fraction of the dataset to be used for training.")
20252025
train_micro_batch_size: int = Field(-1, description="Micro batch size for training.")
20262026
rollout_micro_batch_size: int = Field(-1, description="Micro batch size for rollout.")
2027+
dataset_processor_path: str = Field(
2028+
"",
2029+
description=(
2030+
"Optional path to a user-provided Python file with a `process_data` function. "
2031+
"When set, replaces the built-in dataset processor for custom datasets."
2032+
),
2033+
)
20272034

20282035

20292036
class RLEvaluation(BaseModel):

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def prepare_datasets(
254254
f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}"
255255
)
256256

257+
# Optional user-provided `process_data(dataset_name, tokenizer, template, config, x) -> dict`.
258+
# When `dataset_processor_path` is set in config, load that file's `process_data`
259+
# and use it instead of the built-in utils_rl.process_data. Lets users adapt
260+
# custom datasets (with non-standard answer columns / cleaning) without editing maxtext.
261+
_custom_processor_path = getattr(trainer_config, "dataset_processor_path", "") or ""
262+
if _custom_processor_path:
263+
_process_data = utils_rl.load_custom_callable(_custom_processor_path, "process_data")
264+
max_logging.log(f"prepare_datasets: using custom process_data from {_custom_processor_path}")
265+
else:
266+
_process_data = utils_rl.process_data
267+
257268
# Prepare train and test data from training data for certain datasets
258269
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
259270
test_dataset = None
@@ -272,22 +283,14 @@ def prepare_datasets(
272283
train_dataset = (
273284
grain.MapDataset.source(splits["train"])
274285
.shuffle(seed=trainer_config.data_shuffle_seed)
275-
.map(
276-
lambda x: utils_rl.process_data(
277-
trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x
278-
)
279-
)
286+
.map(lambda x: _process_data(trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x))
280287
)
281288

282289
if trainer_config.num_test_batches > 0:
283290
test_dataset = (
284291
grain.MapDataset.source(splits["validation"])
285292
.shuffle(seed=trainer_config.data_shuffle_seed)
286-
.map(
287-
lambda x: utils_rl.process_data(
288-
trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x
289-
)
290-
)
293+
.map(lambda x: _process_data(trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x))
291294
)
292295
else:
293296
if not eval_dataset_name:
@@ -302,11 +305,7 @@ def prepare_datasets(
302305
train_dataset = (
303306
grain.MapDataset.source(train_dataset)
304307
.shuffle(seed=trainer_config.data_shuffle_seed)
305-
.map(
306-
lambda x: utils_rl.process_data(
307-
trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x
308-
)
309-
)
308+
.map(lambda x: _process_data(trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x))
310309
)
311310

312311
if trainer_config.num_test_batches > 0:
@@ -319,7 +318,7 @@ def prepare_datasets(
319318
test_dataset = (
320319
grain.MapDataset.source(test_dataset)
321320
.shuffle(seed=trainer_config.data_shuffle_seed)
322-
.map(lambda x: utils_rl.process_data(eval_dataset_name, model_tokenizer, template_config, trainer_config, x))
321+
.map(lambda x: _process_data(eval_dataset_name, model_tokenizer, template_config, trainer_config, x))
323322
)
324323

325324
def _filter_long_prompts(x):

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
# pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught
1616
"""RL Utils Module."""
17+
import importlib.util
1718
import itertools
1819
import json
20+
import os
1921
import re
2022
import uuid
2123
from typing import Any, Callable, Optional
@@ -813,3 +815,24 @@ def install_training_hooks(
813815
)
814816
except Exception as e: # pylint: disable=broad-exception-caught
815817
max_logging.warning(f"[intermediate-eval] install failed: {e!r}")
818+
819+
820+
def load_custom_callable(module_path: str, function_name: str) -> Callable:
821+
"""Load a callable from a user-provided Python file via importlib.
822+
823+
`module_path` is an absolute or relative filesystem path to a `.py` file.
824+
The file is loaded as a fresh module (not added to sys.path) and the
825+
named attribute is returned. Used to plug in user-defined `process_data`
826+
(for custom datasets) and reward functions without editing maxtext.
827+
"""
828+
if not os.path.isfile(module_path):
829+
raise ValueError(f"Cannot import {module_path!r}: file does not exist")
830+
spec = importlib.util.spec_from_file_location(f"_user_module_{function_name}", module_path)
831+
if spec is None or spec.loader is None:
832+
raise ValueError(f"Cannot import {module_path!r}: not a valid python file")
833+
module = importlib.util.module_from_spec(spec)
834+
spec.loader.exec_module(module)
835+
fn = getattr(module, function_name, None)
836+
if fn is None:
837+
raise ValueError(f"{module_path!r} does not define a function named {function_name!r}")
838+
return fn
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for load_custom_callable (used by dataset_processor_path knob)."""
16+
17+
import os
18+
import sys
19+
import tempfile
20+
import textwrap
21+
import unittest
22+
23+
import pytest
24+
25+
from maxtext.trainers.post_train.rl.utils_rl import load_custom_callable
26+
27+
28+
pytestmark = [pytest.mark.post_training]
29+
30+
31+
_USER_PROCESS_DATA_SOURCE = textwrap.dedent(
32+
"""
33+
# Simulated user-provided dataset processor file.
34+
def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):
35+
# Minimal stand-in for utils_rl.process_data: returns a dict shaped like
36+
# what the RL data pipeline expects, with a marker so the test can verify
37+
# that THIS function (not the built-in) was actually invoked.
38+
return {
39+
"prompts": f"USER_PROCESSOR<{x.get('question', '')}>",
40+
"question": x.get("question", ""),
41+
"answer": x.get("answer", ""),
42+
"_marker": "loaded_from_user_file",
43+
}
44+
45+
46+
def another_helper(x):
47+
return x * 2
48+
"""
49+
).strip()
50+
51+
52+
def _write_user_file(tmpdir):
53+
"""Write the user processor file inside tmpdir and return its absolute path."""
54+
path = os.path.join(tmpdir, "user_processor.py")
55+
with open(path, "w", encoding="utf-8") as f:
56+
f.write(_USER_PROCESS_DATA_SOURCE)
57+
return path
58+
59+
60+
class LoadCustomCallableTest(unittest.TestCase):
61+
"""Verify load_custom_callable loads a function from a user .py file."""
62+
63+
@pytest.mark.cpu_only
64+
def test_loads_function_from_user_file(self):
65+
"""Returns a callable that behaves like the function in the user file."""
66+
with tempfile.TemporaryDirectory() as tmpdir:
67+
user_file = _write_user_file(tmpdir)
68+
fn = load_custom_callable(user_file, "process_data")
69+
70+
self.assertTrue(callable(fn))
71+
# pylint: disable-next=not-callable
72+
result = fn(
73+
"dataset_name",
74+
model_tokenizer=None,
75+
template_config=None,
76+
tmvp_config=None,
77+
x={"question": "2+2?", "answer": "4"},
78+
)
79+
self.assertEqual(result["_marker"], "loaded_from_user_file")
80+
self.assertEqual(result["prompts"], "USER_PROCESSOR<2+2?>")
81+
self.assertEqual(result["question"], "2+2?")
82+
self.assertEqual(result["answer"], "4")
83+
84+
@pytest.mark.cpu_only
85+
def test_loads_any_named_function(self):
86+
"""function_name argument selects which symbol to return."""
87+
with tempfile.TemporaryDirectory() as tmpdir:
88+
user_file = _write_user_file(tmpdir)
89+
fn = load_custom_callable(user_file, "another_helper")
90+
self.assertEqual(fn(5), 10) # pylint: disable=not-callable
91+
92+
@pytest.mark.cpu_only
93+
def test_raises_when_file_does_not_exist(self):
94+
"""Nonexistent path -> ValueError, not a cryptic ImportError."""
95+
with tempfile.TemporaryDirectory() as tmpdir:
96+
bogus = os.path.join(tmpdir, "does_not_exist.py")
97+
with self.assertRaises(ValueError):
98+
load_custom_callable(bogus, "process_data")
99+
100+
@pytest.mark.cpu_only
101+
def test_raises_when_function_not_defined(self):
102+
"""File exists but doesn't define the named function -> ValueError."""
103+
with tempfile.TemporaryDirectory() as tmpdir:
104+
user_file = _write_user_file(tmpdir)
105+
with self.assertRaises(ValueError):
106+
load_custom_callable(user_file, "no_such_function")
107+
108+
@pytest.mark.cpu_only
109+
def test_does_not_pollute_sys_path(self):
110+
"""Loading the file must not append its directory to sys.path."""
111+
sys_path_before = list(sys.path)
112+
with tempfile.TemporaryDirectory() as tmpdir:
113+
user_file = _write_user_file(tmpdir)
114+
load_custom_callable(user_file, "process_data")
115+
self.assertEqual(sys.path, sys_path_before)
116+
117+
@pytest.mark.cpu_only
118+
def test_does_not_pollute_sys_modules_globally(self):
119+
"""The loaded module gets a unique synthetic name; it should not shadow
120+
other modules with a generic name like 'user_processor'."""
121+
with tempfile.TemporaryDirectory() as tmpdir:
122+
user_file = _write_user_file(tmpdir)
123+
load_custom_callable(user_file, "process_data")
124+
# The helper uses '_user_module_<function_name>' as the synthetic module
125+
# name, not the file's basename - so 'user_processor' should NOT exist.
126+
self.assertNotIn("user_processor", sys.modules)
127+
128+
129+
if __name__ == "__main__":
130+
unittest.main()

0 commit comments

Comments
 (0)