Skip to content

Commit 710cc68

Browse files
authored
Merge pull request #42 from PickwickSoft/bugfix/#41/stream-not-closing-after-terminal-operation
Bugfix/#41/stream not closing after terminal operation
2 parents 178b928 + 9e21499 commit 710cc68

7 files changed

Lines changed: 190 additions & 11 deletions

File tree

pystreamapi/_streams/__base_stream.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,50 @@
1+
# pylint: disable=protected-access
2+
from __future__ import annotations
13
import functools
24
import itertools
35
from abc import abstractmethod
46
from builtins import reversed
57
from functools import cmp_to_key
6-
from typing import Iterable, Callable, Any, TypeVar, Iterator
8+
from typing import Iterable, Callable, Any, TypeVar, Iterator, TYPE_CHECKING
79

810
from pystreamapi.__optional import Optional
11+
from pystreamapi._itertools.tools import dropwhile
912
from pystreamapi._lazy.process import Process
1013
from pystreamapi._lazy.queue import ProcessQueue
1114
from pystreamapi._streams.error.__error import ErrorHandler
12-
from pystreamapi._itertools.tools import dropwhile
15+
if TYPE_CHECKING:
16+
from pystreamapi._streams.numeric.__numeric_base_stream import NumericBaseStream
1317

1418
K = TypeVar('K')
1519
_V = TypeVar('_V')
1620
_identity_missing = object()
1721

1822

23+
def _operation(func):
24+
"""
25+
Decorator to execute all the processes in the queue before executing the decorated function.
26+
To be applied to intermediate operations.
27+
"""
28+
@functools.wraps(func)
29+
def wrapper(*args, **kwargs):
30+
self: BaseStream = args[0]
31+
self._verify_open()
32+
return func(*args, **kwargs)
33+
34+
return wrapper
35+
36+
1937
def terminal(func):
2038
"""
2139
Decorator to execute all the processes in the queue before executing the decorated function.
2240
To be applied to terminal operations.
2341
"""
2442
@functools.wraps(func)
43+
@_operation
2544
def wrapper(*args, **kwargs):
2645
self: BaseStream = args[0]
27-
# pylint: disable=protected-access
2846
self._queue.execute_all()
47+
self._close()
2948
return func(*args, **kwargs)
3049

3150
return wrapper
@@ -47,6 +66,16 @@ class BaseStream(Iterable[K], ErrorHandler):
4766
def __init__(self, source: Iterable[K]):
4867
self._source = source
4968
self._queue = ProcessQueue()
69+
self._open = True
70+
71+
def _close(self):
72+
"""Close the stream."""
73+
self._open = False
74+
75+
def _verify_open(self):
76+
"""Verify if stream is open. If not, raise an exception."""
77+
if not self._open:
78+
raise RuntimeError("The stream has been closed")
5079

5180
@terminal
5281
def __iter__(self) -> Iterator[K]:
@@ -63,6 +92,7 @@ def concat(cls, *streams: "BaseStream[K]"):
6392
"""
6493
return cls(itertools.chain(*list(streams)))
6594

95+
@_operation
6696
def distinct(self) -> 'BaseStream[_V]':
6797
"""Returns a stream consisting of the distinct elements of this stream."""
6898
self._queue.append(Process(self.__distinct))
@@ -72,6 +102,7 @@ def __distinct(self):
72102
"""Removes duplicate elements from the stream."""
73103
self._source = list(set(self._source))
74104

105+
@_operation
75106
def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[_V]':
76107
"""
77108
Returns, if this stream is ordered, a stream consisting of the remaining elements of this
@@ -86,6 +117,7 @@ def __drop_while(self, predicate: Callable[[Any], bool]):
86117
"""Drops elements from the stream while the predicate is true."""
87118
self._source = list(dropwhile(predicate, self._source, self))
88119

