Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 81 additions & 9 deletions src/s2_sdk/_s2s/_append_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import logging
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterable
from typing import NamedTuple
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from dataclasses import dataclass
from typing import Any

import s2_sdk._generated.s2.v1.s2_pb2 as pb
from s2_sdk._client import HttpClient
Expand Down Expand Up @@ -32,9 +33,11 @@
_QUEUE_MAX_SIZE = 100


class _InflightInput(NamedTuple):
@dataclass(slots=True)
class _InflightInput:
num_records: int
encoded: bytes
ack_deadline: float | None = None


async def run_append_session(
Expand Down Expand Up @@ -149,11 +152,22 @@ async def _run_attempt(
if encryption_key is not None:
headers[_S2_ENCRYPTION_KEY_HEADER] = encryption_key

ack_deadline_updated = asyncio.Event()
for resend_inp in pending_resend:
resend_inp.ack_deadline = None

async with client.streaming_request(
"POST",
_stream_records_path(stream_name),
headers=headers,
content=_body_gen(inflight_inputs, input_queue, pending_resend, compression),
content=_body_gen(
inflight_inputs,
input_queue,
pending_resend,
compression,
ack_deadline_updated,
ack_timeout,
),
frame_signal=frame_signal,
) as response:
if response.status_code != 200:
Expand All @@ -166,13 +180,14 @@ async def _run_attempt(
messages = read_messages(response.aiter_bytes())
while True:
try:
msg_body = await asyncio.wait_for(
messages.__anext__(), timeout=ack_timeout
msg_body = await _next_message(
messages,
inflight_inputs,
ack_deadline_updated,
ack_timeout,
)
except StopAsyncIteration:
break
except asyncio.TimeoutError:
raise ReadTimeoutError("Append session ack timeout")

if attempt.value > 0:
attempt.value = 0
Expand All @@ -185,6 +200,8 @@ async def _run_attempt(
raise S2ClientError("Invalid ack: not monotonically increasing")
prev_ack_end = ack.end.seq_num

if not inflight_inputs:
raise S2ClientError("Invalid ack: no inflight append")
num_records_sent = inflight_inputs.popleft().num_records
num_records_ackd = ack.end.seq_num - ack.start.seq_num
if num_records_sent != num_records_ackd:
Expand All @@ -203,19 +220,67 @@ async def _run_attempt(
frame_signal.reset()


async def _next_message(
messages: AsyncIterator[bytes],
inflight_inputs: deque[_InflightInput],
ack_deadline_updated: asyncio.Event,
ack_timeout: float | None,
) -> bytes:
if ack_timeout is None:
return await messages.__anext__()

pending_message: asyncio.Future[Any] | None = None
try:
while True:
deadline = inflight_inputs[0].ack_deadline if inflight_inputs else None
if deadline is not None:
try:
async with asyncio.timeout_at(deadline):
if pending_message is not None:
return await pending_message
return await messages.__anext__()
except TimeoutError:
raise ReadTimeoutError("Append session ack timeout") from None

if pending_message is None:
pending_message = asyncio.ensure_future(messages.__anext__())
deadline_update = asyncio.create_task(ack_deadline_updated.wait())
done, _ = await asyncio.wait(
{pending_message, deadline_update},
return_when=asyncio.FIRST_COMPLETED,
)
if deadline_update not in done:
deadline_update.cancel()

if pending_message in done:
return pending_message.result()
ack_deadline_updated.clear()
finally:
if pending_message is not None and not pending_message.done():
pending_message.cancel()


async def _body_gen(
inflight_inputs: deque[_InflightInput],
input_queue: asyncio.Queue[AppendInput | None],
pending_resend: tuple[_InflightInput, ...],
compression: Compression,
ack_deadline_updated: asyncio.Event | None = None,
ack_timeout: float | None = None,
) -> AsyncGenerator[bytes]:
if pending_resend:
logger.debug(
"resending inflight appends: count=%d bytes=%d",
len(pending_resend),
sum(len(inp.encoded) for inp in pending_resend),
)
loop = asyncio.get_running_loop()
for resend_inp in pending_resend:
resend_inp.ack_deadline = (
loop.time() + ack_timeout if ack_timeout is not None else None
)
if ack_deadline_updated is not None:
ack_deadline_updated.set()
yield resend_inp.encoded
if pending_resend:
logger.debug("finished resending inflight appends")
Expand All @@ -226,9 +291,16 @@ async def _body_gen(
await input_queue.put(None)
return
encoded = _encode_input(inp, compression)
ack_deadline = loop.time() + ack_timeout if ack_timeout is not None else None
inflight_inputs.append(
_InflightInput(num_records=len(inp.records), encoded=encoded)
_InflightInput(
num_records=len(inp.records),
encoded=encoded,
ack_deadline=ack_deadline,
)
)
if ack_deadline_updated is not None:
ack_deadline_updated.set()
yield encoded


Expand Down
Loading