Skip to content

Commit 1392991

Browse files
Cache datastream length
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent b6d1b82 commit 1392991

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

caikit_nlp/toolkit/data_stream_wrapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ def __init__(
6363
self.stream = stream
6464
self.shuffle = shuffle
6565
self.buffer_size = buffer_size
66+
self.stream_length = len(stream)
6667
# Load the whole data set in memory
6768
if self.shuffle and buffer_size is None:
68-
self.buffer_size = len(stream)
69+
self.buffer_size = self.stream_length
6970
log.debug("Shuffling enabled? {}".format(self.shuffle))
7071
log.debug("Shuffling buffer size: {}".format(self.buffer_size))
7172

@@ -152,10 +153,12 @@ def _get_stream_partition(
152153
yield elem
153154

154155
def __len__(self) -> int:
155-
"""Gets the encapsulated stream length.
156+
"""Gets the encapsulated stream length. Note that we cache this attribute,
157+
because taking the length of a datastream (re-entrant generator) requires
158+
iterating until the end of it, which is expensive.
156159
157160
Returns:
158161
int
159162
number of objects in the stream.
160163
"""
161-
return len(self.stream)
164+
return self.stream_length

0 commit comments

Comments
 (0)