Skip to content

Commit 7c86879

Browse files
committed
feat(sft): robust list config parsing and improved SFT docs
- Added automatic string-to-list and tuple-to-list coercion in pyconfig.py for list[str] configuration fields (train_data_columns, eval_data_columns, trainable_parameters_mask, adamw_mask) to prevent unhelpful Pydantic validation errors. - Added comprehensive unit tests in pyconfig_test.py covering 6 coercion scenarios and end-to-end integration. - Added unit tests in instruction_data_processing_test.py validating that unsupported SFT columns throw helpful, descriptive assertion errors. - Expanded docs/tutorials/posttraining/sft.md to document SFT schemas, custom chat templates, ShareGPT custom formatters, and linked a runnable example.
1 parent 2e6cd11 commit 7c86879

4 files changed

Lines changed: 200 additions & 0 deletions

File tree

docs/tutorials/posttraining/sft.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,83 @@ python3 -m maxtext.trainers.post_train.sft.train_sft \
107107
```
108108

109109
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
110+
111+
## Dataset Customization & Chat Templates
112+
113+
Supervised Fine-Tuning in MaxText relies on tokenizing conversational datasets using chat templates. This requires the dataset structure and templates to be aligned.
114+
115+
### Supported Dataset Schemas
116+
117+
By default, MaxText SFT expects one of three conversational dataset structures:
118+
119+
- `["messages"]`: A single column containing a list of dictionaries with `role` and `content` (recommended).
120+
- `["prompt", "completion"]`: Separated prompt and completion columns.
121+
- `["question", "answer"]`: Question and answer columns (e.g., math datasets).
122+
123+
During data processing, MaxText converts these into a unified `messages` schema (OpenAI-like format) before feeding it to the tokenizer:
124+
125+
```json
126+
[
127+
{"role": "user", "content": "Hello!"},
128+
{"role": "assistant", "content": "Hi there!"}
129+
]
130+
```
131+
132+
### Custom Tokenizer Chat Templates
133+
134+
To customize the tokenizer's chat formatting (e.g., adding special tokens like `<start_of_turn>`, `<end_of_turn>`, etc.), you can provide a custom chat template using the `chat_template` or `chat_template_path` configs:
135+
136+
- **`chat_template`**: Use this config to specify a custom Jinja2 template string directly.
137+
- **`chat_template_path`**: Path to a custom Jinja2 template file (e.g., `.jinja`) or a JSON file containing the template.
138+
- **`use_chat_template=True`**: Enables chat template formatting.
139+
140+
### Advanced: Custom Dataset Formatter (e.g., ShareGPT)
141+
142+
If your dataset is in a format not natively supported—such as **ShareGPT** (which uses a `conversations` column with `from` and `value` keys)—you can write a custom Python formatting function to convert it on-the-fly.
143+
144+
#### 1. Write a custom formatting function
145+
146+
Create a Python file in your workspace (e.g., `src/maxtext/input_pipeline/custom_formatters.py`):
147+
148+
```python
149+
def format_sharegpt(example):
150+
"""Converts ShareGPT format (from/value) to standard messages (role/content)."""
151+
role_map = {
152+
"human": "user",
153+
"user": "user",
154+
"gpt": "assistant",
155+
"assistant": "assistant",
156+
"system": "system",
157+
}
158+
159+
messages = []
160+
for turn in example["conversations"]:
161+
role = role_map.get(turn["from"], "user")
162+
messages.append(
163+
{
164+
"role": role,
165+
"content": turn["value"],
166+
}
167+
)
168+
169+
example["messages"] = messages
170+
return example
171+
```
172+
173+
#### 2. Configure MaxText to use your formatter
174+
175+
When starting your SFT training, pass the following parameters:
176+
177+
- `train_data_columns`: Point to the original column name in the raw dataset (`"['conversations']"`).
178+
- `formatting_func_path`: Point to the python import path of your formatting function (`"maxtext.input_pipeline.custom_formatters.format_sharegpt"`).
179+
180+
```sh
181+
python3 -m maxtext.trainers.post_train.sft.train_sft \
182+
... \
183+
train_data_columns="['conversations']" \
184+
formatting_func_path="maxtext.input_pipeline.custom_formatters.format_sharegpt"
185+
```
186+
187+
### Runnable Example in the Codebase
188+
189+
For a complete, runnable SFT workflow that demonstrates how to configure the training loop and use a custom dataset formatter (`formatting_func_path` and `formatting_func_kwargs`), check out the [sft_qwen3_demo.ipynb](../../../src/maxtext/examples/sft_qwen3_demo.ipynb) Jupyter notebook.

src/maxtext/configs/pyconfig.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# pytype: skip-file
1616
"""Pydantic-based configuration management for MaxText."""
17+
import ast
1718
import logging
1819
import os
1920
import sys
@@ -214,6 +215,29 @@ def _lists_to_tuples(l: list | Any) -> tuple | Any:
214215
return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l
215216

216217

218+
def _coerce_to_list(value: Any) -> list[str] | Any:
219+
"""Coerce string/tuple inputs for list[str] configuration fields into Python lists.
220+
221+
This prevents unhelpful Pydantic validation errors when users pass string values
222+
from the CLI (e.g., train_data_columns=messages is coerced to ['messages'], and
223+
stringified lists like "['col1', 'col2']" are safely parsed to a Python list).
224+
"""
225+
if isinstance(value, str):
226+
cleaned = value.strip()
227+
if (cleaned.startswith("[") and cleaned.endswith("]")) or (cleaned.startswith("(") and cleaned.endswith(")")):
228+
try:
229+
parsed = ast.literal_eval(cleaned)
230+
if isinstance(parsed, (list, tuple)):
231+
return list(parsed)
232+
return [str(parsed)]
233+
except (ValueError, SyntaxError):
234+
return [value]
235+
return [value]
236+
if isinstance(value, tuple):
237+
return list(value)
238+
return value
239+
240+
217241
def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
218242
"""Prepares the raw dictionary for Pydantic model instantiation."""
219243
pydantic_kwargs = {}
@@ -236,6 +260,10 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
236260
if key == "data_sharding" and isinstance(new_value, list) and new_value and isinstance(new_value[0], str):
237261
new_value = [new_value]
238262

263+
# Coerce string/tuple inputs for list[str] configuration fields into Python lists.
264+
if key in ("train_data_columns", "eval_data_columns", "trainable_parameters_mask", "adamw_mask"):
265+
new_value = _coerce_to_list(new_value)
266+
239267
# An empty value provided in the configuration is treated as None
240268
if (
241269
key

tests/unit/instruction_data_processing_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import datasets
2424

2525
from maxtext.input_pipeline import instruction_data_processing
26+
from maxtext.input_pipeline import data_processing_utils
2627

2728

2829
class InstructionDataProcessingTest(unittest.TestCase):
@@ -231,5 +232,31 @@ def test_data_formatter_with_formatting_func_path_without_kwargs(self):
231232
self.assertEqual(returned_columns, ["messages"])
232233

233234

235+
class TestDataProcessingUtils(unittest.TestCase):
236+
"""Unit tests for dataset column validation (Scenario B)."""
237+
238+
def test_validate_sft_columns_valid(self):
239+
"""Verifies that valid SFT columns do not raise any error."""
240+
# These should pass without raising any exception
241+
data_processing_utils.validate_and_configure_sft_columns(["messages"], None)
242+
data_processing_utils.validate_and_configure_sft_columns(["prompt", "completion"], None)
243+
data_processing_utils.validate_and_configure_sft_columns(["question", "answer"], None)
244+
245+
def test_validate_sft_columns_invalid_raises_helpful_error(self):
246+
"""Verifies that invalid SFT columns raise a helpful AssertionError."""
247+
with self.assertRaises(AssertionError) as ctx:
248+
data_processing_utils.validate_and_configure_sft_columns(["some_invalid_column"], None)
249+
250+
# Verify that the error message is helpful and contains the expected guidance
251+
self.assertIn("Dataset column names mismatch", str(ctx.exception))
252+
self.assertIn("Expected columns to match one of", str(ctx.exception))
253+
self.assertIn("prompt", str(ctx.exception))
254+
self.assertIn("completion", str(ctx.exception))
255+
self.assertIn("messages", str(ctx.exception))
256+
self.assertIn("question", str(ctx.exception))
257+
self.assertIn("answer", str(ctx.exception))
258+
self.assertIn("some_invalid_column", str(ctx.exception))
259+
260+
234261
if __name__ == "__main__":
235262
unittest.main()

tests/unit/pyconfig_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from maxtext.configs import pyconfig
2222
from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path
23+
from maxtext.input_pipeline import data_processing_utils
2324
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
2425
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path
2526

@@ -173,6 +174,70 @@ def test_identical_override_allowed(self):
173174
)
174175
self.assertEqual(config.tokenizer_type, "huggingface")
175176

177+
def test_list_config_coercion(self):
178+
"""Verifies that string/tuple inputs for list[str] config fields are coerced to lists."""
179+
# Case 1: Plain string (coerced to single-item list)
180+
config = pyconfig.initialize(
181+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
182+
skip_jax_distributed_system=True,
183+
train_data_columns="messages",
184+
)
185+
self.assertEqual(config.train_data_columns, ["messages"])
186+
187+
# Case 2: Stringified list literal
188+
config = pyconfig.initialize(
189+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
190+
skip_jax_distributed_system=True,
191+
train_data_columns="['col1', 'col2']",
192+
)
193+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
194+
195+
# Case 3: Stringified list literal with whitespace
196+
config = pyconfig.initialize(
197+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
198+
skip_jax_distributed_system=True,
199+
train_data_columns="[ 'col1' , 'col2' ]",
200+
)
201+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
202+
203+
# Case 4: Stringified tuple literal
204+
config = pyconfig.initialize(
205+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
206+
skip_jax_distributed_system=True,
207+
train_data_columns="('col1', 'col2')",
208+
)
209+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
210+
211+
# Case 5: Real tuple value (passed via kwargs)
212+
config = pyconfig.initialize(
213+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
214+
skip_jax_distributed_system=True,
215+
train_data_columns=("col1", "col2"),
216+
)
217+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
218+
219+
# Case 6: Malformed stringified list (falls back to wrapping as single-item list)
220+
config = pyconfig.initialize(
221+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
222+
skip_jax_distributed_system=True,
223+
train_data_columns="[malformed, list",
224+
)
225+
self.assertEqual(config.train_data_columns, ["[malformed, list"])
226+
227+
def test_coerced_list_is_validated_successfully(self):
228+
"""Verifies that a coerced list from pyconfig is successfully validated by the dataset pipeline."""
229+
# Simulate a user passing `train_data_columns=messages` on the CLI
230+
config = pyconfig.initialize(
231+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
232+
skip_jax_distributed_system=True,
233+
train_data_columns="messages",
234+
)
235+
# Verify coercion to list was successful
236+
self.assertEqual(config.train_data_columns, ["messages"])
237+
238+
# Verify that passing this coerced list to the SFT column validator passes without error (Scenario A)
239+
data_processing_utils.validate_and_configure_sft_columns(config.train_data_columns, None)
240+
176241

177242
if __name__ == "__main__":
178243
unittest.main()

0 commit comments

Comments
 (0)