Skip to content

Commit 175ca4c

Browse files
committed
Fix abstraction of sample readers, create direct interface to sample readers.
1 parent c560494 commit 175ca4c

7 files changed

Lines changed: 174 additions & 29 deletions

File tree

src/megatron/energon/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
CrudeJsonlDatasetFactory,
3434
CrudeSample,
3535
CrudeWebdataset,
36+
DatasetSampleReader,
3637
DefaultCrudeJsonlDatasetFactory,
3738
DefaultDecoderWebdatasetFactory,
3839
DefaultGenericWebdatasetFactory,
@@ -47,6 +48,7 @@
4748
MultiChoiceVQAWebdataset,
4849
OCRSample,
4950
OCRWebdataset,
51+
RawSample,
5052
Sample,
5153
SampleDecoder,
5254
SavableDataset,
@@ -61,7 +63,7 @@
6163
VQASample,
6264
VQAWebdataset,
6365
)
64-
from megatron.energon.loader import get_loader, get_savable_loader
66+
from megatron.energon.loader import get_dataset_reader, get_loader, get_savable_loader
6567
from megatron.energon.metadataset import (
6668
DatasetLoader,
6769
DatasetLoaderInterface,
@@ -136,6 +138,7 @@
136138
"CrudeWebdataset",
137139
"DatasetLoader",
138140
"DatasetLoaderInterface",
141+
"DatasetSampleReader",
139142
"DecodeFileStore",
140143
"DefaultCrudeJsonlDatasetFactory",
141144
"DefaultDecoderWebdatasetFactory",
@@ -152,6 +155,7 @@
152155
"GcDataset",
153156
"generic_batch",
154157
"generic_concat",
158+
"get_dataset_reader",
155159
"get_loader",
156160
"get_savable_loader",
157161
"get_train_dataset",
@@ -187,6 +191,7 @@
187191
"prepare_metadataset",
188192
"RepeatDataset",
189193
"reraise_exception",
194+
"RawSample",
190195
"Sample",
191196
"SampleDecoder",
192197
"SavableDataLoader",

src/megatron/energon/flavors/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from megatron.energon.flavors.base_dataset import (
55
BaseCoreDatasetFactory,
6+
DatasetSampleReader,
67
PinMemoryMixin,
8+
RawSample,
79
Sample,
810
SavableDataset,
911
)
@@ -55,6 +57,7 @@
5557
"CrudeJsonlDatasetFactory",
5658
"CrudeSample",
5759
"CrudeWebdataset",
60+
"DatasetSampleReader",
5861
"DefaultCrudeJsonlDatasetFactory",
5962
"DefaultDecoderWebdatasetFactory",
6063
"DefaultGenericWebdatasetFactory",
@@ -72,6 +75,7 @@
7275
"OCRSample",
7376
"OCRWebdataset",
7477
"PinMemoryMixin",
78+
"RawSample",
7579
"Sample",
7680
"SampleDecoder",
7781
"SavableDataset",

src/megatron/energon/flavors/base_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Callable,
1313
ClassVar,
1414
Dict,
15+
Generator,
1516
Generic,
1617
List,
1718
Optional,
@@ -37,6 +38,10 @@
3738
T_sample = TypeVar("T_sample", covariant=True)
3839
T = TypeVar("T", covariant=True)
3940

41+
# Must contain at least the fields __key__, __restore_key__, __sources__.
42+
# Other fields contain the data.
43+
RawSample = Dict[str, Any]
44+
4045

4146
class PinMemoryMixin:
4247
"""A mixin class providing a generic `pin_memory` function."""
@@ -395,6 +400,30 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s
395400
)
396401

397402

403+
class DatasetSampleReader(ABC):
404+
"""
405+
Interface for dataset readers for random access to samples.
406+
"""
407+
408+
@abstractmethod
409+
def __getitem__(self, idx: int) -> RawSample | None: ...
410+
411+
@abstractmethod
412+
def __len__(self) -> int: ...
413+
414+
@abstractmethod
415+
def __iter__(self) -> Generator[RawSample | None, None, None]: ...
416+
417+
@abstractmethod
418+
def close(self) -> None: ...
419+
420+
def __enter__(self) -> Self:
421+
return self
422+
423+
def __exit__(self, exc_type, exc_value, traceback) -> None:
424+
self.close()
425+
426+
398427
class BaseCoreDatasetFactory(Generic[T_sample], ABC):
399428
"""Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for
400429
joining in a joined dataset."""

src/megatron/energon/flavors/jsonl/ijsonl_reader.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
from abc import ABC
54
from typing import (
65
Callable,
76
Generator,
87
Optional,
8+
Self,
99
Tuple,
1010
TypeVar,
11+
overload,
1112
)
1213

