Skip to content

Commit b7bd4cb

Browse files
Stream wrapper linter and code formatting
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 89b2858 commit b7bd4cb

2 files changed

Lines changed: 23 additions & 16 deletions

File tree

caikit_nlp/toolkit/data_stream_wrapper.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch
1717
DataLoaders, with minimal boilerplate.
1818
"""
19+
# Standard
1920
from typing import Any, Iterator, Optional
2021

2122
# Third Party
2223
from torch.utils.data import IterableDataset, get_worker_info
2324

2425
# First Party
25-
from caikit.core.toolkit import error_handler
2626
from caikit.core.data_model import DataStream
27+
from caikit.core.toolkit import error_handler
2728
import alog
2829

2930
log = alog.use_channel("STREAM_WRAP")
@@ -35,7 +36,13 @@ class SimpleIterableStreamWrapper(IterableDataset):
3536
compatability with PyTorch data loaders.
3637
"""
3738

38-
def __init__(self, stream: DataStream[Any], shuffle: bool, buffer_size: Optional[int]=None, seed: int=42):
39+
def __init__(
40+
self,
41+
stream: DataStream[Any],
42+
shuffle: bool,
43+
buffer_size: Optional[int] = None,
44+
seed: int = 42,
45+
):
3946
error.type_check("<NLP12855513E>", bool, shuffle=shuffle)
4047
error.type_check(
4148
"<NLP12813713E>", int, buffer_size=buffer_size, allow_none=True
@@ -68,7 +75,7 @@ def __iter__(self) -> Iterator[Any]:
6875
# shuffles completed so far to ensure that every worker will
6976
# shuffle the same way for each epoch.
7077
shuffle_seed = self._get_shuffle_seed(worker_info)
71-
log.debug(f"Reshuffling training data with seed: {shuffle_seed}")
78+
log.debug("Reshuffling training data with seed: {}".format(shuffle_seed))
7279
cycle_stream = self.stream.shuffle(self.buffer_size, seed=shuffle_seed)
7380
self._increment_shuffle_seed(worker_info)
7481
else:
@@ -113,10 +120,9 @@ def _increment_shuffle_seed(self, worker_info: Optional["WorkerInfo"]) -> None:
113120
else:
114121
worker_info.dataset.shuffles_completed += 1
115122

116-
def _get_stream_partition(self,
117-
cycle_stream: DataStream[Any],
118-
worker_id: int,
119-
num_workers: int):
123+
def _get_stream_partition(
124+
self, cycle_stream: DataStream[Any], worker_id: int, num_workers: int
125+
):
120126
"""Generator for a subset of a wrapped datastream; here, we simply traverse a stream,
121127
which is assumed to be preshuffled, and yield the elements that align with the
122128
scheme 'worker n gets every nth entry' after shuffling. This ensures that each
@@ -137,7 +143,7 @@ def _get_stream_partition(self,
137143

138144
def __len__(self) -> int:
139145
"""Gets the encapsulated stream length.
140-
146+
141147
Returns:
142148
int
143149
number of objects in the stream.

tests/toolkit/test_data_stream_wrapper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,25 @@ def test_iter_with_multi_worker():
7676
# Since we don't shuffle in this patched test, they should just be
7777
# divided as is.
7878
index_stream = [
79-
{"label": 0}, # goes to worker 0
80-
{"label": 1}, # goes to worker 1
81-
{"label": 2}, # goes to worker 2
82-
{"label": 3}, # goes to worker 0
83-
{"label": 4}, # goes to worker 1
84-
{"label": 5}, # goes to worker 2
79+
{"label": 0}, # goes to worker 0
80+
{"label": 1}, # goes to worker 1
81+
{"label": 2}, # goes to worker 2
82+
{"label": 3}, # goes to worker 0
83+
{"label": 4}, # goes to worker 1
84+
{"label": 5}, # goes to worker 2
8585
]
8686
worker_info = [
8787
(w1_info, [index_stream[0], index_stream[3]]),
8888
(w2_info, [index_stream[1], index_stream[4]]),
8989
(w3_info, [index_stream[2], index_stream[5]]),
9090
]
9191
for (dummy_worker, expected_elements) in worker_info:
92-
with mock.patch.object(worker, '_worker_info', dummy_worker):
92+
with mock.patch.object(worker, "_worker_info", dummy_worker):
9393
wrapper = SimpleIterableStreamWrapper(stream=index_stream, shuffle=False)
9494
for _ in range(NUM_CYCLES):
9595
actual_elements = list(wrapper)
9696
test_results.append(
97-
actual_elements == expected_elements and len(actual_elements) == len(expected_elements)
97+
actual_elements == expected_elements
98+
and len(actual_elements) == len(expected_elements)
9899
)
99100
assert all(test_results)

0 commit comments

Comments
 (0)