120+
@_operation
89121
def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
90122
"""
91123
Returns a stream consisting of the elements of this stream that match the given predicate.
@@ -99,6 +131,7 @@ def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
99131
def _filter(self, predicate: Callable[[K], bool]):
100132
"""Implementation of filter. Should be implemented by subclasses."""
101133

134+
@_operation
102135
def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
103136
"""
104137
Returns a stream consisting of the results of replacing each element of this stream with
@@ -114,6 +147,7 @@ def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
114147
def _flat_map(self, predicate: Callable[[K], Iterable[_V]]):
115148
"""Implementation of flat_map. Should be implemented by subclasses."""
116149

150+
@_operation
117151
def group_by(self, key_mapper: Callable[[K], Any]) -> 'BaseStream[K]':
118152
"""
119153
Returns a Stream consisting of the results of grouping the elements of this stream
@@ -133,6 +167,7 @@ def __group_by(self, key_mapper: Callable[[Any], Any]):
133167
def _group_to_dict(self, key_mapper: Callable[[K], Any]) -> dict[K, list]:
134168
"""Groups the stream into a dictionary. Should be implemented by subclasses."""
135169

170+
@_operation
136171
def limit(self, max_size: int) -> 'BaseStream[_V]':
137172
"""
138173
Returns a stream consisting of the elements of this stream, truncated to be no longer
@@ -147,6 +182,7 @@ def __limit(self, max_size: int):
147182
"""Limits the stream to the first n elements."""
148183
self._source = itertools.islice(self._source, max_size)
149184

185+
@_operation
150186
def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
151187
"""
152188
Returns a stream consisting of the results of applying the given function to the elements
@@ -161,18 +197,20 @@ def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
161197
def _map(self, mapper: Callable[[K], _V]):
162198
"""Implementation of map. Should be implemented by subclasses."""
163199

164-
def map_to_int(self) -> 'BaseStream[_V]':
200+
@_operation
201+
def map_to_int(self) -> NumericBaseStream[_V]:
165202
"""
166203
Returns a stream consisting of the results of converting the elements of this stream to
167204
integers.
168205
"""
169206
self._queue.append(Process(self.__map_to_int))
170-
return self
207+
return self._to_numeric_stream()
171208

172209
def __map_to_int(self):
173210
"""Converts the stream to integers."""
174211
self._map(int)
175212

213+
@_operation
176214
def map_to_str(self) -> 'BaseStream[_V]':
177215
"""
178216
Returns a stream consisting of the results of converting the elements of this stream to
@@ -185,6 +223,7 @@ def __map_to_str(self):
185223
"""Converts the stream to strings."""
186224
self._map(str)
187225

226+
@_operation
188227
def peek(self, action: Callable) -> 'BaseStream[_V]':
189228
"""
190229
Returns a stream consisting of the elements of this stream, additionally performing the
@@ -196,9 +235,11 @@ def peek(self, action: Callable) -> 'BaseStream[_V]':
196235
return self
197236

198237
@abstractmethod
238+
@_operation
199239
def _peek(self, action: Callable):
200240
"""Implementation of peek. Should be implemented by subclasses."""
201241

242+
@_operation
202243
def reversed(self) -> 'BaseStream[_V]':
203244
"""
204245
Returns a stream consisting of the elements of this stream, with their order being
@@ -214,6 +255,7 @@ def __reversed(self):
214255
except TypeError:
215256
self._source = reversed(list(self._source))
216257

258+
@_operation
217259
def skip(self, n: int) -> 'BaseStream[_V]':
218260
"""
219261
Returns a stream consisting of the remaining elements of this stream after discarding the
@@ -228,6 +270,7 @@ def __skip(self, n: int):
228270
"""Skips the first n elements of the stream."""
229271
self._source = self._source[n:]
230272

273+
@_operation
231274
def sorted(self, comparator: Callable[[K], int] = None) -> 'BaseStream[_V]':
232275
"""
233276
Returns a stream consisting of the elements of this stream, sorted according to natural
@@ -243,6 +286,7 @@ def __sorted(self, comparator: Callable[[K], int] = None):
243286
else:
244287
self._source = sorted(self._source, key=cmp_to_key(comparator))
245288

