Skip to content

Commit 0279dcb

Browse files
authored
cleanup up context parallel class names and test locations (#1378)
Will make this easier to re-use for Llama3 * renames (and moves) `ContextParallelDataLoaderWrapper` to collator.py * renames CP data collator to `DataCollatorForContextParallel` * moves tests for `ContextParallelDataLoaderWrapper` to the esm2 model recipe --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent e0b8624 commit 0279dcb

8 files changed

Lines changed: 457 additions & 295 deletions

File tree

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 134 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -296,58 +296,6 @@ def _pad_batch_to_multiple_of(self, batch):
296296
)
297297

298298

299-
class MLMDataCollatorWithFlatteningCPAware:
300-
"""A collator that is aware of context parallelism.
301-
302-
For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split into shards for each context parallelism rank.
303-
304-
The shards are then typically sent to the CPAwareDataloader which will scatter them to the appropriate GPUs.
305-
"""
306-
307-
def __init__(self, collator: MLMDataCollatorWithFlattening, cp_world_size: int):
308-
"""Initialize the MLMDataCollatorWithFlatteningCPAware.
309-
310-
Args:
311-
collator: The collator to use for masking tokens.
312-
cp_world_size: The size of the context parallelism group.
313-
"""
314-
self.collator = collator
315-
self.cp_world_size = cp_world_size
316-
317-
def __call__(self, features) -> list[dict[str, Any]]:
318-
"""Process batches of data and create shards for each context parallelism rank.
319-
320-
Args:
321-
features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'.
322-
323-
Returns:
324-
A list of dictionaries, each containing a shard of the batch for a given context parallelism rank.
325-
"""
326-
batch = self.collator(features)
327-
328-
combined_batch = []
329-
for cp_rank in range(self.cp_world_size):
330-
input_ids_sharded, labels_sharded = split_batch_by_cp_rank(
331-
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
332-
input_ids_padded=batch["input_ids"],
333-
labels_padded=batch["labels"],
334-
qvk_format="thd",
335-
cp_rank=cp_rank,
336-
cp_world_size=self.cp_world_size,
337-
)
338-
batch_shard = dict(batch)
339-
batch_shard["input_ids"] = input_ids_sharded
340-
batch_shard["labels"] = labels_sharded
341-
# Now determine the max length of the sequence.
342-
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
343-
batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64)
344-
batch_shard["max_length_k"] = batch_shard["max_length_q"]
345-
batch_shard["pad_between_seqs"] = True
346-
combined_batch.append(batch_shard)
347-
348-
return combined_batch
349-
350-
351299
@dataclass
352300
class DataCollatorWithFlattening(DefaultDataCollator):
353301
"""Data collator for sequence packing with flash attentions cu_seqlens-style attention.
@@ -444,7 +392,7 @@ def __iter__(self):
444392
tokens_in_batch = current_length - len(sample["input_ids"])
445393
# Calculate how many tokens we can fit from this sample
446394
tokens_available = self.max_tokens_per_batch - tokens_in_batch
447-
first_part, remaining_part = split_sample_by_num_tokens(sample, tokens_available)
395+
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
448396
yield [*samples, first_part]
449397
samples = [remaining_part]
450398

@@ -460,7 +408,138 @@ def set_epoch(self, epoch: int):
460408
self.dataset.set_epoch(epoch)
461409

462410

463-
def split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
411+
class DataCollatorForContextParallel:
412+
"""A collator that is aware of context parallelism.
413+
414+
For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split
415+
into shards for each context parallelism rank.
416+
417+
The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the
418+
appropriate GPUs.
419+
"""
420+
421+
def __init__(self, collator: DefaultDataCollator, cp_world_size: int):
422+
"""Initialize the DataCollatorForContextParallel.
423+
424+
Args:
425+
collator: The collator to use for masking tokens.
426+
cp_world_size: The size of the context parallelism group.
427+
"""
428+
self.collator = collator
429+
self.cp_world_size = cp_world_size
430+
431+
def __call__(self, features) -> list[dict[str, Any]]:
432+
"""Process batches of data and create shards for each context parallelism rank.
433+
434+
Args:
435+
features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'.
436+
437+
Returns:
438+
A list of dictionaries, each containing a shard of the batch for a given context parallelism rank.
439+
"""
440+
batch = self.collator(features)
441+
442+
combined_batch = []
443+
for cp_rank in range(self.cp_world_size):
444+
input_ids_sharded, labels_sharded = _split_batch_by_cp_rank(
445+
cu_seqlens_padded=batch["cu_seq_lens_q_padded"],
446+
input_ids_padded=batch["input_ids"],
447+
labels_padded=batch["labels"],
448+
qvk_format="thd",
449+
cp_rank=cp_rank,
450+
cp_world_size=self.cp_world_size,
451+
)
452+
batch_shard = dict(batch)
453+
batch_shard["input_ids"] = input_ids_sharded
454+
batch_shard["labels"] = labels_sharded
455+
# Now determine the max length of the sequence.
456+
seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1]
457+
batch_shard["max_length_q"] = int((seqlens_q.max().item() + 63) // 64 * 64)
458+
batch_shard["max_length_k"] = batch_shard["max_length_q"]
459+
batch_shard["pad_between_seqs"] = True
460+
combined_batch.append(batch_shard)
461+
462+
return combined_batch
463+
464+
465+
class ContextParallelDataLoaderWrapper:
466+
"""A dataloader that is aware of context parallelism."""
467+
468+
def __init__(
469+
self,
470+
dataloader: torch.utils.data.DataLoader,
471+
cp_group: torch.distributed.ProcessGroup,
472+
cp_rank: int,
473+
):
474+
"""A dataloader wrapper that distributes the data across the context parallelism group.
475+
476+
This class will get the batch from the dataloader on CP rank 0, and then determine the shards for all the
477+
different CP group members. Then it will scatter the shards to the different CP group members. The shards are
478+
then returned to the caller for the current CP rank.
479+
480+
Args:
481+
dataloader: The dataloader to use.
482+
cp_group: The context parallel group.
483+
cp_rank: The rank of the current context parallel process.
484+
"""
485+
self.dataloader = dataloader
486+
self.cp_rank = cp_rank
487+
self.cp_group = cp_group
488+
self.num_cp_ranks = cp_group.size()
489+
self._iterator = None
490+
491+
def __iter__(self):
492+
"""Make the dataloader iterable."""
493+
self._iterator = iter(self.dataloader) # < --- collator output.
494+
return self
495+
496+
def __next__(self):
497+
"""Get the batch from the dataloader for the current CP rank."""
498+
batch = self._send_data_to_cp_ranks()
499+
return batch
500+
501+
def _send_data_to_cp_ranks(self):
502+
"""Send data to all the CP ranks.
503+
504+
This function will get the batch from the dataloader on CP rank 0, and then determine
505+
the shards for all the different CP group members.
506+
combined_batch = [<cp_rank_0_shard>, <cp_rank_1_shard>, ..., <cp_rank_n_shard>]
507+
Then it will scatter the shards to the different CP group members.
508+
The shards are then combined into a single batch and returned to the caller
509+
for the current CP rank.
510+
511+
Scalability:
512+
Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they do not
513+
grow linearly with CP size.
514+
515+
Args:
516+
None
517+
518+
Returns:
519+
batch: The batch for the current CP rank.
520+
521+
"""
522+
if self.cp_rank == 0:
523+
# Get data once, then make copies for each rank.
524+
if self._iterator is None:
525+
self._iterator = iter(self.dataloader)
526+
combined_batch = next(self._iterator)
527+
528+
else:
529+
combined_batch = None
530+
531+
scatter_object_output_list = [None]
532+
# Note: This does not provide an async_op handle. Thus its blocking.
533+
torch.distributed.scatter_object_list(
534+
scatter_object_output_list=scatter_object_output_list,
535+
scatter_object_input_list=combined_batch,
536+
group=self.cp_group,
537+
group_src=0,
538+
)
539+
return scatter_object_output_list[0]
540+
541+
542+
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
464543
"""Split a sample dictionary at a specified number of tokens.
465544
466545
This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens,
@@ -615,7 +694,7 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token
615694

