Skip to content

Commit 1998c43

Browse files
committed
feat(sft): robust list config parsing and improved SFT docs
- Added automatic string-to-list and tuple-to-list coercion in pyconfig.py for list[str] configuration fields (train_data_columns, eval_data_columns, trainable_parameters_mask, adamw_mask) to prevent unhelpful Pydantic validation errors. - Added comprehensive unit tests in pyconfig_test.py covering 6 coercion scenarios and end-to-end integration. - Added unit tests in instruction_data_processing_test.py validating that unsupported SFT columns throw helpful, descriptive assertion errors. - Expanded docs/tutorials/posttraining/sft.md to document SFT schemas, custom chat templates, ShareGPT custom formatters, and linked a runnable example.
1 parent 2e6cd11 commit 1998c43

4 files changed

Lines changed: 186 additions & 0 deletions

File tree

docs/tutorials/posttraining/sft.md

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,75 @@ python3 -m maxtext.trainers.post_train.sft.train_sft \
107107
```
108108

109109
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
110+
111+
## Dataset Customization & Chat Templates
112+
113+
Supervised Fine-Tuning in MaxText relies on tokenizing conversational datasets using chat templates. This requires the dataset structure and templates to be aligned.
114+
115+
### Supported Dataset Schemas
116+
117+
By default, MaxText SFT expects one of three conversational dataset structures:
118+
- `["messages"]`: A single column containing a list of dictionaries with `role` and `content` (recommended).
119+
- `["prompt", "completion"]`: Separated prompt and completion columns.
120+
- `["question", "answer"]`: Question and answer columns (e.g., math datasets).
121+
122+
During data processing, MaxText converts these into a unified `messages` schema (OpenAI-like format) before feeding it to the tokenizer:
123+
```json
124+
[
125+
{"role": "user", "content": "Hello!"},
126+
{"role": "assistant", "content": "Hi there!"}
127+
]
128+
```
129+
130+
### Custom Tokenizer Chat Templates
131+
132+
To customize the tokenizer's chat formatting (e.g., adding special tokens like `<start_of_turn>`, `<end_of_turn>`, etc.), you can provide a custom chat template using the `chat_template` or `chat_template_path` configs:
133+
134+
- **`chat_template_path`**: Path to a custom Jinja2 template file (e.g., `.jinja`) or a JSON file containing the template.
135+
- **`use_chat_template=True`**: Enables chat template formatting.
136+
137+
### Advanced: Custom Dataset Formatter (e.g., ShareGPT)
138+
139+
If your dataset is in a format not natively supported—such as **ShareGPT** (which uses a `conversations` column with `from` and `value` keys)—you can write a custom Python formatting function to convert it on-the-fly.
140+
141+
#### 1. Write a custom formatting function
142+
Create a Python file in your workspace (e.g., `src/maxtext/input_pipeline/custom_formatters.py`):
143+
144+
```python
145+
def format_sharegpt(example):
146+
"""Converts ShareGPT format (from/value) to standard messages (role/content)."""
147+
role_map = {
148+
"human": "user",
149+
"user": "user",
150+
"gpt": "assistant",
151+
"assistant": "assistant",
152+
"system": "system",
153+
}
154+
155+
messages = []
156+
for turn in example["conversations"]:
157+
role = role_map.get(turn["from"], "user")
158+
messages.append({
159+
"role": role,
160+
"content": turn["value"],
161+
})
162+
163+
example["messages"] = messages
164+
return example
165+
```
166+
167+
#### 2. Configure MaxText to use your formatter
168+
When starting your SFT training, pass the following parameters:
169+
- `train_data_columns`: Point to the original column name in the raw dataset (`"['conversations']"`).
170+
- `formatting_func_path`: Point to the python import path of your formatting function (`"maxtext.input_pipeline.custom_formatters.format_sharegpt"`).
171+
172+
```sh
173+
python3 -m maxtext.trainers.post_train.sft.train_sft \
174+
... \
175+
train_data_columns="['conversations']" \
176+
formatting_func_path="maxtext.input_pipeline.custom_formatters.format_sharegpt"
177+
```
178+
179+
### Runnable Example in the Codebase
180+
181+
For a complete, runnable SFT workflow that demonstrates how to configure the training loop and use a custom dataset formatter (`formatting_func_path` and `formatting_func_kwargs`), check out the [sft_qwen3_demo.ipynb](../../../src/maxtext/examples/sft_qwen3_demo.ipynb) Jupyter notebook.

src/maxtext/configs/pyconfig.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# pytype: skip-file
1616
"""Pydantic-based configuration management for MaxText."""
17+
import ast
1718
import logging
1819
import os
1920
import sys
@@ -236,6 +237,27 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
236237
if key == "data_sharding" and isinstance(new_value, list) and new_value and isinstance(new_value[0], str):
237238
new_value = [new_value]
238239

240+
# Coerce string/tuple inputs for list[str] configuration fields into Python lists.
241+
# This prevents unhelpful Pydantic validation errors when users pass string values
242+
# from the CLI (e.g., train_data_columns=messages is coerced to ['messages'], and
243+
# stringified lists like "['col1', 'col2']" are safely parsed to a Python list).
244+
if key in ("train_data_columns", "eval_data_columns", "trainable_parameters_mask", "adamw_mask"):
245+
if isinstance(new_value, str):
246+
cleaned = new_value.strip()
247+
if (cleaned.startswith("[") and cleaned.endswith("]")) or (cleaned.startswith("(") and cleaned.endswith(")")):
248+
try:
249+
parsed = ast.literal_eval(cleaned)
250+
if isinstance(parsed, (list, tuple)):
251+
new_value = list(parsed)
252+
else:
253+
new_value = [str(parsed)]
254+
except Exception:
255+
new_value = [new_value]
256+
else:
257+
new_value = [new_value]
258+
elif isinstance(new_value, tuple):
259+
new_value = list(new_value)
260+
239261
# An empty value provided in the configuration is treated as None
240262
if (
241263
key

tests/unit/instruction_data_processing_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import datasets
2424

2525
from maxtext.input_pipeline import instruction_data_processing
26+
from maxtext.input_pipeline import data_processing_utils
2627

2728

2829
class InstructionDataProcessingTest(unittest.TestCase):
@@ -231,5 +232,31 @@ def test_data_formatter_with_formatting_func_path_without_kwargs(self):
231232
self.assertEqual(returned_columns, ["messages"])
232233

233234

235+
class TestDataProcessingUtils(unittest.TestCase):
236+
"""Unit tests for dataset column validation (Scenario B)."""
237+
238+
def test_validate_sft_columns_valid(self):
239+
"""Verifies that valid SFT columns do not raise any error."""
240+
# These should pass without raising any exception
241+
data_processing_utils.validate_and_configure_sft_columns(["messages"], None)
242+
data_processing_utils.validate_and_configure_sft_columns(["prompt", "completion"], None)
243+
data_processing_utils.validate_and_configure_sft_columns(["question", "answer"], None)
244+
245+
def test_validate_sft_columns_invalid_raises_helpful_error(self):
246+
"""Verifies that invalid SFT columns raise a helpful AssertionError."""
247+
with self.assertRaises(AssertionError) as ctx:
248+
data_processing_utils.validate_and_configure_sft_columns(["some_invalid_column"], None)
249+
250+
# Verify that the error message is helpful and contains the expected guidance
251+
self.assertIn("Dataset column names mismatch", str(ctx.exception))
252+
self.assertIn("Expected columns to match one of", str(ctx.exception))
253+
self.assertIn("prompt", str(ctx.exception))
254+
self.assertIn("completion", str(ctx.exception))
255+
self.assertIn("messages", str(ctx.exception))
256+
self.assertIn("question", str(ctx.exception))
257+
self.assertIn("answer", str(ctx.exception))
258+
self.assertIn("some_invalid_column", str(ctx.exception))
259+
260+
234261
if __name__ == "__main__":
235262
unittest.main()

tests/unit/pyconfig_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from maxtext.configs import pyconfig
2222
from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path
23+
from maxtext.input_pipeline import data_processing_utils
2324
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
2425
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path
2526

@@ -173,6 +174,70 @@ def test_identical_override_allowed(self):
173174
)
174175
self.assertEqual(config.tokenizer_type, "huggingface")
175176

177+
def test_list_config_coercion(self):
178+
"""Verifies that string/tuple inputs for list[str] config fields are coerced to lists."""
179+
# Case 1: Plain string (coerced to single-item list)
180+
config = pyconfig.initialize(
181+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
182+
skip_jax_distributed_system=True,
183+
train_data_columns="messages",
184+
)
185+
self.assertEqual(config.train_data_columns, ["messages"])
186+
187+
# Case 2: Stringified list literal
188+
config = pyconfig.initialize(
189+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
190+
skip_jax_distributed_system=True,
191+
train_data_columns="['col1', 'col2']",
192+
)
193+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
194+
195+
# Case 3: Stringified list literal with whitespace
196+
config = pyconfig.initialize(
197+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
198+
skip_jax_distributed_system=True,
199+
train_data_columns="[ 'col1' , 'col2' ]",
200+
)
201+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
202+
203+
# Case 4: Stringified tuple literal
204+
config = pyconfig.initialize(
205+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
206+
skip_jax_distributed_system=True,
207+
train_data_columns="('col1', 'col2')",
208+
)
209+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
210+
211+
# Case 5: Real tuple value (passed via kwargs)
212+
config = pyconfig.initialize(
213+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
214+
skip_jax_distributed_system=True,
215+
train_data_columns=("col1", "col2"),
216+
)
217+
self.assertEqual(config.train_data_columns, ["col1", "col2"])
218+
219+
# Case 6: Malformed stringified list (falls back to wrapping as single-item list)
220+
config = pyconfig.initialize(
221+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
222+
skip_jax_distributed_system=True,
223+
train_data_columns="[malformed, list",
224+
)
225+
self.assertEqual(config.train_data_columns, ["[malformed, list"])
226+
227+
def test_coerced_list_is_validated_successfully(self):
228+
"""Verifies that a coerced list from pyconfig is successfully validated by the dataset pipeline."""
229+
# Simulate a user passing `train_data_columns=messages` on the CLI
230+
config = pyconfig.initialize(
231+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
232+
skip_jax_distributed_system=True,
233+
train_data_columns="messages",
234+
)
235+
# Verify coercion to list was successful
236+
self.assertEqual(config.train_data_columns, ["messages"])
237+
238+
# Verify that passing this coerced list to the SFT column validator passes without error (Scenario A)
239+
data_processing_utils.validate_and_configure_sft_columns(config.train_data_columns, None)
240+
176241

177242
if __name__ == "__main__":
178243
unittest.main()

0 commit comments

Comments
 (0)