289+
@_operation
246290
def take_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[_V]':
247291
"""
248292
Returns, if this stream is ordered, a stream consisting of the longest prefix of elements
@@ -257,8 +301,6 @@ def __take_while(self, predicate: Callable[[Any], bool]):
257301
"""Takes elements from the stream while the predicate is true."""
258302
self._source = list(itertools.takewhile(predicate, self._source))
259303

260-
# Terminal Operations:
261-
262304
@abstractmethod
263305
@terminal
264306
def all_match(self, predicate: Callable[[K], bool]):
@@ -373,3 +415,7 @@ def to_dict(self, key_mapper: Callable[[K], Any]) -> dict:
373415
374416
:param key_mapper:
375417
"""
418+
419+
@abstractmethod
420+
def _to_numeric_stream(self) -> NumericBaseStream[_V]:
421+
"""Converts a stream to a numeric stream. To be implemented by subclasses."""

pystreamapi/_streams/__parallel_stream.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,9 @@ def _set_parallelizer_src(self):
9191

9292
def __mapper(self, mapper):
9393
return lambda x: self._one(mapper=mapper, item=x)
94+
95+
def _to_numeric_stream(self):
96+
# pylint: disable=import-outside-toplevel
97+
from pystreamapi._streams.numeric.__parallel_numeric_stream import ParallelNumericStream
98+
self.__class__ = ParallelNumericStream
99+
return self

