1+ # pylint: disable=protected-access
2+ from __future__ import annotations
13import functools
24import itertools
35from abc import abstractmethod
46from builtins import reversed
57from functools import cmp_to_key
6- from typing import Iterable , Callable , Any , TypeVar , Iterator
8+ from typing import Iterable , Callable , Any , TypeVar , Iterator , TYPE_CHECKING
79
810from pystreamapi .__optional import Optional
11+ from pystreamapi ._itertools .tools import dropwhile
912from pystreamapi ._lazy .process import Process
1013from pystreamapi ._lazy .queue import ProcessQueue
1114from pystreamapi ._streams .error .__error import ErrorHandler
12- from pystreamapi ._itertools .tools import dropwhile
15+ if TYPE_CHECKING :
16+ from pystreamapi ._streams .numeric .__numeric_base_stream import NumericBaseStream
1317
1418K = TypeVar ('K' )
1519_V = TypeVar ('_V' )
1620_identity_missing = object ()
1721
1822
23+ def _operation (func ):
24+ """
25+ Decorator to execute all the processes in the queue before executing the decorated function.
26+ To be applied to intermediate operations.
27+ """
28+ @functools .wraps (func )
29+ def wrapper (* args , ** kwargs ):
30+ self : BaseStream = args [0 ]
31+ self ._verify_open ()
32+ return func (* args , ** kwargs )
33+
34+ return wrapper
35+
36+
1937def terminal (func ):
2038 """
2139 Decorator to execute all the processes in the queue before executing the decorated function.
2240 To be applied to terminal operations.
2341 """
2442 @functools .wraps (func )
43+ @_operation
2544 def wrapper (* args , ** kwargs ):
2645 self : BaseStream = args [0 ]
27- # pylint: disable=protected-access
2846 self ._queue .execute_all ()
47+ self ._close ()
2948 return func (* args , ** kwargs )
3049
3150 return wrapper
@@ -47,6 +66,16 @@ class BaseStream(Iterable[K], ErrorHandler):
4766 def __init__ (self , source : Iterable [K ]):
4867 self ._source = source
4968 self ._queue = ProcessQueue ()
69+ self ._open = True
70+
71+ def _close (self ):
72+ """Close the stream."""
73+ self ._open = False
74+
75+ def _verify_open (self ):
76+ """Verify if stream is open. If not, raise an exception."""
77+ if not self ._open :
78+ raise RuntimeError ("The stream has been closed" )
5079
5180 @terminal
5281 def __iter__ (self ) -> Iterator [K ]:
@@ -63,6 +92,7 @@ def concat(cls, *streams: "BaseStream[K]"):
6392 """
6493 return cls (itertools .chain (* list (streams )))
6594
95+ @_operation
6696 def distinct (self ) -> 'BaseStream[_V]' :
6797 """Returns a stream consisting of the distinct elements of this stream."""
6898 self ._queue .append (Process (self .__distinct ))
@@ -72,6 +102,7 @@ def __distinct(self):
72102 """Removes duplicate elements from the stream."""
73103 self ._source = list (set (self ._source ))
74104
105+ @_operation
75106 def drop_while (self , predicate : Callable [[K ], bool ]) -> 'BaseStream[_V]' :
76107 """
77108 Returns, if this stream is ordered, a stream consisting of the remaining elements of this
@@ -86,6 +117,7 @@ def __drop_while(self, predicate: Callable[[Any], bool]):
86117 """Drops elements from the stream while the predicate is true."""
87118 self ._source = list (dropwhile (predicate , self ._source , self ))
88119
120+ @_operation
89121 def filter (self , predicate : Callable [[K ], bool ]) -> 'BaseStream[K]' :
90122 """
91123 Returns a stream consisting of the elements of this stream that match the given predicate.
@@ -99,6 +131,7 @@ def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
99131 def _filter (self , predicate : Callable [[K ], bool ]):
100132 """Implementation of filter. Should be implemented by subclasses."""
101133
134+ @_operation
102135 def flat_map (self , predicate : Callable [[K ], Iterable [_V ]]) -> 'BaseStream[_V]' :
103136 """
104137 Returns a stream consisting of the results of replacing each element of this stream with
@@ -114,6 +147,7 @@ def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
114147 def _flat_map (self , predicate : Callable [[K ], Iterable [_V ]]):
115148 """Implementation of flat_map. Should be implemented by subclasses."""
116149
150+ @_operation
117151 def group_by (self , key_mapper : Callable [[K ], Any ]) -> 'BaseStream[K]' :
118152 """
119153 Returns a Stream consisting of the results of grouping the elements of this stream
@@ -133,6 +167,7 @@ def __group_by(self, key_mapper: Callable[[Any], Any]):
133167 def _group_to_dict (self , key_mapper : Callable [[K ], Any ]) -> dict [K , list ]:
134168 """Groups the stream into a dictionary. Should be implemented by subclasses."""
135169
170+ @_operation
136171 def limit (self , max_size : int ) -> 'BaseStream[_V]' :
137172 """
138173 Returns a stream consisting of the elements of this stream, truncated to be no longer
@@ -147,6 +182,7 @@ def __limit(self, max_size: int):
147182 """Limits the stream to the first n elements."""
148183 self ._source = itertools .islice (self ._source , max_size )
149184
185+ @_operation
150186 def map (self , mapper : Callable [[K ], _V ]) -> 'BaseStream[_V]' :
151187 """
152188 Returns a stream consisting of the results of applying the given function to the elements
@@ -161,18 +197,20 @@ def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
161197 def _map (self , mapper : Callable [[K ], _V ]):
162198 """Implementation of map. Should be implemented by subclasses."""
163199
164- def map_to_int (self ) -> 'BaseStream[_V]' :
200+ @_operation
201+ def map_to_int (self ) -> NumericBaseStream [_V ]:
165202 """
166203 Returns a stream consisting of the results of converting the elements of this stream to
167204 integers.
168205 """
169206 self ._queue .append (Process (self .__map_to_int ))
170- return self
207+ return self . _to_numeric_stream ()
171208
172209 def __map_to_int (self ):
173210 """Converts the stream to integers."""
174211 self ._map (int )
175212
213+ @_operation
176214 def map_to_str (self ) -> 'BaseStream[_V]' :
177215 """
178216 Returns a stream consisting of the results of converting the elements of this stream to
@@ -185,6 +223,7 @@ def __map_to_str(self):
185223 """Converts the stream to strings."""
186224 self ._map (str )
187225
226+ @_operation
188227 def peek (self , action : Callable ) -> 'BaseStream[_V]' :
189228 """
190229 Returns a stream consisting of the elements of this stream, additionally performing the
@@ -196,9 +235,11 @@ def peek(self, action: Callable) -> 'BaseStream[_V]':
196235 return self
197236
198237 @abstractmethod
238+ @_operation
199239 def _peek (self , action : Callable ):
200240 """Implementation of peek. Should be implemented by subclasses."""
201241
242+ @_operation
202243 def reversed (self ) -> 'BaseStream[_V]' :
203244 """
204245 Returns a stream consisting of the elements of this stream, with their order being
@@ -214,6 +255,7 @@ def __reversed(self):
214255 except TypeError :
215256 self ._source = reversed (list (self ._source ))
216257
258+ @_operation
217259 def skip (self , n : int ) -> 'BaseStream[_V]' :
218260 """
219261 Returns a stream consisting of the remaining elements of this stream after discarding the
@@ -228,6 +270,7 @@ def __skip(self, n: int):
228270 """Skips the first n elements of the stream."""
229271 self ._source = self ._source [n :]
230272
273+ @_operation
231274 def sorted (self , comparator : Callable [[K ], int ] = None ) -> 'BaseStream[_V]' :
232275 """
233276 Returns a stream consisting of the elements of this stream, sorted according to natural
@@ -243,6 +286,7 @@ def __sorted(self, comparator: Callable[[K], int] = None):
243286 else :
244287 self ._source = sorted (self ._source , key = cmp_to_key (comparator ))
245288
289+ @_operation
246290 def take_while (self , predicate : Callable [[K ], bool ]) -> 'BaseStream[_V]' :
247291 """
248292 Returns, if this stream is ordered, a stream consisting of the longest prefix of elements
@@ -257,8 +301,6 @@ def __take_while(self, predicate: Callable[[Any], bool]):
257301 """Takes elements from the stream while the predicate is true."""
258302 self ._source = list (itertools .takewhile (predicate , self ._source ))
259303
260- # Terminal Operations:
261-
262304 @abstractmethod
263305 @terminal
264306 def all_match (self , predicate : Callable [[K ], bool ]):
@@ -373,3 +415,7 @@ def to_dict(self, key_mapper: Callable[[K], Any]) -> dict:
373415
374416 :param key_mapper:
375417 """
418+
419+ @abstractmethod
420+ def _to_numeric_stream (self ) -> NumericBaseStream [_V ]:
421+ """Converts a stream to a numeric stream. To be implemented by subclasses."""
0 commit comments