From 7ba2c42d84ee48403376047fdca2af2d48e7c010 Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Wed, 3 Jun 2026 15:03:56 -0700 Subject: [PATCH] wip --- src/s2_sdk/_s2s/_append_session.py | 90 +++++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 9 deletions(-) diff --git a/src/s2_sdk/_s2s/_append_session.py b/src/s2_sdk/_s2s/_append_session.py index e2d6e03..b765dfb 100644 --- a/src/s2_sdk/_s2s/_append_session.py +++ b/src/s2_sdk/_s2s/_append_session.py @@ -1,8 +1,9 @@ import asyncio import logging from collections import deque -from collections.abc import AsyncGenerator, AsyncIterable -from typing import NamedTuple +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator +from dataclasses import dataclass +from typing import Any import s2_sdk._generated.s2.v1.s2_pb2 as pb from s2_sdk._client import HttpClient @@ -32,9 +33,11 @@ _QUEUE_MAX_SIZE = 100 -class _InflightInput(NamedTuple): +@dataclass(slots=True) +class _InflightInput: num_records: int encoded: bytes + ack_deadline: float | None = None async def run_append_session( @@ -149,11 +152,22 @@ async def _run_attempt( if encryption_key is not None: headers[_S2_ENCRYPTION_KEY_HEADER] = encryption_key + ack_deadline_updated = asyncio.Event() + for resend_inp in pending_resend: + resend_inp.ack_deadline = None + async with client.streaming_request( "POST", _stream_records_path(stream_name), headers=headers, - content=_body_gen(inflight_inputs, input_queue, pending_resend, compression), + content=_body_gen( + inflight_inputs, + input_queue, + pending_resend, + compression, + ack_deadline_updated, + ack_timeout, + ), frame_signal=frame_signal, ) as response: if response.status_code != 200: @@ -166,13 +180,14 @@ async def _run_attempt( messages = read_messages(response.aiter_bytes()) while True: try: - msg_body = await asyncio.wait_for( - messages.__anext__(), timeout=ack_timeout + msg_body = await _next_message( + messages, + inflight_inputs, + ack_deadline_updated, + ack_timeout, ) except StopAsyncIteration: break - except asyncio.TimeoutError: - raise ReadTimeoutError("Append session ack timeout") if attempt.value > 0: attempt.value = 0 @@ -185,6 +200,8 @@ async def _run_attempt( raise S2ClientError("Invalid ack: not monotonically increasing") prev_ack_end = ack.end.seq_num + if not inflight_inputs: + raise S2ClientError("Invalid ack: no inflight append") num_records_sent = inflight_inputs.popleft().num_records num_records_ackd = ack.end.seq_num - ack.start.seq_num if num_records_sent != num_records_ackd: @@ -203,11 +220,53 @@ async def _run_attempt( frame_signal.reset() +async def _next_message( + messages: AsyncIterator[bytes], + inflight_inputs: deque[_InflightInput], + ack_deadline_updated: asyncio.Event, + ack_timeout: float | None, +) -> bytes: + if ack_timeout is None: + return await messages.__anext__() + + pending_message: asyncio.Future[Any] | None = None + try: + while True: + deadline = inflight_inputs[0].ack_deadline if inflight_inputs else None + if deadline is not None: + try: + async with asyncio.timeout_at(deadline): + if pending_message is not None: + return await pending_message + return await messages.__anext__() + except TimeoutError: + raise ReadTimeoutError("Append session ack timeout") from None + + if pending_message is None: + pending_message = asyncio.ensure_future(messages.__anext__()) + deadline_update = asyncio.create_task(ack_deadline_updated.wait()) + done, _ = await asyncio.wait( + {pending_message, deadline_update}, + return_when=asyncio.FIRST_COMPLETED, + ) + if deadline_update not in done: + deadline_update.cancel() + + if pending_message in done: + return pending_message.result() + ack_deadline_updated.clear() + finally: + if pending_message is not None and not pending_message.done(): + pending_message.cancel() + + async def _body_gen( inflight_inputs: deque[_InflightInput], input_queue: asyncio.Queue[AppendInput | None], pending_resend: tuple[_InflightInput, ...], compression: Compression, + ack_deadline_updated: asyncio.Event | None = None, + ack_timeout: float | None = None, ) -> AsyncGenerator[bytes]: if pending_resend: logger.debug( @@ -215,7 +274,13 @@ async def _body_gen( len(pending_resend), sum(len(inp.encoded) for inp in pending_resend), ) + loop = asyncio.get_running_loop() for resend_inp in pending_resend: + resend_inp.ack_deadline = ( + loop.time() + ack_timeout if ack_timeout is not None else None + ) + if ack_deadline_updated is not None: + ack_deadline_updated.set() yield resend_inp.encoded if pending_resend: logger.debug("finished resending inflight appends") @@ -226,9 +291,16 @@ async def _body_gen( await input_queue.put(None) return encoded = _encode_input(inp, compression) + ack_deadline = loop.time() + ack_timeout if ack_timeout is not None else None inflight_inputs.append( - _InflightInput(num_records=len(inp.records), encoded=encoded) + _InflightInput( + num_records=len(inp.records), + encoded=encoded, + ack_deadline=ack_deadline, + ) ) + if ack_deadline_updated is not None: + ack_deadline_updated.set() yield encoded