Skip to content

Commit 39a3c8d

Browse files
committed
.buffer: make up_to optional
1 parent e1119dc commit 39a3c8d

8 files changed

Lines changed: 60 additions & 19 deletions

File tree

streamable/_afunctions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def buffer(
3535
aiterator: AsyncIterator[T],
36-
up_to: int,
36+
up_to: Optional[int] = None,
3737
) -> AsyncIterator[T]:
3838
return _aiterators.BufferAsyncIterator(aiterator, up_to)
3939

streamable/_aiterators.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
FIFOFutureResults,
3838
FutureResults,
3939
)
40-
from streamable._tools._async import AsyncFunction, empty_aiter
40+
from streamable._tools._async import AsyncFunction, NoopSemaphore, empty_aiter
4141
from streamable._tools._context import noop_context_manager
4242
from streamable._tools._error import ExceptionContainer, RaisingAsyncIterator
4343

@@ -59,7 +59,7 @@ class _BufferAsyncIterable(AsyncIterable[Union[T, ExceptionContainer]]):
5959
def __init__(
6060
self,
6161
iterator: AsyncIterator[T],
62-
up_to: int,
62+
up_to: Optional[int],
6363
) -> None:
6464
self.iterator = iterator
6565
self.up_to = up_to
@@ -76,7 +76,9 @@ def _lazy_buffer(self) -> "asyncio.Queue[Union[T, ExceptionContainer]]":
7676
@property
7777
def _lazy_slots(self) -> asyncio.Semaphore:
7878
if not self._slots:
79-
self._slots = asyncio.Semaphore(self.up_to)
79+
self._slots = (
80+
asyncio.Semaphore(self.up_to) if self.up_to else NoopSemaphore()
81+
)
8082
return self._slots
8183

8284
async def _buffer_upstream(self) -> None:
@@ -115,7 +117,7 @@ class BufferAsyncIterator(RaisingAsyncIterator[T]):
115117
def __init__(
116118
self,
117119
iterator: AsyncIterator[T],
118-
up_to: int,
120+
up_to: Optional[int],
119121
) -> None:
120122
super().__init__(_BufferAsyncIterable(iterator, up_to).__aiter__())
121123

streamable/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def buffer(
3232
iterator: Iterator[T],
33-
up_to: int,
33+
up_to: Optional[int] = None,
3434
) -> Iterator[T]:
3535
return _iterators.BufferIterator(iterator, up_to)
3636

streamable/_iterators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from streamable._tools._context import noop_context_manager
2929
from streamable._tools._observation import Observation
30+
from streamable._tools._threading import NoopSemaphore
3031
from streamable._tools._sentinel import STOP_ITERATION
3132
from streamable._tools._validation import validate_sync_flatten_iterable
3233

@@ -55,11 +56,11 @@ class _BufferIterable(Iterable[Union[T, ExceptionContainer]]):
5556
def __init__(
5657
self,
5758
iterator: Iterator[T],
58-
up_to: int,
59+
up_to: Optional[int],
5960
) -> None:
6061
self.iterator = iterator
6162
self._buffer: "queue.Queue[Union[T, ExceptionContainer]]" = queue.Queue()
62-
self._slots = Semaphore(up_to)
63+
self._slots = Semaphore(up_to) if up_to else NoopSemaphore()
6364
self._stopped = False
6465

6566
def _buffer_upstream(self) -> None:
@@ -99,7 +100,7 @@ class BufferIterator(RaisingIterator[T]):
99100
def __init__(
100101
self,
101102
iterator: Iterator[T],
102-
up_to: int,
103+
up_to: Optional[int],
103104
) -> None:
104105
super().__init__(_BufferIterable(iterator, up_to).__iter__())
105106

streamable/_stream.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,17 +300,15 @@ def cast(self, into: Type[U]) -> "stream[U]":
300300

