Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions docs/tutorials/posttraining/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,82 @@ python3 -m maxtext.trainers.post_train.sft.train_sft \
```

Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.

## Dataset Customization & Chat Templates

Supervised Fine-Tuning in MaxText relies on tokenizing conversational datasets using chat templates. This requires the dataset structure and templates to be aligned.

### Supported Dataset Schemas

By default, MaxText SFT expects one of three conversational dataset structures:

- `["messages"]`: A single column containing a list of dictionaries with `role` and `content` (recommended).
- `["prompt", "completion"]`: Separated prompt and completion columns.
- `["question", "answer"]`: Question and answer columns (e.g., math datasets).

During data processing, MaxText converts these into a unified `messages` schema (OpenAI-like format) before feeding it to the tokenizer:

```json
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"}
]
```

### Custom Tokenizer Chat Templates

To customize the tokenizer's chat formatting (e.g., adding special tokens like `<start_of_turn>`, `<end_of_turn>`, etc.), you can provide a custom chat template using the `chat_template` or `chat_template_path` configs:

- **`chat_template_path`**: Path to a custom Jinja2 template file (e.g., `.jinja`) or a JSON file containing the template.
- **`use_chat_template=True`**: Enables chat template formatting.

### Advanced: Custom Dataset Formatter (e.g., ShareGPT)

If your dataset is in a format not natively supported—such as **ShareGPT** (which uses a `conversations` column with `from` and `value` keys)—you can write a custom Python formatting function to convert it on-the-fly.

#### 1. Write a custom formatting function

Create a Python file in your workspace (e.g., `src/maxtext/input_pipeline/custom_formatters.py`):

```python
def format_sharegpt(example):
"""Converts ShareGPT format (from/value) to standard messages (role/content)."""
role_map = {
"human": "user",
"user": "user",
"gpt": "assistant",
"assistant": "assistant",
"system": "system",
}

messages = []
for turn in example["conversations"]:
role = role_map.get(turn["from"], "user")
messages.append(
{
"role": role,
"content": turn["value"],
}
)

