1414 Tuple ,
1515 TypeVar ,
1616 Union ,
17- overload ,
1817)
1918
19+ from typing_extensions import Self
20+
2021from megatron .energon .epathlib import EPath
22+ from megatron .energon .flavors .base_dataset import DatasetSampleReader , RawSample
2123from megatron .energon .flavors .webdataset .config import (
2224 INDEX_SQLITE_FILENAME ,
2325 skip_meta_re ,
@@ -275,7 +277,7 @@ def __getitem__(self, idx: T_index) -> FilteredSample | None:
275277 return self ._get_item_by_sample_pointer (sample_pointer , idx )
276278
277279
278- class JoinIndexFileITarReader (ITarReader [int ]):
280+ class JoinIndexFileITarReader (ITarReader [int ], DatasetSampleReader ):
279281 """
280282 A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
281283 """
@@ -369,6 +371,16 @@ def __len__(self) -> int:
369371
370372 return len (index_reader )
371373
374+ def __iter__ (self ) -> Generator [RawSample | None , None , None ]:
375+ for idx in range (len (self )):
376+ yield self [idx ]
377+
378+ def __enter__ (self ) -> Self :
379+ return self
380+
381+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
382+ self .close ()
383+
372384 def __str__ (self ) -> str :
373385 return (
374386 f"JoinIndexFileITarReader("
@@ -378,7 +390,7 @@ def __str__(self) -> str:
378390 )
379391
380392
381- class ShardInfosITarReader (ITarReader [int ]):
393+ class ShardInfosITarReader (ITarReader [int ], DatasetSampleReader ):
382394 """
383395 A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
384396 """
@@ -469,6 +481,16 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
469481 def __len__ (self ) -> int :
470482 return self .shard_count_cumsum [- 1 ]
471483
484+ def __iter__ (self ) -> Generator [RawSample | None , None , None ]:
485+ for idx in range (len (self )):
486+ yield self [idx ]
487+
488+ def __enter__ (self ) -> Self :
489+ return self
490+
491+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
492+ self .close ()
493+
472494 def __str__ (self ) -> str :
473495 return (
474496 f"ShardInfosITarReader("
@@ -524,7 +546,6 @@ def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
524546 """
525547 Get the ITarSample object for the given index.
526548 """
527-
528549 return self .sqlite_reader .get_sample_pointer_by_key (sample_key )
529550
530551 def list_all_samples (self ) -> Generator [Tuple [str , int , int ], None , None ]:
@@ -577,25 +598,12 @@ def list_sample_parts(
577598 def get_total_size (self ) -> int :
578599 return self .sqlite_reader .get_total_size ()
579600
580- @overload
581- def __getitem__ (self , key : str ) -> Union [FilteredSample , tuple [bytes , SourceInfo ]]: ...
582-
583- @overload
584- def __getitem__ (self , key : slice ) -> "ITarReader" : ...
585-
586- def __getitem__ (
587- self , key : Union [slice , str ]
588- ) -> Union [FilteredSample , tuple [bytes , SourceInfo ], ITarReader ]:
601+ def __getitem__ (self , key : str ) -> Union [FilteredSample , tuple [bytes , SourceInfo ]]:
589602 """
590603 Either get a sample from the dataset by the sample key including all its entries,
591604 or get the bytes of a specific entry by the full filename of the entry inside the tar.
592605 """
593606
594- if isinstance (key , slice ):
595- # Return a new reader with a sliced samples tensor
596- raise NotImplementedError ("Slicing is not yet implemented" )
597- assert isinstance (key , str ), "Invalid argument type for __getitem__"
598-
599607 if self .key_is_full_entryname :
600608 m = split_name_re .match (key )
601609 if not m :
0 commit comments