|
24 | 24 | from hydra import compose, initialize_config_dir |
25 | 25 | from torch.distributed.device_mesh import init_device_mesh |
26 | 26 |
|
27 | | -from dataset import create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader, create_tokenized_dataset |
| 27 | +from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel |
| 28 | +from dataset import create_bshd_dataloader, create_thd_dataloader, create_tokenized_dataset |
28 | 29 | from distributed_config import DistributedConfig |
29 | 30 |
|
30 | 31 |
|
@@ -703,15 +704,28 @@ def test_cp_dataloader(tokenizer_path): |
703 | 704 | torch.cuda.set_device(dist_config.local_rank) |
704 | 705 | device_mesh = init_device_mesh("cuda", mesh_shape=(1, 1), mesh_dim_names=("dp", "cp")) |
705 | 706 |
|
706 | | - dataloader, _ = create_cp_dataloader( |
707 | | - distributed_config=dist_config, |
708 | | - cp_mesh=device_mesh["cp"], |
709 | | - tokenizer_name_or_path=tokenizer_path, |
710 | | - load_dataset_kwargs=load_dataset_kwargs, |
711 | | - text_column="text", |
712 | | - micro_batch_size=1, |
713 | | - max_seq_length=1024, |
714 | | - ) |
| 707 | + cp_mesh = device_mesh["cp"] |
| 708 | + |
| 709 | + # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py |
| 710 | + if cp_mesh.get_local_rank() == 0: |
| 711 | + train_dataloader, _ = create_thd_dataloader( |
| 712 | + distributed_config=dist_config, |
| 713 | + tokenizer_name_or_path=tokenizer_path, |
| 714 | + load_dataset_kwargs=load_dataset_kwargs, |
| 715 | + text_column="text", |
| 716 | + micro_batch_size=1, |
| 717 | + max_seq_length=1024, |
| 718 | + pad_sequences_to_be_divisible_by=cp_mesh.size() * 2, |
| 719 | + ) |
| 720 | + |
| 721 | + train_dataloader.collate_fn = DataCollatorForContextParallel( |
| 722 | + collator=train_dataloader.collate_fn, |
| 723 | + cp_world_size=cp_mesh.size(), |
| 724 | + ) |
| 725 | + else: |
| 726 | + train_dataloader = None |
| 727 | + |
| 728 | + dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh) |
715 | 729 |
|
716 | 730 | batches = list(dataloader) |
717 | 731 | assert len(batches) > 1 |
@@ -775,30 +789,39 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path): |
775 | 789 | parser.add_argument("--dataset_path", type=str, default="dlcm_sanity_dataset.parquet") |
776 | 790 | args = parser.parse_args() |
777 | 791 |
|
778 | | - from torch.distributed.device_mesh import init_device_mesh |
779 | | - |
780 | | - from dataset import create_cp_dataloader |
781 | | - |
782 | 792 | dist_config = DistributedConfig() |
783 | 793 | device = torch.device(f"cuda:{dist_config.local_rank}") |
784 | 794 | torch.distributed.init_process_group(backend="nccl", device_id=device) |
785 | 795 | torch.cuda.set_device(dist_config.local_rank) |
786 | 796 | device_mesh = init_device_mesh("cuda", mesh_shape=(1, 2), mesh_dim_names=("dp", "cp")) |
787 | 797 |
|
788 | | - dataloader, _ = create_cp_dataloader( |
789 | | - distributed_config=dist_config, |
790 | | - cp_mesh=device_mesh["cp"], |
791 | | - tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8", |
792 | | - micro_batch_size=1, |
793 | | - text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence", |
794 | | - load_dataset_kwargs={ |
795 | | - "path": "parquet", |
796 | | - "split": "train", |
797 | | - "data_files": args.dataset_path, |
798 | | - "streaming": True, |
799 | | - }, |
800 | | - num_workers=1, |
801 | | - ) |
| 798 | + cp_mesh = device_mesh["cp"] |
| 799 | + |
| 800 | + # Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py |
| 801 | + if cp_mesh.get_local_rank() == 0: |
| 802 | + train_dataloader, _ = create_thd_dataloader( |
| 803 | + distributed_config=dist_config, |
| 804 | + tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8", |
| 805 | + micro_batch_size=1, |
| 806 | + text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence", |
| 807 | + load_dataset_kwargs={ |
| 808 | + "path": "parquet", |
| 809 | + "split": "train", |
| 810 | + "data_files": args.dataset_path, |
| 811 | + "streaming": True, |
| 812 | + }, |
| 813 | + num_workers=1, |
| 814 | + pad_sequences_to_be_divisible_by=cp_mesh.size() * 2, |
| 815 | + ) |
| 816 | + |
| 817 | + train_dataloader.collate_fn = DataCollatorForContextParallel( |
| 818 | + collator=train_dataloader.collate_fn, |
| 819 | + cp_world_size=cp_mesh.size(), |
| 820 | + ) |
| 821 | + else: |
| 822 | + train_dataloader = None |
| 823 | + |
| 824 | + dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh) |
802 | 825 |
|
803 | 826 | batches = list(itertools.islice(dataloader, 10)) |
804 | 827 |
|
|
0 commit comments