Skip to content

Commit 86f715f

Browse files
committed
Enable Tunix-based DPO input processing for Grain
1 parent 7c68a9d commit 86f715f

4 files changed

Lines changed: 208 additions & 11 deletions

File tree

src/maxtext/configs/post_train/dpo.yml

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,21 @@ dpo:
88
dpo_beta: 0.1
99
max_prompt_length: null
1010
packing: false
11-
train_data_columns: ['chosen', 'rejected']
12-
eval_data_columns: ['chosen', 'rejected']
13-
base_output_directory: 'gs://maxtext-external/logs'
1411

1512
per_device_batch_size: 2.0
1613
steps: 10
1714
max_target_length: 512
1815
eval_interval: 5 # test eval once, in the middle of 10 training steps
1916
eval_steps: 2
2017

21-
# TFDS Pipeline ----------------------
22-
dataset_type: tfds
23-
dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf'
24-
dataset_name: 'tfds:1.0.0'
25-
eval_dataset_name: 'tfds:1.0.0'
26-
eval_split: 'test'
27-
28-
# HF Pipeline -------------------------
18+
# Some reasonable defaults for running DPO without extra config params.
19+
model_name: "qwen3-0.6b"
20+
tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer"
21+
tokenizer_type: "huggingface"
22+
dataset_type: hf
23+
hf_path: 'Anthropic/hh-rlhf'
24+
train_data_columns: ['chosen', 'rejected']
25+
eval_data_columns: ['chosen', 'rejected']
2926
hf_eval_split: 'test'
3027

3128
gradient_clipping_threshold: 10.0

src/maxtext/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3043,6 +3043,12 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30433043
logger.warning(
30443044
"tfds pipeline is deprecated. Use dataset_type=grain, grain_file_type=tfrecord, and provide grain_train_files."
30453045
)
3046+
if self.use_dpo:
3047+
raise ValueError(
3048+
"TFDS dataset_type=tfds is not supported for DPO training"
3049+
" (config.use_dpo=True). Please use dataset_type=grain or"
3050+
" dataset_type=hf instead."
3051+
)
30463052
if not self.dataset_name:
30473053
raise ValueError("dataset_name can't be empty when dataset_type=tfds")
30483054
if self.eval_interval > 0 and not self.eval_split:

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from maxtext.input_pipeline import data_processing_utils
3030
from maxtext.input_pipeline import input_pipeline_utils
3131
from maxtext.input_pipeline import grain_tokenizer
32+
from maxtext.input_pipeline import dpo_utils
3233
from maxtext.input_pipeline import multihost_dataloading
3334
from maxtext.utils import gcs_utils
3435
from maxtext.utils import max_logging
@@ -263,6 +264,46 @@ def pretrain_preprocessing_pipeline(
263264
return dataset
264265

265266

267+
def dpo_preprocessing_pipeline(
268+
dataset,
269+
config,
270+
data_columns,
271+
tokenize,
272+
grain_worker_count,
273+
grain_per_worker_buffer_size,
274+
):
275+
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
276+
dataset = data_processing_utils.parse_and_keep_features(dataset, config, data_columns, tokenize)
277+
tokenizer_model, pad_id = data_processing_utils.get_tokenizer_and_pad_id(config)
278+
279+
if tokenize:
280+
dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))
281+
282+
# Renames arbitrary DPO columns and performs DPO-aware padding.
283+
max_prompt_length = config.dpo.max_prompt_length
284+
dataset = dataset.map(
285+
dpo_utils.DPODataFormatting(
286+
pad_id=pad_id,
287+
max_target_length=config.max_target_length,
288+
data_column_names=data_columns,
289+
max_prompt_length=max_prompt_length,
290+
)
291+
)
292+
293+
batch_size = data_processing_utils.get_local_batch_size(config)
294+
if config.grain_use_elastic_iterator:
295+
# ElasticIterator batches internally, so return the pre-batch dataset.
296+
pass
297+
else:
298+
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
299+
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
300+
301+
dataset = data_processing_utils.apply_multiprocessing_and_prefetch(
302+
dataset, config, grain_worker_count, grain_per_worker_buffer_size
303+
)
304+
return dataset
305+
306+
266307
def _format_chat_template_grain(element, data_columns, tokenizer_model):
267308
"""Grain-compatible mapping function to format raw columns into conversational messages."""
268309
# Convert raw columns to conversational messages
@@ -350,6 +391,8 @@ def sft_preprocessing_pipeline(
350391

351392
def _get_pipeline_fn(config):
352393
"""Returns the appropriate preprocessing pipeline function based on config."""
394+
if config.use_dpo:
395+
return dpo_preprocessing_pipeline
353396
if config.use_sft:
354397
return sft_preprocessing_pipeline
355398
return pretrain_preprocessing_pipeline

tests/post_training/unit/dpo_data_processing_test.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
import pytest
2424
import transformers
2525

26+
import grain.python as grain
2627
from maxtext.configs import pyconfig
2728
from maxtext.input_pipeline import dpo_utils
2829
from maxtext.input_pipeline import hf_data_processing
30+
from maxtext.input_pipeline import grain_data_processing
2931
from maxtext.input_pipeline import input_pipeline_interface
3032
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
3133

@@ -389,5 +391,154 @@ def test_dpo_non_positive_max_prompt_length(self):
389391
)
390392

391393

