Skip to content

Commit aaf4bfe

Browse files
authored
fix: linger timeout closes async iterator in append_record_batches (#28)
1 parent 496acfe commit aaf4bfe

3 files changed

Lines changed: 48 additions & 21 deletions

File tree

src/s2_sdk/_batching.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,47 @@ async def append_record_batches(
5050
validate_batching(batching.max_records, batching.max_bytes)
5151
acc = BatchAccumulator(batching)
5252
linger_secs = batching.linger.total_seconds()
53-
aiter = records.__aiter__()
53+
record_iter = records.__aiter__()
54+
pending_next = None
5455

55-
while True:
56-
try:
57-
record = await anext(aiter)
58-
except StopAsyncIteration:
59-
break
56+
try:
57+
while True:
58+
if pending_next is not None:
59+
record = await pending_next
60+
pending_next = None
61+
else:
62+
record = await anext(record_iter, None)
63+
if record is None:
64+
break
6065

61-
acc.add(record)
62-
if acc.is_full():
63-
yield acc.take()
64-
continue
66+
acc.add(record)
6567

66-
try:
6768
deadline = (
68-
asyncio.get_event_loop().time() + linger_secs
69+
asyncio.get_running_loop().time() + linger_secs
6970
if linger_secs > 0
7071
else None
7172
)
7273
while not acc.is_full():
7374
if deadline is not None:
74-
remaining = deadline - asyncio.get_event_loop().time()
75+
remaining = deadline - asyncio.get_running_loop().time()
7576
if remaining <= 0:
7677
break
77-
record = await asyncio.wait_for(anext(aiter), timeout=remaining)
78+
next_task = asyncio.create_task(anext(record_iter, None))
79+
done, _ = await asyncio.wait({next_task}, timeout=remaining)
80+
if not done:
81+
pending_next = next_task
82+
break
83+
record = next_task.result()
7884
else:
79-
record = await anext(aiter)
85+
record = await anext(record_iter, None)
86+
if record is None:
87+
break
8088
acc.add(record)
81-
except StopAsyncIteration:
82-
pass
83-
except TimeoutError:
84-
pass
8589

86-
if not acc.is_empty():
8790
yield acc.take()
91+
finally:
92+
if pending_next is not None:
93+
pending_next.cancel()
8894

8995

9096
async def append_inputs(

src/s2_sdk/_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class Batching:
208208
"""Maximum metered bytes per batch. Must be between 8 and 1 MiB. Default is 1 MiB."""
209209

210210
linger: timedelta = field(default_factory=lambda: timedelta(milliseconds=5))
211-
"""Maximum time to wait for more records before flushing a partial batch. Default is 5 ms.
211+
"""Maximum time to wait for more records before flushing a batch. Default is 5 ms.
212212
213213
Note:
214214
If set to 0, batches are flushed only when ``max_records`` or ``max_bytes`` is reached.

tests/test_batching.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from datetime import timedelta
23

34
import pytest
@@ -63,6 +64,26 @@ async def test_oversized_record_passes():
6364
assert len(batches[0]) == 1
6465

6566

67+
@pytest.mark.asyncio
68+
async def test_linger_flushes_batches():
69+
async def delayed_records():
70+
yield Record(body=b"r1")
71+
await asyncio.sleep(0.1) # Longer than linger
72+
yield Record(body=b"r2")
73+
yield Record(body=b"r3")
74+
75+
batches = []
76+
async for batch in append_record_batches(
77+
delayed_records(),
78+
batching=Batching(max_records=10, linger=timedelta(seconds=0.01)),
79+
):
80+
batches.append(batch)
81+
82+
assert len(batches) == 2
83+
assert len(batches[0]) == 1 # r1 (linger expired)
84+
assert len(batches[1]) == 2 # r2, r3
85+
86+
6687
@pytest.mark.asyncio
6788
async def test_append_inputs_skips_empty_batches():
6889
inputs = []

0 commit comments

Comments
 (0)