pystreamapi/_streams/__sequential_stream.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@ def reduce(self, predicate: Callable, identity=_identity_missing, depends_on_sta
6262
@stream.terminal
6363
def to_dict(self, key_mapper: Callable[[Any], Any]) -> dict:
6464
return self._group_to_dict(key_mapper)
65+
66+
def _to_numeric_stream(self):
67+
# pylint: disable=import-outside-toplevel
68+
from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream
69+
self.__class__ = SequentialNumericStream
70+
return self

pystreamapi/_streams/numeric/__numeric_base_stream.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,22 @@ def interquartile_range(self) -> Union[float, int, None]:
1818
Calculates the iterquartile range of a numerical Stream
1919
:return: The iterquartile range, can be int or float
2020
"""
21-
return self.third_quartile() - self.first_quartile() if len(self._source) > 0 else None
21+
return self._interquartile_range()
22+
23+
def _interquartile_range(self):
24+
"""Implementation of the interquartile range calculation"""
25+
return self._third_quartile() - self._first_quartile() if len(self._source) > 0 else None
2226

2327
@terminal
2428
def first_quartile(self) -> Union[float, int, None]:
2529
"""
2630
Calculates the first quartile of a numerical Stream
2731
:return: The first quartile, can be int or float
2832
"""
33+
return self._first_quartile()
34+
35+
def _first_quartile(self):
36+
"""Implementation of the first quartile calculation"""
2937
self._source = sorted(self._source)
3038
return self.__median(self._source[:(len(self._source)) // 2])
3139

@@ -59,7 +67,7 @@ def __median(source) -> Union[float, int, None]:
5967
@terminal
6068
def mode(self) -> Union[list[Union[int, float]], None]:
6169
"""
62-
Calculates the mode(s) (most frequently occurring element) of a numerical Stream
70+
Calculates the mode/modes (most frequently occurring element/elements) of a numerical Stream
6371
:return: The mode, can be int or float
6472
"""
6573
frequency = Counter(self._source)
@@ -90,5 +98,9 @@ def third_quartile(self) -> Union[float, int, None]:
9098
Calculates the third quartile of a numerical Stream
9199
:return: The third quartile, can be int or float
92100
"""
101+
return self._third_quartile()
102+
103+
def _third_quartile(self):
104+
"""Implementation of the third quartile calculation"""
93105
self._source = sorted(self._source)
94106
return self.__median(self._source[(len(self._source) + 1) // 2:])

pystreamapi/_streams/numeric/__parallel_numeric_stream.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def sum(self) -> Union[float, int, None]:
1919
_sum = self.__sum()
2020
return 0 if _sum == [] else _sum
2121

22-
@terminal
2322
def __sum(self):
2423
"""Parallel sum method"""
2524
self._set_parallelizer_src()

tests/test_stream_closed.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import unittest
2+
3+
from parameterized import parameterized_class
4+
5+
from pystreamapi._streams.__parallel_stream import ParallelStream
6+
from pystreamapi._streams.__sequential_stream import SequentialStream
7+
from pystreamapi._streams.numeric.__parallel_numeric_stream import ParallelNumericStream
8+
from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream
9+
10+
11+
@parameterized_class("stream", [
12+
[SequentialStream],
13+
[ParallelStream],
14+
[SequentialNumericStream],
15+
[ParallelNumericStream]])
16+
class BaseStreamClosed(unittest.TestCase):
17+
def test_closed_stream_throws_exception(self):
18+
# pylint: disable=too-many-statements
19+
closed_stream = self.stream([])
20+
closed_stream.for_each(lambda _: ...)
21+
22+
# Verify that all methods throw a RuntimeError
23+
with self.assertRaises(RuntimeError):
24+
list(closed_stream)
25+
26+
with self.assertRaises(RuntimeError):
27+
closed_stream.distinct()
28+
29+
with self.assertRaises(RuntimeError):
30+
closed_stream.drop_while(lambda x: True)
31+
32+
with self.assertRaises(RuntimeError):
33+
closed_stream.filter(lambda x: True)
34+
35+
with self.assertRaises(RuntimeError):
36+
closed_stream.flat_map(lambda x: [x])
37+
38+
with self.assertRaises(RuntimeError):
39+
closed_stream.group_by(lambda x: x)
40+
41+
with self.assertRaises(RuntimeError):
42+
closed_stream.limit(5)
43+
44+
with self.assertRaises(RuntimeError):
45+
closed_stream.map(lambda x: x)
46+
47+
with self.assertRaises(RuntimeError):
48+
closed_stream.map_to_int()
49+
50+
with self.assertRaises(RuntimeError):
51+
closed_stream.map_to_str()
52+
53+
with self.assertRaises(RuntimeError):
54+
closed_stream.peek(lambda x: None)
55+
56+
with self.assertRaises(RuntimeError):
57+
closed_stream.reversed()
58+
59+
with self.assertRaises(RuntimeError):
60+
closed_stream.skip(5)
61+
62+
with self.assertRaises(RuntimeError):
63+
closed_stream.sorted()
64+
65+
with self.assertRaises(RuntimeError):
66+
closed_stream.take_while(lambda x: True)
67+
68+
with self.assertRaises(RuntimeError):
69+
closed_stream.all_match(lambda x: True)
70+
71+
with self.assertRaises(RuntimeError):
72+
closed_stream.any_match(lambda x: True)
73+
74+
with self.assertRaises(RuntimeError):
75+
closed_stream.count()
76+
77+
with self.assertRaises(RuntimeError):
78+
closed_stream.find_any()
79+
80+
with self.assertRaises(RuntimeError):
81+
closed_stream.find_first()
82+
83+
with self.assertRaises(RuntimeError):
84+
closed_stream.for_each(lambda x: None)
85+
86+
with self.assertRaises(RuntimeError):
87+
closed_stream.none_match(lambda x: True)
88+
89+
with self.assertRaises(RuntimeError):
90+
closed_stream.min()
91+
92+
with self.assertRaises(RuntimeError):
93+
closed_stream.max()
94+
95+
with self.assertRaises(RuntimeError):
96+
closed_stream.reduce(lambda x, y: x + y)
97+
98+
with self.assertRaises(RuntimeError):
99+
closed_stream.to_list()
100+
101+
with self.assertRaises(RuntimeError):
102+
closed_stream.to_tuple()
103+
104+
with self.assertRaises(RuntimeError):
105+
closed_stream.to_set()
106+
107+
with self.assertRaises(RuntimeError):
108+
closed_stream.to_dict(lambda x: x)

0 commit comments

Comments
 (0)