394+
@pytest.mark.external_training
395+
class TestGrainDPOPipelineProcessing(unittest.TestCase):
396+
"""End-to-end Grain DPO pipeline processing tests."""
397+
398+
def setUp(self):
399+
super().setUp()
400+
self.config = pyconfig.initialize_pydantic(
401+
[
402+
os.path.join(MAXTEXT_PKG_DIR, "dpo_trainer"),
403+
os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "dpo.yml"),
404+
],
405+
per_device_batch_size=2,
406+
run_name="test",
407+
mesh_axes=["data"],
408+
logical_axis_rules=[["batch", "data"]],
409+
data_sharding=["data"],
410+
base_output_directory="gs://max-experiments/",
411+
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "qwen3-tokenizer"),
412+
train_split="train",
413+
enable_checkpointing=False,
414+
use_dpo=True,
415+
enable_data_shuffling=False,
416+
max_target_length=64,
417+
grain_file_type="parquet", # to trigger KeepFeatures in parse_and_keep_features
418+
tokenizer_type="huggingface",
419+
dataset_type="grain",
420+
grain_train_files="dummy",
421+
eval_interval=0,
422+
)
423+
self.mesh_shape_1d = (len(jax.devices()),)
424+
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
425+
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
426+
self.config.data_sharding,
427+
self.config.global_batch_size_to_load,
428+
self.config.global_batch_size_to_train_on,
429+
self.config.max_target_length,
430+
self.mesh,
431+
)
432+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
433+
self.config.tokenizer_path,
434+
add_bos_token=False,
435+
add_eos_token=False,
436+
legacy=False,
437+
)
438+
self.pad_id = hf_data_processing._get_pad_id(self.tokenizer) # pylint: disable=protected-access
439+
440+
def get_data_iterator(self, list_of_dicts, data_columns):
441+
"""Helper to initialize the Grain preprocessing pipeline."""
442+
dataset = grain.MapDataset.source(list_of_dicts)
443+
dataset = dataset[self.process_indices.index(jax.process_index()) :: len(self.process_indices)]
444+
dataset = dataset.to_iter_dataset()
445+
446+
iter_ds = grain_data_processing.dpo_preprocessing_pipeline(
447+
dataset=dataset,
448+
config=self.config,
449+
data_columns=data_columns,
450+
tokenize=self.config.tokenize_train_data,
451+
grain_worker_count=0,
452+
grain_per_worker_buffer_size=1,
453+
)
454+
return iter(iter_ds)
455+
456+
def test_dpo_format_3_columns(self):
457+
"""Verify that the 3-column explicit DPO dataset is processed correctly."""
458+
prompt_str = "Question: What is 2+2?"
459+
chosen_str = "Answer: 4"
460+
rejected_str = "Answer: 5"
461+
462+
list_of_dicts = [
463+
{
464+
"input": prompt_str,
465+
"chosen": chosen_str,
466+
"rejected": rejected_str,
467+
}
468+
for _ in range(10)
469+
]
470+
471+
data_iter = self.get_data_iterator(list_of_dicts, ["input", "chosen", "rejected"])
472+
batch = next(data_iter)
473+
474+
# Verify expected keys
475+
for key in (
476+
"prompt_ids",
477+
"chosen_ids",
478+
"rejected_ids",
479+
"prompt_mask",
480+
"chosen_mask",
481+
"rejected_mask",
482+
):
483+
self.assertIn(key, batch)
484+
485+
# Verify batch dimensions match global batch size and split max_target_length
486+
max_prompt_len = self.config.max_target_length // 2
487+
max_response_len = self.config.max_target_length - max_prompt_len
488+
self.assertEqual(
489+
batch["prompt_ids"].shape,
490+
(self.config.global_batch_size_to_load, max_prompt_len),
491+
)
492+
self.assertEqual(
493+
batch["chosen_ids"].shape,
494+
(self.config.global_batch_size_to_load, max_response_len),
495+
)
496+
self.assertEqual(
497+
batch["rejected_ids"].shape,
498+
(self.config.global_batch_size_to_load, max_response_len),
499+
)
500+
501+
# Verify decoded content directly
502+
decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True)
503+
decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True)
504+
decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True)
505+
506+
self.assertEqual(decoded_prompt, prompt_str)
507+
self.assertEqual(decoded_chosen, chosen_str)
508+
self.assertEqual(decoded_rejected, rejected_str)
509+
510+
# Verify mask structure (left padding for prompt -> 1s at the end; right padding for responses -> 1s at start)
511+
self.assertEqual(batch["prompt_mask"][0][-1], 1)
512+
self.assertEqual(batch["chosen_mask"][0][0], 1)
513+
self.assertEqual(batch["rejected_mask"][0][0], 1)
514+
515+
def test_dpo_format_2_columns(self):
516+
"""Verify that 2-column DPO datasets correctly extract common prefixes."""
517+
# We use a clear common prefix and different suffixes
518+
prefix = "Common prompt context for DPO:"
519+
chosen_suffix = " the chosen completion"
520+
rejected_suffix = " the rejected completion"
521+
522+
list_of_dicts = [
523+
{
524+
"chosen": prefix + chosen_suffix,
525+
"rejected": prefix + rejected_suffix,
526+
}
527+
for _ in range(10)
528+
]
529+
530+
data_iter = self.get_data_iterator(list_of_dicts, ["chosen", "rejected"])
531+
batch = next(data_iter)
532+
533+
# Verify decoded extracted prefix and completions robustly against BPE token boundary quirks
534+
decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True)
535+
decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True)
536+
decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True)
537+
538+
self.assertIn("Common prompt context", decoded_prompt)
539+
self.assertIn("chosen", decoded_chosen)
540+
self.assertIn("rejected", decoded_rejected)
541+
542+
392543
if __name__ == "__main__":
393544
unittest.main()

0 commit comments

Comments
 (0)