Skip to content

Commit 38dc134

Browse files
committed
Fix abstraction of sample readers, create direct interface to sample readers.
1 parent c560494 commit 38dc134

7 files changed

Lines changed: 176 additions & 29 deletions

File tree

src/megatron/energon/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
CrudeJsonlDatasetFactory,
3434
CrudeSample,
3535
CrudeWebdataset,
36+
DatasetSampleReader,
3637
DefaultCrudeJsonlDatasetFactory,
3738
DefaultDecoderWebdatasetFactory,
3839
DefaultGenericWebdatasetFactory,
@@ -47,6 +48,7 @@
4748
MultiChoiceVQAWebdataset,
4849
OCRSample,
4950
OCRWebdataset,
51+
RawSample,
5052
Sample,
5153
SampleDecoder,
5254
SavableDataset,
@@ -61,7 +63,7 @@
6163
VQASample,
6264
VQAWebdataset,
6365
)
64-
from megatron.energon.loader import get_loader, get_savable_loader
66+
from megatron.energon.loader import get_dataset_reader, get_loader, get_savable_loader
6567
from megatron.energon.metadataset import (
6668
DatasetLoader,
6769
DatasetLoaderInterface,
@@ -136,6 +138,7 @@
136138
"CrudeWebdataset",
137139
"DatasetLoader",
138140
"DatasetLoaderInterface",
141+
"DatasetSampleReader",
139142
"DecodeFileStore",
140143
"DefaultCrudeJsonlDatasetFactory",
141144
"DefaultDecoderWebdatasetFactory",
@@ -152,6 +155,7 @@
152155
"GcDataset",
153156
"generic_batch",
154157
"generic_concat",
158+
"get_dataset_reader",
155159
"get_loader",
156160
"get_savable_loader",
157161
"get_train_dataset",
@@ -187,6 +191,7 @@
187191
"prepare_metadataset",
188192
"RepeatDataset",
189193
"reraise_exception",
194+
"RawSample",
190195
"Sample",
191196
"SampleDecoder",
192197
"SavableDataLoader",

src/megatron/energon/flavors/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from megatron.energon.flavors.base_dataset import (
55
BaseCoreDatasetFactory,
6+
DatasetSampleReader,
67
PinMemoryMixin,
8+
RawSample,
79
Sample,
810
SavableDataset,
911
)
@@ -55,6 +57,7 @@
5557
"CrudeJsonlDatasetFactory",
5658
"CrudeSample",
5759
"CrudeWebdataset",
60+
"DatasetSampleReader",
5861
"DefaultCrudeJsonlDatasetFactory",
5962
"DefaultDecoderWebdatasetFactory",
6063
"DefaultGenericWebdatasetFactory",
@@ -72,6 +75,7 @@
7275
"OCRSample",
7376
"OCRWebdataset",
7477
"PinMemoryMixin",
78+
"RawSample",
7579
"Sample",
7680
"SampleDecoder",
7781
"SavableDataset",

src/megatron/energon/flavors/base_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Callable,
1313
ClassVar,
1414
Dict,
15+
Generator,
1516
Generic,
1617
List,
1718
Optional,
@@ -37,6 +38,10 @@
3738
T_sample = TypeVar("T_sample", covariant=True)
3839
T = TypeVar("T", covariant=True)
3940

41+
# Must contain at least the fields __key__, __restore_key__, __sources__.
42+
# Other fields contain the data.
43+
RawSample = Dict[str, Any]
44+
4045

4146
class PinMemoryMixin:
4247
"""A mixin class providing a generic `pin_memory` function."""
@@ -395,6 +400,30 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s
395400
)
396401

397402

403+
class DatasetSampleReader(ABC):
404+
"""
405+
Interface for dataset readers for random access to samples.
406+
"""
407+
408+
@abstractmethod
409+
def __getitem__(self, idx: int) -> RawSample | None: ...
410+
411+
@abstractmethod
412+
def __len__(self) -> int: ...
413+
414+
@abstractmethod
415+
def __iter__(self) -> Generator[RawSample | None, None, None]: ...
416+
417+
@abstractmethod
418+
def close(self) -> None: ...
419+
420+
def __enter__(self) -> Self:
421+
return self
422+
423+
def __exit__(self, exc_type, exc_value, traceback) -> None:
424+
self.close()
425+
426+
398427
class BaseCoreDatasetFactory(Generic[T_sample], ABC):
399428
"""Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for
400429
joining in a joined dataset."""

src/megatron/energon/flavors/jsonl/ijsonl_reader.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
from abc import ABC
54
from typing import (
65
Callable,
76
Generator,
87
Optional,
98
Tuple,
109
TypeVar,
10+
overload,
1111
)
1212

