|
17 | 17 | from typing import Dict, Iterator, List |
18 | 18 | from unittest import mock |
19 | 19 |
|
| 20 | +import pytest |
20 | 21 | import torch |
21 | 22 | from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp |
22 | 23 | from transformers import DataCollatorForLanguageModeling |
23 | 24 |
|
24 | 25 | from esm.collator import ( |
| 26 | + BatchType, |
25 | 27 | ContextParallelDataLoaderWrapper, |
| 28 | + DataCollatorForContextParallel, |
26 | 29 | DataCollatorWithFlattening, |
27 | 30 | _split_batch_by_cp_rank, |
28 | 31 | ) |
@@ -887,3 +890,106 @@ def test_bshd_and_thd_equivalence(tokenizer): |
887 | 890 | torch.sort(batch_bshd["input_ids"][1])[0], |
888 | 891 | msg="Reconstructed sequence 2 doesn't match original", |
889 | 892 | ) |
| 893 | + |
| 894 | + |
| 895 | +@pytest.mark.parametrize("cp_world_size", [2, 4]) |
| 896 | +def test_data_collator_for_context_parallel_returns_correct_list_size(tokenizer, cp_world_size): |
| 897 | + """Test that DataCollatorForContextParallel returns a list of the correct size.""" |
| 898 | + divisibility_factor = 2 * cp_world_size |
| 899 | + |
| 900 | + # Create the wrapped collator that produces padded THD batches |
| 901 | + base_collator = DataCollatorWithFlattening( |
| 902 | + collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15), |
| 903 | + pad_sequences_to_be_divisible_by=divisibility_factor, |
| 904 | + ) |
| 905 | + |
| 906 | + # Create the context parallel collator |
| 907 | + cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size) |
| 908 | + |
| 909 | + # Create test sequences |
| 910 | + features = [ |
| 911 | + {"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens |
| 912 | + {"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens |
| 913 | + ] |
| 914 | + |
| 915 | + # Call the collator |
| 916 | + result = cp_collator(features) |
| 917 | + |
| 918 | + # Assert that the result is a list of the correct size |
| 919 | + assert isinstance(result, list), f"Expected list, got {type(result)}" |
| 920 | + assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}" |
| 921 | + |
| 922 | + |
| 923 | +def test_data_collator_for_context_parallel_thd(tokenizer): |
| 924 | + """Test that each shard from DataCollatorForContextParallel has all required keys from BatchType.""" |
| 925 | + |
| 926 | + cp_world_size = 2 |
| 927 | + divisibility_factor = 2 * cp_world_size |
| 928 | + |
| 929 | + # Create the wrapped collator that produces padded THD batches |
| 930 | + base_collator = DataCollatorWithFlattening( |
| 931 | + collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15), |
| 932 | + pad_sequences_to_be_divisible_by=divisibility_factor, |
| 933 | + ) |
| 934 | + |
| 935 | + # Create the context parallel collator |
| 936 | + cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size) |
| 937 | + |
| 938 | + # Create test sequences |
| 939 | + features = [ |
| 940 | + {"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens |
| 941 | + {"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens |
| 942 | + ] |
| 943 | + |
| 944 | + # Call the collator |
| 945 | + result = cp_collator(features) |
| 946 | + |
| 947 | + assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}" |
| 948 | + |
| 949 | + # Define the required keys from BatchType |
| 950 | + required_keys = set(BatchType.__annotations__.keys()) |
| 951 | + |
| 952 | + # Assert each shard has all required keys |
| 953 | + for cp_rank, shard in enumerate(result): |
| 954 | + assert set(shard.keys()) == required_keys, ( |
| 955 | + f"CP rank {cp_rank}: difference: {set(shard.keys()) - required_keys}" |
| 956 | + ) |
| 957 | + |
| 958 | + |
| 959 | +def test_data_collator_for_context_parallel_bshd(tokenizer): |
| 960 | + """Test that each shard from DataCollatorForContextParallel has all required keys from BatchType.""" |
| 961 | + |
| 962 | + cp_world_size = 2 |
| 963 | + divisibility_factor = 2 * cp_world_size |
| 964 | + |
| 965 | + # Create the wrapped collator that produces padded THD batches |
| 966 | + base_collator = DataCollatorForLanguageModeling( |
| 967 | + tokenizer=tokenizer, |
| 968 | + mlm_probability=0.15, |
| 969 | + pad_to_multiple_of=divisibility_factor, |
| 970 | + ) |
| 971 | + |
| 972 | + # Create the context parallel collator |
| 973 | + cp_collator = DataCollatorForContextParallel( |
| 974 | + collator=base_collator, cp_world_size=cp_world_size, qkv_format="bshd" |
| 975 | + ) |
| 976 | + |
| 977 | + # Create test sequences |
| 978 | + features = [ |
| 979 | + {"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens |
| 980 | + {"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens |
| 981 | + ] |
| 982 | + |
| 983 | + # Call the collator |
| 984 | + result = cp_collator(features) |
| 985 | + |
| 986 | + assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}" |
| 987 | + |
| 988 | + # Define the required keys from BatchType |
| 989 | + required_keys = {"input_ids", "labels", "max_length_q", "max_length_k"} |
| 990 | + |
| 991 | + # Assert each shard has all required keys |
| 992 | + for cp_rank, shard in enumerate(result): |
| 993 | + assert set(shard.keys()) == required_keys, ( |
| 994 | + f"CP rank {cp_rank}: expected keys {required_keys}, got {set(shard.keys())}" |
| 995 | + ) |
0 commit comments