example["messages"] = messages
return example
```

#### 2. Configure MaxText to use your formatter

When starting your SFT training, pass the following parameters:

- `train_data_columns`: Point to the original column name in the raw dataset (`"['conversations']"`).
- `formatting_func_path`: Point to the python import path of your formatting function (`"maxtext.input_pipeline.custom_formatters.format_sharegpt"`).

```sh
python3 -m maxtext.trainers.post_train.sft.train_sft \
... \
train_data_columns="['conversations']" \
formatting_func_path="maxtext.input_pipeline.custom_formatters.format_sharegpt"
```

### Runnable Example in the Codebase

For a complete, runnable SFT workflow that demonstrates how to configure the training loop and use a custom dataset formatter (`formatting_func_path` and `formatting_func_kwargs`), check out the [sft_qwen3_demo.ipynb](../../../src/maxtext/examples/sft_qwen3_demo.ipynb) Jupyter notebook.
28 changes: 28 additions & 0 deletions src/maxtext/configs/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# pytype: skip-file
"""Pydantic-based configuration management for MaxText."""
import ast
import logging
import os
import sys
Expand Down Expand Up @@ -214,6 +215,29 @@ def _lists_to_tuples(l: list | Any) -> tuple | Any:
return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l


def _coerce_to_list(value: Any) -> list[str] | Any:
"""Coerce string/tuple inputs for list[str] configuration fields into Python lists.

This prevents unhelpful Pydantic validation errors when users pass string values
from the CLI (e.g., train_data_columns=messages is coerced to ['messages'], and
stringified lists like "['col1', 'col2']" are safely parsed to a Python list).
"""
if isinstance(value, str):
cleaned = value.strip()
if (cleaned.startswith("[") and cleaned.endswith("]")) or (cleaned.startswith("(") and cleaned.endswith(")")):
try:
parsed = ast.literal_eval(cleaned)
if isinstance(parsed, (list, tuple)):
return list(parsed)
return [str(parsed)]
except (ValueError, SyntaxError):
return [value]
return [value]
if isinstance(value, tuple):
return list(value)
return value


def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
"""Prepares the raw dictionary for Pydantic model instantiation."""
pydantic_kwargs = {}
Expand All @@ -236,6 +260,10 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
if key == "data_sharding" and isinstance(new_value, list) and new_value and isinstance(new_value[0], str):
new_value = [new_value]

# Coerce string/tuple inputs for list[str] configuration fields into Python lists.
if key in ("train_data_columns", "eval_data_columns", "trainable_parameters_mask", "adamw_mask"):
new_value = _coerce_to_list(new_value)

# An empty value provided in the configuration is treated as None
if (
key
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/instruction_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import datasets

from maxtext.input_pipeline import instruction_data_processing
from maxtext.input_pipeline import data_processing_utils


class InstructionDataProcessingTest(unittest.TestCase):
Expand Down Expand Up @@ -231,5 +232,31 @@ def test_data_formatter_with_formatting_func_path_without_kwargs(self):
self.assertEqual(returned_columns, ["messages"])


class TestDataProcessingUtils(unittest.TestCase):
"""Unit tests for dataset column validation (Scenario B)."""

def test_validate_sft_columns_valid(self):
"""Verifies that valid SFT columns do not raise any error."""
# These should pass without raising any exception
data_processing_utils.validate_and_configure_sft_columns(["messages"], None)
data_processing_utils.validate_and_configure_sft_columns(["prompt", "completion"], None)
data_processing_utils.validate_and_configure_sft_columns(["question", "answer"], None)

def test_validate_sft_columns_invalid_raises_helpful_error(self):
"""Verifies that invalid SFT columns raise a helpful AssertionError."""
with self.assertRaises(AssertionError) as ctx:
data_processing_utils.validate_and_configure_sft_columns(["some_invalid_column"], None)

# Verify that the error message is helpful and contains the expected guidance
self.assertIn("Dataset column names mismatch", str(ctx.exception))
self.assertIn("Expected columns to match one of", str(ctx.exception))
self.assertIn("prompt", str(ctx.exception))
self.assertIn("completion", str(ctx.exception))
self.assertIn("messages", str(ctx.exception))
self.assertIn("question", str(ctx.exception))
self.assertIn("answer", str(ctx.exception))
self.assertIn("some_invalid_column", str(ctx.exception))


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions tests/unit/pyconfig_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from maxtext.configs import pyconfig
from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path
from maxtext.input_pipeline import data_processing_utils
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path

Expand Down Expand Up @@ -173,6 +174,70 @@ def test_identical_override_allowed(self):
)
self.assertEqual(config.tokenizer_type, "huggingface")

def test_list_config_coercion(self):
"""Verifies that string/tuple inputs for list[str] config fields are coerced to lists."""
# Case 1: Plain string (coerced to single-item list)
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns="messages",
)
self.assertEqual(config.train_data_columns, ["messages"])

# Case 2: Stringified list literal
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns="['col1', 'col2']",
)
self.assertEqual(config.train_data_columns, ["col1", "col2"])

# Case 3: Stringified list literal with whitespace
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns="[ 'col1' , 'col2' ]",
)
self.assertEqual(config.train_data_columns, ["col1", "col2"])

# Case 4: Stringified tuple literal
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns="('col1', 'col2')",
)
self.assertEqual(config.train_data_columns, ["col1", "col2"])

# Case 5: Real tuple value (passed via kwargs)
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns=("col1", "col2"),
)
self.assertEqual(config.train_data_columns, ["col1", "col2"])

# Case 6: Malformed stringified list (falls back to wrapping as single-item list)
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns="[malformed, list",
)
self.assertEqual(config.train_data_columns, ["[malformed, list"])

def test_coerced_list_is_validated_successfully(self):
"""Verifies that a coerced list from pyconfig is successfully validated by the dataset pipeline."""
# Simulate a user passing `train_data_columns=messages` on the CLI
config = pyconfig.initialize(
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
skip_jax_distributed_system=True,
train_data_columns="messages",
)
# Verify coercion to list was successful
self.assertEqual(config.train_data_columns, ["messages"])

# Verify that passing this coerced list to the SFT column validator passes without error (Scenario A)
data_processing_utils.validate_and_configure_sft_columns(config.train_data_columns, None)


if __name__ == "__main__":
unittest.main()
Loading