13+
from typing_extensions import Self
14+
1315
from megatron.energon.epathlib import EPath
16+
from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample
1417
from megatron.energon.flavors.jsonl.ijsonl import (
1518
CachedIJsonlOffsetReader,
1619
IJsonlFile,
@@ -23,7 +26,7 @@
2326
T_index = TypeVar("T_index", covariant=False)
2427

2528

26-
class IJsonlReader(ABC):
29+
class IJsonlReader(DatasetSampleReader):
2730
"""
2831
Class for reading indexed jsonl files containing json samples.
2932
@@ -102,6 +105,12 @@ def _get_item_by_sample_pointer(
102105
json=json_data,
103106
)
104107

108+
@overload
109+
def __getitem__(self, idx: int) -> FilteredSample | None: ...
110+
111+
@overload
112+
def __getitem__(self, idx: str) -> tuple[bytes, SourceInfo] | None: ...
113+
105114
def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInfo] | None:
106115
"""
107116
Get a sample from the dataset.
@@ -136,6 +145,10 @@ def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInf
136145
else:
137146
return sample
138147

148+
def __iter__(self) -> Generator[RawSample | None, None, None]:
149+
for idx in range(len(self)):
150+
yield self[idx]
151+
139152
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
140153
"""List all samples in the jsonl file.
141154
@@ -184,6 +197,12 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int],
184197
def get_total_size(self) -> int:
185198
return self.cached_offset_reader.get_total_size()
186199

200+
def __enter__(self) -> Self:
201+
return self
202+
203+
def __exit__(self, exc_type, exc_value, traceback) -> None:
204+
self.close()
205+
187206
def close(self):
188207
if self.ijsonl_file is not None:
189208
self.ijsonl_file.close()

src/megatron/energon/flavors/webdataset/itar_reader.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
Tuple,
1515
TypeVar,
1616
Union,
17-
overload,
1817
)
1918

19+
from typing_extensions import Self
20+
2021
from megatron.energon.epathlib import EPath
22+
from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample
2123
from megatron.energon.flavors.webdataset.config import (
2224
INDEX_SQLITE_FILENAME,
2325
skip_meta_re,
@@ -275,7 +277,7 @@ def __getitem__(self, idx: T_index) -> FilteredSample | None:
275277
return self._get_item_by_sample_pointer(sample_pointer, idx)
276278

277279

278-
class JoinIndexFileITarReader(ITarReader[int]):
280+
class JoinIndexFileITarReader(ITarReader[int], DatasetSampleReader):
279281
"""
280282
A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
281283
"""
@@ -369,6 +371,16 @@ def __len__(self) -> int:
369371

370372
return len(index_reader)
371373

374+
def __iter__(self) -> Generator[RawSample | None, None, None]:
375+
for idx in range(len(self)):
376+
yield self[idx]
377+
378+
def __enter__(self) -> Self:
379+
return self
380+
381+
def __exit__(self, exc_type, exc_value, traceback) -> None:
382+
self.close()
383+
372384
def __str__(self) -> str:
373385
return (
374386
f"JoinIndexFileITarReader("
@@ -378,7 +390,7 @@ def __str__(self) -> str:
378390
)
379391

380392

381-
class ShardInfosITarReader(ITarReader[int]):
393+
class ShardInfosITarReader(ITarReader[int], DatasetSampleReader):
382394
"""
383395
A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
384396
"""
@@ -469,6 +481,16 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
469481
def __len__(self) -> int:
470482
return self.shard_count_cumsum[-1]
471483

484+
def __iter__(self) -> Generator[RawSample | None, None, None]:
485+
for idx in range(len(self)):
486+
yield self[idx]
487+
488+
def __enter__(self) -> Self:
489+
return self
490+
491+
def __exit__(self, exc_type, exc_value, traceback) -> None:
492+
self.close()
493+
472494
def __str__(self) -> str:
473495
return (
474496
f"ShardInfosITarReader("
@@ -524,7 +546,6 @@ def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
524546
"""
525547
Get the ITarSample object for the given index.
526548
"""
527-
528549
return self.sqlite_reader.get_sample_pointer_by_key(sample_key)
529550

530551
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
@@ -577,25 +598,12 @@ def list_sample_parts(
577598
def get_total_size(self) -> int:
578599
return self.sqlite_reader.get_total_size()
579600

580-
@overload
581-
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ...
582-
583-
@overload
584-
def __getitem__(self, key: slice) -> "ITarReader": ...
585-
586-
def __getitem__(
587-
self, key: Union[slice, str]
588-
) -> Union[FilteredSample, tuple[bytes, SourceInfo], ITarReader]:
601+
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]:
589602
"""
590603
Either get a sample from the dataset by the sample key including all its entries,
591604
or get the bytes of a specific entry by the full filename of the entry inside the tar.
592605
"""
593606

594-
if isinstance(key, slice):
595-
# Return a new reader with a sliced samples tensor
596-
raise NotImplementedError("Slicing is not yet implemented")
597-
assert isinstance(key, str), "Invalid argument type for __getitem__"
598-
599607
if self.key_is_full_entryname:
600608
m = split_name_re.match(key)
601609
if not m:

src/megatron/energon/flavors/webdataset/sample_loader.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import torch
77

88
from megatron.energon.edataclass import edataclass
9-
from megatron.energon.flavors.base_dataset import FlexState, SavableDataset
10-
from megatron.energon.flavors.webdataset.itar_reader import ITarReader
9+
from megatron.energon.flavors.base_dataset import DatasetSampleReader, FlexState, SavableDataset
1110
from megatron.energon.flavors.webdataset.structs import FilteredSample
1211
from megatron.energon.rng import WorkerRng
1312
from megatron.energon.worker import WorkerConfig
@@ -34,10 +33,10 @@ class SliceState:
3433

3534

3635
class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):
37-
"""Internal class for loading samples from webdataset slices"""
36+
"""Internal class for sampling from random access datasets efficiently (the "core sampler")."""
3837

3938
#: The readers for each joined dataset
40-
join_readers: Sequence[ITarReader]
39+
join_readers: Sequence[DatasetSampleReader]
4140

4241
#: The offsets of the slice slices to iterate over for the current worker
4342
slice_offsets: Optional[Sequence[int]]
@@ -83,7 +82,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):
8382

8483
def __init__(
8584
self,
86-
join_readers: Sequence[ITarReader],
85+
join_readers: Sequence[DatasetSampleReader],
8786
workers_sample_slice_offsets: Sequence[Sequence[int]],
8887
*,
8988
worker_config: WorkerConfig,

0 commit comments

Comments
 (0)