Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions tensorflow_datasets/core/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,11 @@ class MultiSplitInfo(SplitInfo):
This should only be used to read data and not when producing data.
"""

split_infos: list[SplitInfo] = dataclasses.field(default_factory=list)
split_infos: list[SplitInfo | SubSplitInfo] = dataclasses.field(
default_factory=list
)

def __init__(self, name: str, split_infos: list[SplitInfo]):
def __init__(self, name: str, split_infos: list[SplitInfo | SubSplitInfo]):
if not split_infos:
raise ValueError('Need to pass a non-empty list of SplitInfos')
object.__setattr__(self, 'split_infos', split_infos)
Expand Down Expand Up @@ -315,6 +317,16 @@ def __repr__(self) -> str:
f'split_infos={self.split_infos!r})'
)

@property
def examples_in_shards(self) -> list[int]:
result = []
for split_info in self.split_infos:
if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)):
result.extend(split_info.examples_in_shards)
else:
result.extend(split_info.shard_lengths)
return result

@property
def file_instructions(self) -> list[shard_utils.FileInstruction]:
result = []
Expand Down Expand Up @@ -361,6 +373,10 @@ class SubSplitInfo:
def shard_lengths(self) -> list[int]:
return [f.take for f in self.file_instructions]

@property
def examples_in_shards(self) -> list[int]:
return [f.examples_in_shard for f in self.file_instructions]

@property
def num_examples(self) -> int:
"""Returns the number of example in the subsplit."""
Expand Down Expand Up @@ -526,7 +542,7 @@ def _make_absolute_instructions(

def _file_instructions_for_split(
instruction: _AbsoluteInstruction,
split_info: SplitInfo,
split_info: SplitInfo | SubSplitInfo,
) -> list[shard_utils.FileInstruction]:
"""Returns the file instructions from the given instruction applied to the given split info."""
if not split_info.num_examples:
Expand All @@ -537,9 +553,7 @@ def _file_instructions_for_split(
return []
to = split_info.num_examples if instruction.to is None else instruction.to
if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)):
examples_in_shards = [
f.examples_in_shard for f in split_info.file_instructions
]
examples_in_shards = split_info.examples_in_shards
else:
examples_in_shards = None
return shard_utils.get_file_instructions(
Expand Down
37 changes: 37 additions & 0 deletions tensorflow_datasets/core/splits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,43 @@ def test_multi_split_sub_split(self):
self.assertEqual(file_instruction.take, 2)
self.assertEqual(file_instruction.examples_in_shard, 10)

def test_multi_split_empty_shard(self):
split_info = splits.MultiSplitInfo(
name='train',
split_infos=[
splits.SplitInfo(
name='train',
shard_lengths=[5, 0, 5],
num_bytes=0,
filename_template=_filename_template(
split='train', data_dir='/abc'
),
),
],
)
split_dict = splits.SplitDict([split_info])
sub_split = split_dict['train[:90%]']
self.assertEqual(sub_split.name, 'train[:90%]')
self.assertEqual(sub_split.num_examples, 9)
self.assertEqual(sub_split.shard_lengths, [5, 4])
self.assertEqual(
sub_split.file_instructions,
[
shard_utils.FileInstruction(
filename='/abc/ds_name-train.tfrecord-00000-of-00003',
skip=0,
take=5,
examples_in_shard=5,
),
shard_utils.FileInstruction(
filename='/abc/ds_name-train.tfrecord-00002-of-00003',
skip=0,
take=4,
examples_in_shard=5,
),
],
)


class SplitsTest(testing.TestCase):

Expand Down