301301
def buffer(
302302
self,
303-
up_to: int,
303+
up_to: Optional[int] = None,
304304
) -> "stream[T]":
305305
"""
306-
Buffer upstream elements into a bounded queue (max size ``up_to``), via a background task.
307-
308-
Allow to decouple the upstream production rate from the downstream consumption rate.
306+
Buffer upstream elements into a queue, via a background task, decoupling upstream production rate from downstream consumption rate.
309307
310308
The background task is a thread during a sync iteration, and an async task during an async iteration.
311309
312310
Args:
313-
up_to (``int``): The buffer size. Must be >= 1. When reached, upstream pulling pauses until an element is yielded out of the buffer.
311+
up_to (``int | None``): The buffer size, must be >= 1 when set. When reached, upstream pulling pauses until an element is yielded out of the buffer.
314312
315313
Returns:
316314
``stream[T]``: Upstream with buffering.
@@ -327,7 +325,8 @@ def buffer(
327325
time.sleep(1e-3)
328326
assert pulled == [0, 1, 2, 3, 4, 5]
329327
"""
330-
validate_int(up_to, gte=1, name="up_to")
328+
if up_to is not None:
329+
validate_int(up_to, gte=1, name="up_to")
331330
return BufferStream(self, up_to)
332331

333332
@overload
@@ -1082,7 +1081,7 @@ class BufferStream(DownStream[T, T]):
10821081
def __init__(
10831082
self,
10841083
upstream: stream[T],
1085-
up_to: int,
1084+
up_to: Optional[int],
10861085
) -> None:
10871086
super().__init__(upstream)
10881087
self._up_to = up_to

streamable/_tools/_async.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from asyncio import Semaphore
12
from typing import (
23
Any,
34
AsyncIterator,
45
Awaitable,
56
Callable,
67
Coroutine,
8+
Literal,
79
TypeVar,
810
)
911

@@ -25,3 +27,19 @@ async def awaitable_to_coroutine(aw: Awaitable[T]) -> T:
2527
async def empty_aiter() -> AsyncIterator[Any]:
2628
return
2729
yield # pragma: no cover
30+
31+
32+
class NoopSemaphore(Semaphore):
33+
__slots__ = ()
34+
35+
def __init__(self) -> None:
36+
pass
37+
38+
async def acquire(self) -> Literal[True]:
39+
return True
40+
41+
def release(self) -> None:
42+
return
43+
44+
def locked(self) -> bool:
45+
return False

tests/test_buffer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from typing import Any, AsyncIterable, Callable, Iterable, List
23

34
import pytest
@@ -26,7 +27,7 @@ def test_buffer_preserves_elements(itype: IterableType) -> None:
2627
assert alist_or_list(ints.buffer(5), itype) == list(INTEGERS)
2728

2829

29-
@pytest.mark.parametrize("buffer_size", [1, 10])
30+
@pytest.mark.parametrize("buffer_size", [1, 10, None])
3031
@pytest.mark.parametrize(
3132
"itype, slow_identity",
3233
[(Iterable, slow_identity), (AsyncIterable, async_slow_identity)],
@@ -39,9 +40,13 @@ def test_buffer_size_is_respected(
3940
buffering_ints_iter = aiter_or_iter(buffering_ints, itype)
4041
assert buffered == []
4142
assert anext_or_next(buffering_ints_iter, itype) == 0
42-
assert buffered == list(INTEGERS)[: buffer_size + 1]
43+
assert (
44+
buffered == list(INTEGERS)[: (buffer_size + 1) if buffer_size else sys.maxsize]
45+
)
4346
assert anext_or_next(buffering_ints_iter, itype) == 1
44-
assert buffered == list(INTEGERS)[: buffer_size + 2]
47+
assert (
48+
buffered == list(INTEGERS)[: (buffer_size + 2) if buffer_size else sys.maxsize]
49+
)
4550

4651

4752
@pytest.mark.parametrize(

tests/test_tools.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from streamable._tools._func import sidify
44
from streamable._tools._logging import logfmt_str_escape
5+
from streamable._tools import _async, _threading
56

67
from typing import Any, Callable, List
78

@@ -88,3 +89,18 @@ def test_logfmt_str_escape():
8889
assert logfmt_str_escape("in ts") == '"in ts"'
8990
assert logfmt_str_escape("in\\ts") == r'"in\\ts"'
9091
assert logfmt_str_escape('"ints"') == r'"\"ints\""'
92+
93+
94+
def test_noop_semaphore() -> None:
95+
semaphore = _threading.NoopSemaphore()
96+
assert semaphore.acquire()
97+
semaphore.release()
98+
assert not semaphore.locked()
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_noop_semaphore_async() -> None:
103+
semaphore = _async.NoopSemaphore()
104+
assert await semaphore.acquire()
105+
semaphore.release()
106+
assert not semaphore.locked()

0 commit comments

Comments
 (0)