From fb11d8d9180cdb7a5c676b32619656d0846433b0 Mon Sep 17 00:00:00 2001 From: The TensorFlow Datasets Authors Date: Fri, 25 Apr 2025 01:32:51 -0700 Subject: [PATCH] Fix crash when taking a subset of a MultiSplitInfo with empty shard. PiperOrigin-RevId: 751317460 --- tensorflow_datasets/core/splits.py | 26 +++++++++++++---- tensorflow_datasets/core/splits_test.py | 37 +++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/tensorflow_datasets/core/splits.py b/tensorflow_datasets/core/splits.py index c70f66417e2..297c9a4cdb6 100644 --- a/tensorflow_datasets/core/splits.py +++ b/tensorflow_datasets/core/splits.py @@ -282,9 +282,11 @@ class MultiSplitInfo(SplitInfo): This should only be used to read data and not when producing data. """ - split_infos: list[SplitInfo] = dataclasses.field(default_factory=list) + split_infos: list[SplitInfo | SubSplitInfo] = dataclasses.field( + default_factory=list + ) - def __init__(self, name: str, split_infos: list[SplitInfo]): + def __init__(self, name: str, split_infos: list[SplitInfo | SubSplitInfo]): if not split_infos: raise ValueError('Need to pass a non-empty list of SplitInfos') object.__setattr__(self, 'split_infos', split_infos) @@ -315,6 +317,16 @@ def __repr__(self) -> str: f'split_infos={self.split_infos!r})' ) + @property + def examples_in_shards(self) -> list[int]: + result = [] + for split_info in self.split_infos: + if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)): + result.extend(split_info.examples_in_shards) + else: + result.extend(split_info.shard_lengths) + return result + @property def file_instructions(self) -> list[shard_utils.FileInstruction]: result = [] @@ -361,6 +373,10 @@ class SubSplitInfo: def shard_lengths(self) -> list[int]: return [f.take for f in self.file_instructions] + @property + def examples_in_shards(self) -> list[int]: + return [f.examples_in_shard for f in self.file_instructions] + @property def num_examples(self) -> int: """Returns the number of example in the subsplit.""" @@ -526,7 +542,7 @@ def _make_absolute_instructions( def _file_instructions_for_split( instruction: _AbsoluteInstruction, - split_info: SplitInfo, + split_info: SplitInfo | SubSplitInfo, ) -> list[shard_utils.FileInstruction]: """Returns the file instructions from the given instruction applied to the given split info.""" if not split_info.num_examples: @@ -537,9 +553,7 @@ def _file_instructions_for_split( return [] to = split_info.num_examples if instruction.to is None else instruction.to if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)): - examples_in_shards = [ - f.examples_in_shard for f in split_info.file_instructions - ] + examples_in_shards = split_info.examples_in_shards else: examples_in_shards = None return shard_utils.get_file_instructions( diff --git a/tensorflow_datasets/core/splits_test.py b/tensorflow_datasets/core/splits_test.py index 32a277349f8..d2b4115e631 100644 --- a/tensorflow_datasets/core/splits_test.py +++ b/tensorflow_datasets/core/splits_test.py @@ -255,6 +255,43 @@ def test_multi_split_sub_split(self): self.assertEqual(file_instruction.take, 2) self.assertEqual(file_instruction.examples_in_shard, 10) + def test_multi_split_empty_shard(self): + split_info = splits.MultiSplitInfo( + name='train', + split_infos=[ + splits.SplitInfo( + name='train', + shard_lengths=[5, 0, 5], + num_bytes=0, + filename_template=_filename_template( + split='train', data_dir='/abc' + ), + ), + ], + ) + split_dict = splits.SplitDict([split_info]) + sub_split = split_dict['train[:90%]'] + self.assertEqual(sub_split.name, 'train[:90%]') + self.assertEqual(sub_split.num_examples, 9) + self.assertEqual(sub_split.shard_lengths, [5, 4]) + self.assertEqual( + sub_split.file_instructions, + [ + shard_utils.FileInstruction( + filename='/abc/ds_name-train.tfrecord-00000-of-00003', + skip=0, + take=5, + examples_in_shard=5, + ), + shard_utils.FileInstruction( + filename='/abc/ds_name-train.tfrecord-00002-of-00003', + skip=0, + take=4, + examples_in_shard=5, + ), + ], + ) + class SplitsTest(testing.TestCase):