1111 Generic ,
1212 List ,
1313 Optional ,
14+ Self ,
1415 Tuple ,
1516 TypeVar ,
1617 Union ,
17- overload ,
1818)
1919
2020from megatron .energon .epathlib import EPath
21+ from megatron .energon .flavors .base_dataset import DatasetSampleReader , RawSample
2122from megatron .energon .flavors .webdataset .config import (
2223 INDEX_SQLITE_FILENAME ,
2324 skip_meta_re ,
@@ -275,7 +276,7 @@ def __getitem__(self, idx: T_index) -> FilteredSample | None:
275276 return self ._get_item_by_sample_pointer (sample_pointer , idx )
276277
277278
278- class JoinIndexFileITarReader (ITarReader [int ]):
279+ class JoinIndexFileITarReader (ITarReader [int ], DatasetSampleReader ):
279280 """
280281 A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
281282 """
@@ -369,6 +370,16 @@ def __len__(self) -> int:
369370
370371 return len (index_reader )
371372
373+ def __iter__ (self ) -> Generator [RawSample | None , None , None ]:
374+ for idx in range (len (self )):
375+ yield self [idx ]
376+
377+ def __enter__ (self ) -> Self :
378+ return self
379+
380+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
381+ self .close ()
382+
372383 def __str__ (self ) -> str :
373384 return (
374385 f"JoinIndexFileITarReader("
@@ -378,7 +389,7 @@ def __str__(self) -> str:
378389 )
379390
380391
381- class ShardInfosITarReader (ITarReader [int ]):
392+ class ShardInfosITarReader (ITarReader [int ], DatasetSampleReader ):
382393 """
383394 A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
384395 """
@@ -469,6 +480,16 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
469480 def __len__ (self ) -> int :
470481 return self .shard_count_cumsum [- 1 ]
471482
483+ def __iter__ (self ) -> Generator [RawSample | None , None , None ]:
484+ for idx in range (len (self )):
485+ yield self [idx ]
486+
487+ def __enter__ (self ) -> Self :
488+ return self
489+
490+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
491+ self .close ()
492+
472493 def __str__ (self ) -> str :
473494 return (
474495 f"ShardInfosITarReader("
@@ -524,7 +545,6 @@ def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
524545 """
525546 Get the ITarSample object for the given index.
526547 """
527-
528548 return self .sqlite_reader .get_sample_pointer_by_key (sample_key )
529549
530550 def list_all_samples (self ) -> Generator [Tuple [str , int , int ], None , None ]:
@@ -577,25 +597,12 @@ def list_sample_parts(
577597 def get_total_size (self ) -> int :
578598 return self .sqlite_reader .get_total_size ()
579599
580- @overload
581- def __getitem__ (self , key : str ) -> Union [FilteredSample , tuple [bytes , SourceInfo ]]: ...
582-
583- @overload
584- def __getitem__ (self , key : slice ) -> "ITarReader" : ...
585-
586- def __getitem__ (
587- self , key : Union [slice , str ]
588- ) -> Union [FilteredSample , tuple [bytes , SourceInfo ], ITarReader ]:
600+ def __getitem__ (self , key : str ) -> Union [FilteredSample , tuple [bytes , SourceInfo ]]:
589601 """
590602 Either get a sample from the dataset by the sample key including all its entries,
591603 or get the bytes of a specific entry by the full filename of the entry inside the tar.
592604 """
593605
594- if isinstance (key , slice ):
595- # Return a new reader with a sliced samples tensor
596- raise NotImplementedError ("Slicing is not yet implemented" )
597- assert isinstance (key , str ), "Invalid argument type for __getitem__"
598-
599606 if self .key_is_full_entryname :
600607 m = split_name_re .match (key )
601608 if not m :
0 commit comments