Skip to content

Commit 6d4e8bb

Browse files
authored
Add split_samples functionality to TokenPackingDataset (#1372)
This update introduces a new parameter, split_samples, to the TokenPackingDataset class, allowing for the splitting of samples to ensure batches contain exactly max_tokens_per_batch tokens. A new utility function, split_sample_by_num_tokens, is also added to facilitate this feature by splitting sample dictionaries at a specified number of tokens. This will help auto-regressive models where we want to ensure the entire batch is filled for optimal token throughput Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent b7ae1d2 commit 6d4e8bb

6 files changed

Lines changed: 512 additions & 20 deletions

File tree

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,21 +411,44 @@ class TokenPackingDataset(torch.utils.data.IterableDataset):
411411
"""Maximum number of tokens per batch."""
412412
drop_last: bool = True
413413
"""Whether to drop the last batch if it's less than max_length."""
414+
split_samples: bool = False
415+
"""Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens."""
414416

415417
def __iter__(self):
416418
"""Yield batches of samples, each with a variable number of tokens up to the maximum length.
417419
420+
When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting
421+
the final sample if needed. The remaining tokens from the split sample start the next batch.
422+
418423
Returns:
419424
A generator of batches of samples, each with a variable number of tokens up to the maximum length.
420425
"""
421426
samples = []
422427
current_length = 0
423428
for sample in iter(self.dataset):
424429
current_length += len(sample["input_ids"])
425-
if current_length > self.max_tokens_per_batch:
426-
yield samples
427-
samples = [sample]
428-
current_length = len(sample["input_ids"])
430+
if current_length == self.max_tokens_per_batch:
431+
yield [*samples, sample]
432+
samples = []
433+
current_length = 0
434+
435+
elif current_length > self.max_tokens_per_batch:
436+
if not self.split_samples:
437+
# If we are not splitting samples, we can just yield the current batch (before this sample) and
438+
# start a new one.
439+
yield samples
440+
samples = [sample]
441+
442+
else:
443+
# Calculate how many tokens are already in the batch
444+
tokens_in_batch = current_length - len(sample["input_ids"])
445+
# Calculate how many tokens we can fit from this sample
446+
tokens_available = self.max_tokens_per_batch - tokens_in_batch
447+
first_part, remaining_part = split_sample_by_num_tokens(sample, tokens_available)
448+
yield [*samples, first_part]
449+
samples = [remaining_part]
450+
451+
current_length = len(samples[0]["input_ids"])
429452
else:
430453
samples.append(sample)
431454

@@ -437,6 +460,74 @@ def set_epoch(self, epoch: int):
437460
self.dataset.set_epoch(epoch)
438461

439462

463+
def split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
464+
"""Split a sample dictionary at a specified number of tokens.
465+
466+
This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens,
467+
and the second part contains the remaining tokens. All fields that are sequences (input_ids, attention_mask,
468+
token_type_ids, labels, etc.) are split accordingly.
469+
470+
Args:
471+
sample: Dictionary containing sample data with fields like input_ids, attention_mask, token_type_ids, labels, etc.
472+
num_tokens: Number of tokens to include in the first part of the split.
473+
474+
Returns:
475+
A tuple of two dictionaries: (first_part, remaining_part), where:
476+
- first_part contains the first `num_tokens` tokens from each sequence field
477+
- remaining_part contains the remaining tokens from each sequence field
478+
479+
Example:
480+
>>> sample = {
481+
... "input_ids": [0, 5, 6, 7, 8, 9, 2],
482+
... "attention_mask": [1, 1, 1, 1, 1, 1, 1],
483+
... "labels": [0, 5, 6, 7, 8, 9, 2]
484+
... }
485+
>>> first, remaining = split_sample_by_num_tokens(sample, 3)
486+
>>> first["input_ids"] # [0, 5, 6]
487+
>>> remaining["input_ids"] # [7, 8, 9, 2]
488+
"""
489+
sample_length = len(sample["input_ids"])
490+
if num_tokens >= sample_length:
491+
raise ValueError(
492+
f"num_tokens ({num_tokens}) must be less than sample length ({sample_length}) to split the sample"
493+
)
494+
if num_tokens <= 0:
495+
raise ValueError(f"num_tokens ({num_tokens}) must be positive")
496+
497+
first_part = {}
498+
remaining_part = {}
499+
500+
# Fields that should be split by tokens (sequence fields)
501+
sequence_fields = ["input_ids", "attention_mask", "token_type_ids", "token_type", "labels"]
502+
503+
for key, value in sample.items():
504+
if key in sequence_fields:
505+
# Handle both list and tensor inputs
506+
if isinstance(value, torch.Tensor):
507+
first_part[key] = value[:num_tokens].clone()
508+
remaining_part[key] = value[num_tokens:].clone()
509+
elif isinstance(value, list):
510+
first_part[key] = value[:num_tokens]
511+
remaining_part[key] = value[num_tokens:]
512+
else:
513+
# For other types, try to slice if possible
514+
try:
515+
first_part[key] = value[:num_tokens]
516+
remaining_part[key] = value[num_tokens:]
517+
except (TypeError, IndexError):
518+
# If slicing doesn't work, copy the value to both parts
519+
# This handles fields that shouldn't be split (like metadata)
520+
first_part[key] = value
521+
remaining_part[key] = value
522+
else:
523+
# For non-sequence fields, copy to both parts
524+
# This handles metadata fields that shouldn't be split
525+
first_part[key] = value
526+
remaining_part[key] = value
527+
528+
return first_part, remaining_part
529+
530+
440531
def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
441532
is_labels_provided = "labels" in features[0]
442533
sample_lengths = [len(sample["input_ids"]) for sample in features]

bionemo-recipes/models/esm2/tests/test_collator.py

Lines changed: 216 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515

1616
from unittest.mock import MagicMock
1717

18+
import pytest
1819
import torch
1920
from transformers import DataCollatorForLanguageModeling
2021

21-
from esm.collator import DataCollatorWithFlattening, MLMDataCollatorWithFlattening, TokenPackingDataset
22+
from esm.collator import (
23+
DataCollatorWithFlattening,
24+
MLMDataCollatorWithFlattening,
25+
TokenPackingDataset,
26+
split_sample_by_num_tokens,
27+
)
2228

2329

2430
def test_data_collator_with_flattening_basic():
@@ -486,3 +492,212 @@ def __iter__(self):
486492
assert len(batches) == 1
487493
assert len(batches[0]) == 3
488494
assert sum(len(sample["input_ids"]) for sample in batches[0]) == 90
495+
496+
497+
def test_split_sample_by_num_tokens_basic():
498+
"""Test split_sample_by_num_tokens with basic input_ids."""
499+
sample = {"input_ids": [0, 5, 6, 7, 8, 9, 2]}
500+
first, remaining = split_sample_by_num_tokens(sample, 3)
501+
502+
assert first["input_ids"] == [0, 5, 6]
503+
assert remaining["input_ids"] == [7, 8, 9, 2]
504+
assert len(first["input_ids"]) == 3
505+
assert len(remaining["input_ids"]) == 4
506+
507+
508+
def test_split_sample_by_num_tokens_with_labels():
509+
"""Test split_sample_by_num_tokens with input_ids and labels."""
510+
sample = {"input_ids": [0, 5, 6, 7, 8, 2], "labels": [0, 5, 6, 7, 8, 2]}
511+
first, remaining = split_sample_by_num_tokens(sample, 3)
512+
513+
assert first["input_ids"] == [0, 5, 6]
514+
assert first["labels"] == [0, 5, 6]
515+
assert remaining["input_ids"] == [7, 8, 2]
516+
assert remaining["labels"] == [7, 8, 2]
517+
518+
519+
def test_split_sample_by_num_tokens_with_attention_mask():
520+
"""Test split_sample_by_num_tokens with input_ids, attention_mask, and labels."""
521+
sample = {
522+
"input_ids": [0, 5, 6, 7, 8, 2],
523+
"attention_mask": [1, 1, 1, 1, 1, 1],
524+
"labels": [0, 5, 6, 7, 8, 2],
525+
}
526+
first, remaining = split_sample_by_num_tokens(sample, 4)
527+
528+
assert first["input_ids"] == [0, 5, 6, 7]
529+
assert first["attention_mask"] == [1, 1, 1, 1]
530+
assert first["labels"] == [0, 5, 6, 7]
531+
assert remaining["input_ids"] == [8, 2]
532+
assert remaining["attention_mask"] == [1, 1]
533+
assert remaining["labels"] == [8, 2]
534+
535+
536+
def test_split_sample_by_num_tokens_with_token_type_ids():
537+
"""Test split_sample_by_num_tokens with token_type_ids."""
538+
sample = {
539+
"input_ids": [0, 5, 6, 7, 8, 2],
540+
"token_type_ids": [0, 0, 0, 1, 1, 1],
541+
"labels": [0, 5, 6, 7, 8, 2],
542+
}
543+
first, remaining = split_sample_by_num_tokens(sample, 3)
544+
545+
assert first["input_ids"] == [0, 5, 6]
546+
assert first["token_type_ids"] == [0, 0, 0]
547+
assert first["labels"] == [0, 5, 6]
548+
assert remaining["input_ids"] == [7, 8, 2]
549+
assert remaining["token_type_ids"] == [1, 1, 1]
550+
assert remaining["labels"] == [7, 8, 2]
551+
552+
553+
def test_split_sample_by_num_tokens_with_token_type():
554+
"""Test split_sample_by_num_tokens with token_type (alternative name)."""
555+
sample = {
556+
"input_ids": [0, 5, 6, 7, 8, 2],
557+
"token_type": [0, 0, 0, 1, 1, 1],
558+
"labels": [0, 5, 6, 7, 8, 2],
559+
}
560+
first, remaining = split_sample_by_num_tokens(sample, 3)
561+
562+
assert first["input_ids"] == [0, 5, 6]
563+
assert first["token_type"] == [0, 0, 0]
564+
assert first["labels"] == [0, 5, 6]
565+
assert remaining["input_ids"] == [7, 8, 2]
566+
assert remaining["token_type"] == [1, 1, 1]
567+
assert remaining["labels"] == [7, 8, 2]
568+
569+
570+
def test_split_sample_by_num_tokens_with_tensors():
571+
"""Test split_sample_by_num_tokens with torch tensors."""
572+
sample = {
573+
"input_ids": torch.tensor([0, 5, 6, 7, 8, 2]),
574+
"attention_mask": torch.tensor([1, 1, 1, 1, 1, 1]),
575+
"labels": torch.tensor([0, 5, 6, 7, 8, 2]),
576+
}
577+
first, remaining = split_sample_by_num_tokens(sample, 3)
578+
579+
assert torch.equal(first["input_ids"], torch.tensor([0, 5, 6]))
580+
assert torch.equal(first["attention_mask"], torch.tensor([1, 1, 1]))
581+
assert torch.equal(first["labels"], torch.tensor([0, 5, 6]))
582+
assert torch.equal(remaining["input_ids"], torch.tensor([7, 8, 2]))
583+
assert torch.equal(remaining["attention_mask"], torch.tensor([1, 1, 1]))
584+
assert torch.equal(remaining["labels"], torch.tensor([7, 8, 2]))
585+
586+
587+
def test_split_sample_by_num_tokens_with_metadata():
588+
"""Test split_sample_by_num_tokens preserves non-sequence fields."""
589+
sample = {
590+
"input_ids": [0, 5, 6, 7, 8, 2],
591+
"labels": [0, 5, 6, 7, 8, 2],
592+
"metadata": {"id": 123, "source": "test"},
593+
}
594+
first, remaining = split_sample_by_num_tokens(sample, 3)
595+
596+
# Sequence fields should be split
597+
assert first["input_ids"] == [0, 5, 6]
598+
assert remaining["input_ids"] == [7, 8, 2]
599+
600+
# Metadata should be copied to both parts
601+
assert first["metadata"] == {"id": 123, "source": "test"}
602+
assert remaining["metadata"] == {"id": 123, "source": "test"}
603+
604+
605+
def test_split_sample_by_num_tokens_errors():
606+
"""Test split_sample_by_num_tokens raises errors for invalid inputs."""
607+
sample = {"input_ids": [0, 5, 6, 7, 2]}
608+
609+
# num_tokens >= sample_length should raise ValueError
610+
with pytest.raises(ValueError, match="num_tokens.*must be less than sample length"):
611+
split_sample_by_num_tokens(sample, 5)
612+
613+
with pytest.raises(ValueError, match="num_tokens.*must be less than sample length"):
614+
split_sample_by_num_tokens(sample, 10)
615+
616+
# num_tokens <= 0 should raise ValueError
617+
with pytest.raises(ValueError, match="num_tokens.*must be positive"):
618+
split_sample_by_num_tokens(sample, 0)
619+
620+
with pytest.raises(ValueError, match="num_tokens.*must be positive"):
621+
split_sample_by_num_tokens(sample, -1)
622+
623+
624+
def test_token_packing_dataset_with_split_samples():
625+
"""Test TokenPackingDataset with split_samples=True ensures exact batch sizes."""
626+
627+
class MockDataset(torch.utils.data.IterableDataset):
628+
def __iter__(self):
629+
yield {"input_ids": torch.arange(40)} # 40 tokens
630+
yield {"input_ids": torch.arange(50)} # 50 tokens
631+
yield {"input_ids": torch.arange(30)} # 30 tokens
632+
633+
dataset = MockDataset()
634+
token_packing_dataset = TokenPackingDataset(dataset, max_tokens_per_batch=100, split_samples=True, drop_last=False)
635+
batches = list(token_packing_dataset)
636+
637+
# First batch should have exactly 100 tokens (40 + 50 + 10 from the 30-token sample)
638+
assert len(batches) >= 1
639+
assert sum(len(sample["input_ids"]) for sample in batches[0]) == 100
640+
641+
# Second batch should start with the remaining 20 tokens from the split sample
642+
if len(batches) > 1:
643+
assert sum(len(sample["input_ids"]) for sample in batches[1]) == 20
644+
645+
646+
def test_token_packing_dataset_with_split_samples_exact_fit():
647+
"""Test TokenPackingDataset with split_samples=True when samples exactly fill batches."""
648+
649+
class MockDataset(torch.utils.data.IterableDataset):
650+
def __iter__(self):
651+
yield {"input_ids": torch.arange(50)} # 50 tokens
652+
yield {"input_ids": torch.arange(50)} # 50 tokens (total: 100, exactly max)
653+
654+
dataset = MockDataset()
655+
token_packing_dataset = TokenPackingDataset(dataset, max_tokens_per_batch=100, split_samples=True, drop_last=False)
656+
batches = list(token_packing_dataset)
657+
658+
# Should have 1 batch with exactly 100 tokens
659+
assert len(batches) == 1
660+
assert sum(len(sample["input_ids"]) for sample in batches[0]) == 100
661+
662+
663+
def test_token_packing_dataset_with_split_samples_multiple_fields():
664+
"""Test TokenPackingDataset with split_samples=True handles multiple fields correctly."""
665+
666+
class MockDataset(torch.utils.data.IterableDataset):
667+
def __iter__(self):
668+
yield {
669+
"input_ids": torch.arange(40),
670+
"attention_mask": torch.ones(40),
671+
"labels": torch.arange(40),
672+
}
673+
yield {
674+
"input_ids": torch.arange(50),
675+
"attention_mask": torch.ones(50),
676+
"labels": torch.arange(50),
677+
}
678+
yield {
679+
"input_ids": torch.arange(30),
680+
"attention_mask": torch.ones(30),
681+
"labels": torch.arange(30),
682+
}
683+
684+
dataset = MockDataset()
685+
token_packing_dataset = TokenPackingDataset(dataset, max_tokens_per_batch=100, split_samples=True, drop_last=False)
686+
batches = list(token_packing_dataset)
687+
688+
# First batch should have exactly 100 tokens
689+
assert len(batches) >= 1
690+
first_batch_total = sum(len(sample["input_ids"]) for sample in batches[0])
691+
assert first_batch_total == 100
692+
693+
# Second batch should have exactly 20 tokens
694+
second_batch_total = sum(len(sample["input_ids"]) for sample in batches[1])
695+
assert second_batch_total == 20
696+
697+
# Verify all fields are present and consistent
698+
for sample in batches[0]:
699+
assert "input_ids" in sample
700+
assert "attention_mask" in sample
701+
assert "labels" in sample
702+
assert len(sample["input_ids"]) == len(sample["attention_mask"])
703+
assert len(sample["input_ids"]) == len(sample["labels"])

0 commit comments

Comments
 (0)