Skip to content

Commit f3a6084

Browse files
authored
Add CP tests for Llama3 (#1414)
Adds CP tests to models/llama3, and fixes the tests in recipes/llama3_native_te Closes BIO-40 --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 57bc767 commit f3a6084

15 files changed

Lines changed: 2111 additions & 430 deletions

File tree

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def set_epoch(self, epoch: int):
275275
self.dataset.set_epoch(epoch)
276276

277277

278+
@dataclass
278279
class DataCollatorForContextParallel:
279280
"""A collator that is aware of context parallelism.
280281
@@ -285,15 +286,9 @@ class DataCollatorForContextParallel:
285286
appropriate GPUs.
286287
"""
287288

288-
def __init__(self, collator: DataCollator, cp_world_size: int):
289-
"""Initialize the DataCollatorForContextParallel.
290-
291-
Args:
292-
collator: The collator to use for masking tokens.
293-
cp_world_size: The size of the context parallelism group.
294-
"""
295-
self.collator = collator
296-
self.cp_world_size = cp_world_size
289+
collator: DataCollator
290+
cp_world_size: int
291+
qkv_format: str = "thd"
297292

298293
def __call__(self, features) -> list[dict[str, Any]]:
299294
"""Process batches of data and create shards for each context parallelism rank.
@@ -309,21 +304,29 @@ def __call__(self, features) -> list[dict[str, Any]]:
309304
combined_batch = []
310305
for cp_rank in range(self.cp_world_size):
311306
input_ids_sharded, labels_sharded = _split_batch_by_cp_rank(
312-
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
307+
cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format.
313308
input_ids_padded=batch["input_ids"],
314309
labels_padded=batch["labels"],
315-
qvk_format="thd",
310+
qvk_format=self.qkv_format,
316311
cp_rank=cp_rank,
317312
cp_world_size=self.cp_world_size,
318313
)
319314
batch_shard = dict(batch)
320315
batch_shard["input_ids"] = input_ids_sharded
321316
batch_shard["labels"] = labels_sharded
322317
# Now determine the max length of the sequence.
323-
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
324-
batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64)
325-
batch_shard["max_length_k"] = batch_shard["max_length_q"]
326-
batch_shard["pad_between_seqs"] = True
318+
if self.qkv_format == "thd":
319+
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
320+
max_length = seqlens_q.max().item()
321+
batch_shard["pad_between_seqs"] = True
322+
elif self.qkv_format == "bshd":
323+
max_length = batch["input_ids"].shape[1]
324+
# For BSHD context parallelism, we can't handle padding, so we remove the attention mask.
325+
del batch_shard["attention_mask"]
326+
else:
327+
raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!")
328+
329+
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64)
327330
combined_batch.append(batch_shard)
328331

329332
return combined_batch
@@ -727,7 +730,7 @@ def process_tensor_bshd(val):
727730

728731

729732
class BatchType(TypedDict):
730-
"""The fields in the batch dictionary for context parallel."""
733+
"""The fields in the batch dictionary fo THD context parallel."""
731734

732735
input_ids: torch.Tensor
733736
labels: torch.Tensor
@@ -737,6 +740,7 @@ class BatchType(TypedDict):
737740
cu_seq_lens_k_padded: torch.Tensor
738741
max_length_q: int
739742
max_length_k: int
743+
pad_between_seqs: bool
740744

741745

742746
def _scatter_batch_to_cp_ranks(

bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
from typing import Dict, Iterator, List
1818
from unittest import mock
1919

20+
import pytest
2021
import torch
2122
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp
2223
from transformers import DataCollatorForLanguageModeling
2324

2425
from esm.collator import (
26+
BatchType,
2527
ContextParallelDataLoaderWrapper,
28+
DataCollatorForContextParallel,
2629
DataCollatorWithFlattening,
2730
_split_batch_by_cp_rank,
2831
)
@@ -887,3 +890,106 @@ def test_bshd_and_thd_equivalence(tokenizer):
887890
torch.sort(batch_bshd["input_ids"][1])[0],
888891
msg="Reconstructed sequence 2 doesn't match original",
889892
)
893+
894+
895+
@pytest.mark.parametrize("cp_world_size", [2, 4])
896+
def test_data_collator_for_context_parallel_returns_correct_list_size(tokenizer, cp_world_size):
897+
"""Test that DataCollatorForContextParallel returns a list of the correct size."""
898+
divisibility_factor = 2 * cp_world_size
899+
900+
# Create the wrapped collator that produces padded THD batches
901+
base_collator = DataCollatorWithFlattening(
902+
collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
903+
pad_sequences_to_be_divisible_by=divisibility_factor,
904+
)
905+
906+
# Create the context parallel collator
907+
cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size)
908+
909+
# Create test sequences
910+
features = [
911+
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
912+
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
913+
]
914+
915+
# Call the collator
916+
result = cp_collator(features)
917+
918+
# Assert that the result is a list of the correct size
919+
assert isinstance(result, list), f"Expected list, got {type(result)}"
920+
assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}"
921+
922+
923+
def test_data_collator_for_context_parallel_thd(tokenizer):
924+
"""Test that each shard from DataCollatorForContextParallel has all required keys from BatchType."""
925+
926+
cp_world_size = 2
927+
divisibility_factor = 2 * cp_world_size
928+
929+
# Create the wrapped collator that produces padded THD batches
930+
base_collator = DataCollatorWithFlattening(
931+
collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
932+
pad_sequences_to_be_divisible_by=divisibility_factor,
933+
)
934+
935+
# Create the context parallel collator
936+
cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size)
937+
938+
# Create test sequences
939+
features = [
940+
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
941+
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
942+
]
943+
944+
# Call the collator
945+
result = cp_collator(features)
946+
947+
assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}"
948+
949+
# Define the required keys from BatchType
950+
required_keys = set(BatchType.__annotations__.keys())
951+
952+
# Assert each shard has all required keys
953+
for cp_rank, shard in enumerate(result):
954+
assert set(shard.keys()) == required_keys, (
955+
f"CP rank {cp_rank}: difference: {set(shard.keys()) - required_keys}"
956+
)
957+
958+
959+
def test_data_collator_for_context_parallel_bshd(tokenizer):
960+
"""Test that each shard from DataCollatorForContextParallel has all required keys from BatchType."""
961+
962+
cp_world_size = 2
963+
divisibility_factor = 2 * cp_world_size
964+
965+
# Create the wrapped collator that produces padded THD batches
966+
base_collator = DataCollatorForLanguageModeling(
967+
tokenizer=tokenizer,
968+
mlm_probability=0.15,
969+
pad_to_multiple_of=divisibility_factor,
970+
)
971+
972+
# Create the context parallel collator
973+
cp_collator = DataCollatorForContextParallel(
974+
collator=base_collator, cp_world_size=cp_world_size, qkv_format="bshd"
975+
)
976+
977+
# Create test sequences
978+
features = [
979+
{"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens
980+
{"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens
981+
]
982+
983+
# Call the collator
984+
result = cp_collator(features)
985+
986+
assert len(result) == cp_world_size, f"Expected list of size {cp_world_size}, got {len(result)}"
987+
988+
# Define the required keys from BatchType
989+
required_keys = {"input_ids", "labels", "max_length_q", "max_length_k"}
990+
991+
# Assert each shard has all required keys
992+
for cp_rank, shard in enumerate(result):
993+
assert set(shard.keys()) == required_keys, (
994+
f"CP rank {cp_rank}: expected keys {required_keys}, got {set(shard.keys())}"
995+
)

0 commit comments

Comments
 (0)