616695
# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387
617696
# we can replace this with the one in TransformerEngine.
618-
def split_batch_by_cp_rank(
697+
def _split_batch_by_cp_rank(
619698
cu_seqlens_padded: torch.Tensor,
620699
input_ids_padded: torch.Tensor,
621700
labels_padded: torch.Tensor,

bionemo-recipes/models/esm2/tests/test_collator.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
DataCollatorWithFlattening,
2424
MLMDataCollatorWithFlattening,
2525
TokenPackingDataset,
26-
split_sample_by_num_tokens,
26+
_split_sample_by_num_tokens,
2727
)
2828

2929

@@ -494,36 +494,36 @@ def __iter__(self):
494494
assert sum(len(sample["input_ids"]) for sample in batches[0]) == 90
495495

496496

497-
def test_split_sample_by_num_tokens_basic():
498-
"""Test split_sample_by_num_tokens with basic input_ids."""
497+
def test__split_sample_by_num_tokens_basic():
498+
"""Test _split_sample_by_num_tokens with basic input_ids."""
499499
sample = {"input_ids": [0, 5, 6, 7, 8, 9, 2]}
500-
first, remaining = split_sample_by_num_tokens(sample, 3)
500+
first, remaining = _split_sample_by_num_tokens(sample, 3)
501501

502502
assert first["input_ids"] == [0, 5, 6]
503503
assert remaining["input_ids"] == [7, 8, 9, 2]
504504
assert len(first["input_ids"]) == 3
505505
assert len(remaining["input_ids"]) == 4
506506

507507

508-
def test_split_sample_by_num_tokens_with_labels():
509-
"""Test split_sample_by_num_tokens with input_ids and labels."""
508+
def test__split_sample_by_num_tokens_with_labels():
509+
"""Test _split_sample_by_num_tokens with input_ids and labels."""
510510
sample = {"input_ids": [0, 5, 6, 7, 8, 2], "labels": [0, 5, 6, 7, 8, 2]}
511-
first, remaining = split_sample_by_num_tokens(sample, 3)
511+
first, remaining = _split_sample_by_num_tokens(sample, 3)
512512

513513
assert first["input_ids"] == [0, 5, 6]
514514
assert first["labels"] == [0, 5, 6]
515515
assert remaining["input_ids"] == [7, 8, 2]
516516
assert remaining["labels"] == [7, 8, 2]
517517

518518

519-
def test_split_sample_by_num_tokens_with_attention_mask():
520-
"""Test split_sample_by_num_tokens with input_ids, attention_mask, and labels."""
519+
def test__split_sample_by_num_tokens_with_attention_mask():
520+
"""Test _split_sample_by_num_tokens with input_ids, attention_mask, and labels."""
521521
sample = {
522522
"input_ids": [0, 5, 6, 7, 8, 2],
523523
"attention_mask": [1, 1, 1, 1, 1, 1],
524524
"labels": [0, 5, 6, 7, 8, 2],
525525
}
526-
first, remaining = split_sample_by_num_tokens(sample, 4)
526+
first, remaining = _split_sample_by_num_tokens(sample, 4)
527527

528528
assert first["input_ids"] == [0, 5, 6, 7]
529529
assert first["attention_mask"] == [1, 1, 1, 1]
@@ -533,14 +533,14 @@ def test_split_sample_by_num_tokens_with_attention_mask():
533533
assert remaining["labels"] == [8, 2]
534534

535535

536-
def test_split_sample_by_num_tokens_with_token_type_ids():
537-
"""Test split_sample_by_num_tokens with token_type_ids."""
536+
def test__split_sample_by_num_tokens_with_token_type_ids():
537+
"""Test _split_sample_by_num_tokens with token_type_ids."""
538538
sample = {
539539
"input_ids": [0, 5, 6, 7, 8, 2],
540540
"token_type_ids": [0, 0, 0, 1, 1, 1],
541541
"labels": [0, 5, 6, 7, 8, 2],
542542
}
543-
first, remaining = split_sample_by_num_tokens(sample, 3)
543+
first, remaining = _split_sample_by_num_tokens(sample, 3)
544544

545545
assert first["input_ids"] == [0, 5, 6]
546546
assert first["token_type_ids"] == [0, 0, 0]
@@ -550,14 +550,14 @@ def test_split_sample_by_num_tokens_with_token_type_ids():
550550
assert remaining["labels"] == [7, 8, 2]
551551

552552

553-
def test_split_sample_by_num_tokens_with_token_type():
554-
"""Test split_sample_by_num_tokens with token_type (alternative name)."""
553+
def test__split_sample_by_num_tokens_with_token_type():
554+
"""Test _split_sample_by_num_tokens with token_type (alternative name)."""
555555
sample = {
556556
"input_ids": [0, 5, 6, 7, 8, 2],
557557
"token_type": [0, 0, 0, 1, 1, 1],
558558
"labels": [0, 5, 6, 7, 8, 2],
559559
}
560-
first, remaining = split_sample_by_num_tokens(sample, 3)
560+
first, remaining = _split_sample_by_num_tokens(sample, 3)
561561

562562
assert first["input_ids"] == [0, 5, 6]
563563
assert first["token_type"] == [0, 0, 0]
@@ -567,14 +567,14 @@ def test_split_sample_by_num_tokens_with_token_type():
567567
assert remaining["labels"] == [7, 8, 2]
568568

569569

570-
def test_split_sample_by_num_tokens_with_tensors():
571-
"""Test split_sample_by_num_tokens with torch tensors."""
570+
def test__split_sample_by_num_tokens_with_tensors():
571+
"""Test _split_sample_by_num_tokens with torch tensors."""
572572
sample = {
573573
"input_ids": torch.tensor([0, 5, 6, 7, 8, 2]),
574574
"attention_mask": torch.tensor([1, 1, 1, 1, 1, 1]),
575575
"labels": torch.tensor([0, 5, 6, 7, 8, 2]),
576576
}
577-
first, remaining = split_sample_by_num_tokens(sample, 3)
577+
first, remaining = _split_sample_by_num_tokens(sample, 3)
578578

579579
assert torch.equal(first["input_ids"], torch.tensor([0, 5, 6]))
580580
assert torch.equal(first["attention_mask"], torch.tensor([1, 1, 1]))
@@ -584,14 +584,14 @@ def test_split_sample_by_num_tokens_with_tensors():
584584
assert torch.equal(remaining["labels"], torch.tensor([7, 8, 2]))
585585

586586

587-
def test_split_sample_by_num_tokens_with_metadata():
588-
"""Test split_sample_by_num_tokens preserves non-sequence fields."""
587+
def test__split_sample_by_num_tokens_with_metadata():
588+
"""Test _split_sample_by_num_tokens preserves non-sequence fields."""
589589
sample = {
590590
"input_ids": [0, 5, 6, 7, 8, 2],
591591
"labels": [0, 5, 6, 7, 8, 2],
592592
"metadata": {"id": 123, "source": "test"},
593593
}
594-
first, remaining = split_sample_by_num_tokens(sample, 3)
594+
first, remaining = _split_sample_by_num_tokens(sample, 3)
595595

596596
# Sequence fields should be split
597597
assert first["input_ids"] == [0, 5, 6]
@@ -602,23 +602,23 @@ def test_split_sample_by_num_tokens_with_metadata():
602602
assert remaining["metadata"] == {"id": 123, "source": "test"}
603603

604604

605-
def test_split_sample_by_num_tokens_errors():
606-
"""Test split_sample_by_num_tokens raises errors for invalid inputs."""
605+
def test__split_sample_by_num_tokens_errors():
606+
"""Test _split_sample_by_num_tokens raises errors for invalid inputs."""
607607
sample = {"input_ids": [0, 5, 6, 7, 2]}
608608

609609
# num_tokens >= sample_length should raise ValueError
610610
with pytest.raises(ValueError, match="num_tokens.*must be less than sample length"):
611-
split_sample_by_num_tokens(sample, 5)
611+
_split_sample_by_num_tokens(sample, 5)
612612

613613
with pytest.raises(ValueError, match="num_tokens.*must be less than sample length"):
614-
split_sample_by_num_tokens(sample, 10)
614+
_split_sample_by_num_tokens(sample, 10)
615615

616616
# num_tokens <= 0 should raise ValueError
617617
with pytest.raises(ValueError, match="num_tokens.*must be positive"):
618-
split_sample_by_num_tokens(sample, 0)
618+
_split_sample_by_num_tokens(sample, 0)
619619

620620
with pytest.raises(ValueError, match="num_tokens.*must be positive"):
621-
split_sample_by_num_tokens(sample, -1)
621+
_split_sample_by_num_tokens(sample, -1)
622622

623623

624624
def test_token_packing_dataset_with_split_samples():

0 commit comments

Comments
 (0)