1616and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch
1717DataLoaders, with minimal boilerplate.
1818"""
19+ # Standard
1920from typing import Any , Iterator , Optional
2021
2122# Third Party
2223from torch .utils .data import IterableDataset , get_worker_info
2324
2425# First Party
25- from caikit .core .toolkit import error_handler
2626from caikit .core .data_model import DataStream
27+ from caikit .core .toolkit import error_handler
2728import alog
2829
2930log = alog .use_channel ("STREAM_WRAP" )
@@ -35,7 +36,13 @@ class SimpleIterableStreamWrapper(IterableDataset):
3536 compatability with PyTorch data loaders.
3637 """
3738
38- def __init__ (self , stream : DataStream [Any ], shuffle : bool , buffer_size : Optional [int ]= None , seed : int = 42 ):
39+ def __init__ (
40+ self ,
41+ stream : DataStream [Any ],
42+ shuffle : bool ,
43+ buffer_size : Optional [int ] = None ,
44+ seed : int = 42 ,
45+ ):
3946 error .type_check ("<NLP12855513E>" , bool , shuffle = shuffle )
4047 error .type_check (
4148 "<NLP12813713E>" , int , buffer_size = buffer_size , allow_none = True
@@ -68,7 +75,7 @@ def __iter__(self) -> Iterator[Any]:
6875 # shuffles completed so far to ensure that every worker will
6976 # shuffle the same way for each epoch.
7077 shuffle_seed = self ._get_shuffle_seed (worker_info )
71- log .debug (f "Reshuffling training data with seed: { shuffle_seed } " )
78+ log .debug ("Reshuffling training data with seed: {}" . format ( shuffle_seed ) )
7279 cycle_stream = self .stream .shuffle (self .buffer_size , seed = shuffle_seed )
7380 self ._increment_shuffle_seed (worker_info )
7481 else :
@@ -113,10 +120,9 @@ def _increment_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> None:
113120 else :
114121 worker_info .dataset .shuffles_completed += 1
115122
116- def _get_stream_partition (self ,
117- cycle_stream : DataStream [Any ],
118- worker_id : int ,
119- num_workers : int ):
123+ def _get_stream_partition (
124+ self , cycle_stream : DataStream [Any ], worker_id : int , num_workers : int
125+ ):
120126 """Generator for a subset of a wrapped datastream; here, we simply traverse a stream,
121127 which is assumed to be preshuffled, and yield the elements that align with the
122128 scheme 'worker n gets every nth entry' after shuffling. This ensures that each
@@ -137,7 +143,7 @@ def _get_stream_partition(self,
137143
138144 def __len__ (self ) -> int :
139145 """Gets the encapsulated stream length.
140-
146+
141147 Returns:
142148 int
143149 number of objects in the stream.
0 commit comments