|
23 | 23 | import pytest |
24 | 24 | import transformers |
25 | 25 |
|
| 26 | +import grain.python as grain |
26 | 27 | from maxtext.configs import pyconfig |
27 | 28 | from maxtext.input_pipeline import dpo_utils |
28 | 29 | from maxtext.input_pipeline import hf_data_processing |
| 30 | +from maxtext.input_pipeline import grain_data_processing |
29 | 31 | from maxtext.input_pipeline import input_pipeline_interface |
30 | 32 | from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR |
31 | 33 |
|
@@ -389,5 +391,154 @@ def test_dpo_non_positive_max_prompt_length(self): |
389 | 391 | ) |
390 | 392 |
|
391 | 393 |
|
| 394 | +@pytest.mark.external_training |
| 395 | +class TestGrainDPOPipelineProcessing(unittest.TestCase): |
| 396 | + """End-to-end Grain DPO pipeline processing tests.""" |
| 397 | + |
| 398 | + def setUp(self): |
| 399 | + super().setUp() |
| 400 | + self.config = pyconfig.initialize_pydantic( |
| 401 | + [ |
| 402 | + os.path.join(MAXTEXT_PKG_DIR, "dpo_trainer"), |
| 403 | + os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "dpo.yml"), |
| 404 | + ], |
| 405 | + per_device_batch_size=2, |
| 406 | + run_name="test", |
| 407 | + mesh_axes=["data"], |
| 408 | + logical_axis_rules=[["batch", "data"]], |
| 409 | + data_sharding=["data"], |
| 410 | + base_output_directory="gs://max-experiments/", |
| 411 | + tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "qwen3-tokenizer"), |
| 412 | + train_split="train", |
| 413 | + enable_checkpointing=False, |
| 414 | + use_dpo=True, |
| 415 | + enable_data_shuffling=False, |
| 416 | + max_target_length=64, |
| 417 | + grain_file_type="parquet", # to trigger KeepFeatures in parse_and_keep_features |
| 418 | + tokenizer_type="huggingface", |
| 419 | + dataset_type="grain", |
| 420 | + grain_train_files="dummy", |
| 421 | + eval_interval=0, |
| 422 | + ) |
| 423 | + self.mesh_shape_1d = (len(jax.devices()),) |
| 424 | + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) |
| 425 | + self.process_indices = input_pipeline_interface.get_process_loading_real_data( |
| 426 | + self.config.data_sharding, |
| 427 | + self.config.global_batch_size_to_load, |
| 428 | + self.config.global_batch_size_to_train_on, |
| 429 | + self.config.max_target_length, |
| 430 | + self.mesh, |
| 431 | + ) |
| 432 | + self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 433 | + self.config.tokenizer_path, |
| 434 | + add_bos_token=False, |
| 435 | + add_eos_token=False, |
| 436 | + legacy=False, |
| 437 | + ) |
| 438 | + self.pad_id = hf_data_processing._get_pad_id(self.tokenizer) # pylint: disable=protected-access |
| 439 | + |
| 440 | + def get_data_iterator(self, list_of_dicts, data_columns): |
| 441 | + """Helper to initialize the Grain preprocessing pipeline.""" |
| 442 | + dataset = grain.MapDataset.source(list_of_dicts) |
| 443 | + dataset = dataset[self.process_indices.index(jax.process_index()) :: len(self.process_indices)] |
| 444 | + dataset = dataset.to_iter_dataset() |
| 445 | + |
| 446 | + iter_ds = grain_data_processing.dpo_preprocessing_pipeline( |
| 447 | + dataset=dataset, |
| 448 | + config=self.config, |
| 449 | + data_columns=data_columns, |
| 450 | + tokenize=self.config.tokenize_train_data, |
| 451 | + grain_worker_count=0, |
| 452 | + grain_per_worker_buffer_size=1, |
| 453 | + ) |
| 454 | + return iter(iter_ds) |
| 455 | + |
| 456 | + def test_dpo_format_3_columns(self): |
| 457 | + """Verify that the 3-column explicit DPO dataset is processed correctly.""" |
| 458 | + prompt_str = "Question: What is 2+2?" |
| 459 | + chosen_str = "Answer: 4" |
| 460 | + rejected_str = "Answer: 5" |
| 461 | + |
| 462 | + list_of_dicts = [ |
| 463 | + { |
| 464 | + "input": prompt_str, |
| 465 | + "chosen": chosen_str, |
| 466 | + "rejected": rejected_str, |
| 467 | + } |
| 468 | + for _ in range(10) |
| 469 | + ] |
| 470 | + |
| 471 | + data_iter = self.get_data_iterator(list_of_dicts, ["input", "chosen", "rejected"]) |
| 472 | + batch = next(data_iter) |
| 473 | + |
| 474 | + # Verify expected keys |
| 475 | + for key in ( |
| 476 | + "prompt_ids", |
| 477 | + "chosen_ids", |
| 478 | + "rejected_ids", |
| 479 | + "prompt_mask", |
| 480 | + "chosen_mask", |
| 481 | + "rejected_mask", |
| 482 | + ): |
| 483 | + self.assertIn(key, batch) |
| 484 | + |
| 485 | + # Verify batch dimensions match global batch size and split max_target_length |
| 486 | + max_prompt_len = self.config.max_target_length // 2 |
| 487 | + max_response_len = self.config.max_target_length - max_prompt_len |
| 488 | + self.assertEqual( |
| 489 | + batch["prompt_ids"].shape, |
| 490 | + (self.config.global_batch_size_to_load, max_prompt_len), |
| 491 | + ) |
| 492 | + self.assertEqual( |
| 493 | + batch["chosen_ids"].shape, |
| 494 | + (self.config.global_batch_size_to_load, max_response_len), |
| 495 | + ) |
| 496 | + self.assertEqual( |
| 497 | + batch["rejected_ids"].shape, |
| 498 | + (self.config.global_batch_size_to_load, max_response_len), |
| 499 | + ) |
| 500 | + |
| 501 | + # Verify decoded content directly |
| 502 | + decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True) |
| 503 | + decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True) |
| 504 | + decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True) |
| 505 | + |
| 506 | + self.assertEqual(decoded_prompt, prompt_str) |
| 507 | + self.assertEqual(decoded_chosen, chosen_str) |
| 508 | + self.assertEqual(decoded_rejected, rejected_str) |
| 509 | + |
| 510 | + # Verify mask structure (left padding for prompt -> 1s at the end; right padding for responses -> 1s at start) |
| 511 | + self.assertEqual(batch["prompt_mask"][0][-1], 1) |
| 512 | + self.assertEqual(batch["chosen_mask"][0][0], 1) |
| 513 | + self.assertEqual(batch["rejected_mask"][0][0], 1) |
| 514 | + |
| 515 | + def test_dpo_format_2_columns(self): |
| 516 | + """Verify that 2-column DPO datasets correctly extract common prefixes.""" |
| 517 | + # We use a clear common prefix and different suffixes |
| 518 | + prefix = "Common prompt context for DPO:" |
| 519 | + chosen_suffix = " the chosen completion" |
| 520 | + rejected_suffix = " the rejected completion" |
| 521 | + |
| 522 | + list_of_dicts = [ |
| 523 | + { |
| 524 | + "chosen": prefix + chosen_suffix, |
| 525 | + "rejected": prefix + rejected_suffix, |
| 526 | + } |
| 527 | + for _ in range(10) |
| 528 | + ] |
| 529 | + |
| 530 | + data_iter = self.get_data_iterator(list_of_dicts, ["chosen", "rejected"]) |
| 531 | + batch = next(data_iter) |
| 532 | + |
| 533 | + # Verify decoded extracted prefix and completions robustly against BPE token boundary quirks |
| 534 | + decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True) |
| 535 | + decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True) |
| 536 | + decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True) |
| 537 | + |
| 538 | + self.assertIn("Common prompt context", decoded_prompt) |
| 539 | + self.assertIn("chosen", decoded_chosen) |
| 540 | + self.assertIn("rejected", decoded_rejected) |
| 541 | + |
| 542 | + |
392 | 543 | if __name__ == "__main__": |
393 | 544 | unittest.main() |
0 commit comments