1314
from megatron.energon.epathlib import EPath
15+
from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample
1416
from megatron.energon.flavors.jsonl.ijsonl import (
1517
CachedIJsonlOffsetReader,
1618
IJsonlFile,
@@ -23,7 +25,7 @@
2325
T_index = TypeVar("T_index", covariant=False)
2426

2527

26-
class IJsonlReader(ABC):
28+
class IJsonlReader(DatasetSampleReader):
2729
"""
2830
Class for reading indexed jsonl files containing json samples.
2931
@@ -102,6 +104,12 @@ def _get_item_by_sample_pointer(
102104
json=json_data,
103105
)
104106

107+
@overload
108+
def __getitem__(self, idx: int) -> FilteredSample | None: ...
109+
110+
@overload
111+
def __getitem__(self, idx: str) -> tuple[bytes, SourceInfo] | None: ...
112+
105113
def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInfo] | None:
106114
"""
107115
Get a sample from the dataset.
@@ -136,6 +144,10 @@ def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInf
136144
else:
137145
return sample
138146

147+
def __iter__(self) -> Generator[RawSample | None, None, None]:
148+
for idx in range(len(self)):
149+
yield self[idx]
150+
139151
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
140152
"""List all samples in the jsonl file.
141153
@@ -184,6 +196,12 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int],
184196
def get_total_size(self) -> int:
185197
return self.cached_offset_reader.get_total_size()
186198

199+
def __enter__(self) -> Self:
200+
return self
201+
202+
def __exit__(self, exc_type, exc_value, traceback) -> None:
203+
self.close()
204+
187205
def close(self):
188206
if self.ijsonl_file is not None:
189207
self.ijsonl_file.close()

src/megatron/energon/flavors/webdataset/itar_reader.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
Generic,
1212
List,
1313
Optional,
14+
Self,
1415
Tuple,
1516
TypeVar,
1617
Union,
17-
overload,
1818
)
1919

2020
from megatron.energon.epathlib import EPath
21+
from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample
2122
from megatron.energon.flavors.webdataset.config import (
2223
INDEX_SQLITE_FILENAME,
2324
skip_meta_re,
@@ -275,7 +276,7 @@ def __getitem__(self, idx: T_index) -> FilteredSample | None:
275276
return self._get_item_by_sample_pointer(sample_pointer, idx)
276277

277278

278-
class JoinIndexFileITarReader(ITarReader[int]):
279+
class JoinIndexFileITarReader(ITarReader[int], DatasetSampleReader):
279280
"""
280281
A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
281282
"""
@@ -369,6 +370,16 @@ def __len__(self) -> int:
369370

370371
return len(index_reader)
371372

373+
def __iter__(self) -> Generator[RawSample | None, None, None]:
374+
for idx in range(len(self)):
375+
yield self[idx]
376+
377+
def __enter__(self) -> Self:
378+
return self
379+
380+
def __exit__(self, exc_type, exc_value, traceback) -> None:
381+
self.close()
382+
372383
def __str__(self) -> str:
373384
return (
374385
f"JoinIndexFileITarReader("
@@ -378,7 +389,7 @@ def __str__(self) -> str:
378389
)
379390

380391

381-
class ShardInfosITarReader(ITarReader[int]):
392+
class ShardInfosITarReader(ITarReader[int], DatasetSampleReader):
382393
"""
383394
A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
384395
"""
@@ -469,6 +480,16 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
469480
def __len__(self) -> int:
470481
return self.shard_count_cumsum[-1]
471482

483+
def __iter__(self) -> Generator[RawSample | None, None, None]:
484+
for idx in range(len(self)):
485+
yield self[idx]
486+
487+
def __enter__(self) -> Self:
488+
return self
489+
490+
def __exit__(self, exc_type, exc_value, traceback) -> None:
491+
self.close()
492+
472493
def __str__(self) -> str:
473494
return (
474495
f"ShardInfosITarReader("
@@ -524,7 +545,6 @@ def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
524545
"""
525546
Get the ITarSample object for the given index.
526547
"""
527-
528548
return self.sqlite_reader.get_sample_pointer_by_key(sample_key)
529549

530550
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
@@ -577,25 +597,12 @@ def list_sample_parts(
577597
def get_total_size(self) -> int:
578598
return self.sqlite_reader.get_total_size()
579599

580-
@overload
581-
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ...
582-
583-
@overload
584-
def __getitem__(self, key: slice) -> "ITarReader": ...
585-
586-
def __getitem__(
587-
self, key: Union[slice, str]
588-
) -> Union[FilteredSample, tuple[bytes, SourceInfo], ITarReader]:
600+
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]:
589601
"""
590602
Either get a sample from the dataset by the sample key including all its entries,
591603
or get the bytes of a specific entry by the full filename of the entry inside the tar.
592604
"""
593605

594-
if isinstance(key, slice):
595-
# Return a new reader with a sliced samples tensor
596-
raise NotImplementedError("Slicing is not yet implemented")
597-
assert isinstance(key, str), "Invalid argument type for __getitem__"
598-
599606
if self.key_is_full_entryname:
600607
m = split_name_re.match(key)
601608
if not m:

src/megatron/energon/flavors/webdataset/sample_loader.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import torch
77

88
from megatron.energon.edataclass import edataclass
9-
from megatron.energon.flavors.base_dataset import FlexState, SavableDataset
10-
from megatron.energon.flavors.webdataset.itar_reader import ITarReader
9+
from megatron.energon.flavors.base_dataset import DatasetSampleReader, FlexState, SavableDataset
1110
from megatron.energon.flavors.webdataset.structs import FilteredSample
1211
from megatron.energon.rng import WorkerRng
1312
from megatron.energon.worker import WorkerConfig
@@ -34,10 +33,10 @@ class SliceState:
3433

3534

3635
class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):
37-
"""Internal class for loading samples from webdataset slices"""
36+
"""Internal class for sampling from random access datasets efficiently (the "core sampler")."""
3837

3938
#: The readers for each joined dataset
40-
join_readers: Sequence[ITarReader]
39+
join_readers: Sequence[DatasetSampleReader]
4140

4241
#: The offsets of the slice slices to iterate over for the current worker
4342
slice_offsets: Optional[Sequence[int]]
@@ -83,7 +82,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):
8382

8483
def __init__(
8584
self,
86-
join_readers: Sequence[ITarReader],
85+
join_readers: Sequence[DatasetSampleReader],
8786
workers_sample_slice_offsets: Sequence[Sequence[int]],
8887
*,
8988
worker_config: WorkerConfig,

0 commit comments

Comments
 (0)