diff --git a/frontend/docs/pages/reference/changelog/python.mdx b/frontend/docs/pages/reference/changelog/python.mdx index 34ff5e5972..ad593864f1 100644 --- a/frontend/docs/pages/reference/changelog/python.mdx +++ b/frontend/docs/pages/reference/changelog/python.mdx @@ -1,5 +1,12 @@ {/* AUTOGENERATED — do not edit. Run `task sync-changelog` to regenerate from sdks/python/CHANGELOG.md */} +## v1.33.13 - 2026-06-26 + +### Fixed + +- Reworks the internals of event pushes and stream event pubs to use `grpc.aio` directly to limit threading overhead on high-throughput workers. +- Reworks how logs are forwarded to the engine to publish from a thread instead of from an `asyncio.Task` to try to avoid event loop blocking issues. + ## v1.33.12 - 2026-06-21 ### Fixed diff --git a/sdks/python/CHANGELOG.md b/sdks/python/CHANGELOG.md index 2b5364bcdf..7d27670b79 100644 --- a/sdks/python/CHANGELOG.md +++ b/sdks/python/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.33.13] - 2026-06-26 + +### Fixed + +- Reworks the internals of event pushes and stream event pubs to use `grpc.aio` directly to limit threading overhead on high-throughput workers. +- Reworks how logs are forwarded to the engine to publish from a thread instead of from an `asyncio.Task` to try to avoid event loop blocking issues. + ## [1.33.12] - 2026-06-21 ### Fixed diff --git a/sdks/python/examples/events/test_event.py b/sdks/python/examples/events/test_event.py index bc32671cff..32b5fa1100 100644 --- a/sdks/python/examples/events/test_event.py +++ b/sdks/python/examples/events/test_event.py @@ -280,6 +280,41 @@ async def test_async_event_bulk_push(hatchet: Hatchet) -> None: assert returned_event.key == hatchet.config.apply_namespace(original_event.key) +@pytest.mark.asyncio(loop_scope="session") +async def test_event_bulk_push(hatchet: Hatchet) -> None: + events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1", "should_skip": False}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2", "should_skip": False}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3", "should_skip": False}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), + ] + + e = hatchet.event.bulk_push(events) + + assert len(e) == 3 + + # Sort both lists of events by their key to ensure comparison order + sorted_events = sorted(events, key=lambda x: x.key) + sorted_returned_events = sorted(e, key=lambda x: x.key) + + # Check that the returned events match the original events + for original_event, returned_event in zip( + sorted_events, sorted_returned_events, strict=False + ): + assert returned_event.key == hatchet.config.apply_namespace(original_event.key) + + @pytest.fixture(scope="function") def test_run_id() -> str: return str(uuid4()) diff --git a/sdks/python/hatchet_sdk/clients/events.py b/sdks/python/hatchet_sdk/clients/events.py index 8288802959..b4aee9cfea 100644 --- a/sdks/python/hatchet_sdk/clients/events.py +++ b/sdks/python/hatchet_sdk/clients/events.py @@ -26,6 +26,7 @@ PushEventRequest, PutLogRequest, PutStreamEventRequest, + PutStreamEventResponse, ) from hatchet_sdk.contracts.events_pb2 import Event as EventProto from hatchet_sdk.contracts.events_pb2 import Events as EventsProto @@ -115,11 +116,14 @@ class EventClient(BaseRestClient): def __init__(self, config: ClientConfig): super().__init__(config) - conn = new_conn(config, False) - self.events_service_client = EventsServiceStub(conn) + self._client: EventsServiceStub | None = None + self._aio_client: EventsServiceStub | None = None self.token = config.token self.namespace = config.namespace + self._retrying_aio_put_stream_event = tenacity_retry( + self._put_stream_event, self.client_config.tenacity + ) def _wra(self, client: ApiClient) -> WorkflowRunsApi: return WorkflowRunsApi(client) @@ -127,18 +131,76 @@ def _wra(self, client: ApiClient) -> WorkflowRunsApi: def _ea(self, client: ApiClient) -> EventApi: return EventApi(client) + def _get_or_create_aio_client(self) -> EventsServiceStub: + if self._aio_client is None: + self._aio_client = EventsServiceStub(new_conn(self.client_config, True)) + + return self._aio_client + + def _get_or_create_client(self) -> EventsServiceStub: + if self._client is None: + self._client = EventsServiceStub(new_conn(self.client_config, False)) + + return self._client + + def _prepare_push_event_request( + self, + key: str, + payload: JSONSerializableMapping, + options: PushEventOptions, + additional_metadata: JSONSerializableMapping | None = None, + priority: Priority | None = None, + scope: str | None = None, + namespace_override: str | None = None, + ) -> PushEventRequest: + namespace = namespace_override or options.namespace or self.namespace + namespaced_key = self.client_config.apply_namespace(key, namespace) + + try: + meta = _inject_source_info( + additional_metadata or options.additional_metadata + ) + meta_bytes = json.dumps(meta) + except Exception as e: + raise ValueError("Error encoding meta") from e + + try: + payload_str = json.dumps(payload) + except (TypeError, ValueError) as e: + raise ValueError("Error encoding payload") from e + + return PushEventRequest( + key=namespaced_key, + payload=payload_str, + event_timestamp=proto_timestamp_now(), + additional_metadata=meta_bytes, + priority=priority or options.priority, + scope=scope or options.scope, + ) + async def aio_push( self, event_key: str, payload: JSONSerializableMapping, - options: PushEventOptions = PushEventOptions(), + options: PushEventOptions | None = None, additional_metadata: JSONSerializableMapping | None = None, priority: Priority | None = None, scope: str | None = None, ) -> Event: - return await asyncio.to_thread( - self.push, - event_key=event_key, + if options is not None: + warnings.warn( + "The `options` parameter is deprecated and will be removed in v2.0.0. The namespace should be set on the `ClientConfig`", + stacklevel=2, + category=DeprecationWarning, + ) + else: + options = PushEventOptions() + + aio_client = self._get_or_create_aio_client() + push_event = tenacity_retry(aio_client.Push, self.client_config.tenacity) + + request = self._prepare_push_event_request( + key=event_key, payload=payload, options=options, additional_metadata=additional_metadata, @@ -146,12 +208,12 @@ async def aio_push( scope=scope, ) - async def aio_bulk_push( - self, - events: list[BulkPushEventWithMetadata], - options: BulkPushEventOptions | None = None, - ) -> list[Event]: - return await asyncio.to_thread(self.bulk_push, events=events, options=options) + response = cast( + EventProto, + await push_event(request, metadata=create_authorization_header(self.token)), # type: ignore[misc] + ) + + return Event.from_proto(response) def push( self, @@ -171,69 +233,74 @@ def push( else: options = PushEventOptions() - namespace = options.namespace or self.namespace - namespaced_event_key = self.client_config.apply_namespace(event_key, namespace) - push_event = tenacity_retry( - self.events_service_client.Push, self.client_config.tenacity - ) - - try: - meta = _inject_source_info( - additional_metadata or options.additional_metadata - ) - meta_bytes = json.dumps(meta) - except Exception as e: - raise ValueError("Error encoding meta") from e - - try: - payload_str = json.dumps(payload) - except (TypeError, ValueError) as e: - raise ValueError("Error encoding payload") from e + client = self._get_or_create_client() + push_event = tenacity_retry(client.Push, self.client_config.tenacity) - request = PushEventRequest( - key=namespaced_event_key, - payload=payload_str, - event_timestamp=proto_timestamp_now(), - additional_metadata=meta_bytes, - priority=priority or options.priority, - scope=scope or options.scope, + request = self._prepare_push_event_request( + key=event_key, + payload=payload, + options=options, + additional_metadata=additional_metadata, + priority=priority, + scope=scope, ) response = cast( EventProto, push_event(request, metadata=create_authorization_header(self.token)), ) + return Event.from_proto(response) - def _create_push_event_request( + async def aio_bulk_push( self, - event: BulkPushEventWithMetadata, - namespace: str, - ) -> PushEventRequest: - event_key = self.client_config.apply_namespace(event.key, namespace) - payload = event.payload + events: list[BulkPushEventWithMetadata], + options: BulkPushEventOptions | None = None, + ) -> list[Event]: + if options: + warnings.warn( + "The `options` parameter is deprecated and will be removed in v2.0.0. The namespace should be set on the `ClientConfig`", + stacklevel=2, + category=DeprecationWarning, + ) + else: + options = BulkPushEventOptions() - meta = _inject_source_info(event.additional_metadata) + namespace = options.namespace or self.namespace - try: - meta_str = json.dumps(meta) - except Exception as e: - raise ValueError("Error encoding meta") from e + bulk_request = BulkPushEventRequest( + events=[ + self._prepare_push_event_request( + key=event.key, + payload=event.payload, + additional_metadata=event.additional_metadata, + options=PushEventOptions(), + priority=( + Priority(event.priority) + if isinstance(event.priority, int) + else event.priority + ), + scope=event.scope, + namespace_override=namespace, + ) + for event in events + ] + ) - try: - serialized_payload = json.dumps(payload) - except (TypeError, ValueError) as e: - raise ValueError("Error serializing payload") from e + client = self._get_or_create_aio_client() - return PushEventRequest( - key=event_key, - payload=serialized_payload, - event_timestamp=proto_timestamp_now(), - additional_metadata=meta_str, - priority=event.priority, - scope=event.scope, + bulk_push = tenacity_retry(client.BulkPush, self.client_config.tenacity) + + response = cast( + EventsProto, + await bulk_push( # type: ignore[misc] + bulk_request, + metadata=create_authorization_header(self.token), + ), ) + return [Event.from_proto(event) for event in response.events] + def bulk_push( self, events: list[BulkPushEventWithMetadata], @@ -249,20 +316,35 @@ def bulk_push( options = BulkPushEventOptions() namespace = options.namespace or self.namespace - bulk_push = tenacity_retry( - self.events_service_client.BulkPush, self.client_config.tenacity - ) bulk_request = BulkPushEventRequest( events=[ - self._create_push_event_request(event, namespace) for event in events + self._prepare_push_event_request( + key=event.key, + payload=event.payload, + additional_metadata=event.additional_metadata, + options=PushEventOptions(), + priority=( + Priority(event.priority) + if isinstance(event.priority, int) + else event.priority + ), + scope=event.scope, + namespace_override=namespace, + ) + for event in events ] ) + client = self._get_or_create_client() + + bulk_push = tenacity_retry(client.BulkPush, self.client_config.tenacity) + response = cast( EventsProto, bulk_push(bulk_request, metadata=create_authorization_header(self.token)), ) + return [Event.from_proto(event) for event in response.events] def log( @@ -276,9 +358,8 @@ def log( logger.warning("truncating log message to 10,000 characters") message = message[:10_000] - put_log = tenacity_retry( - self.events_service_client.PutLog, self.client_config.tenacity - ) + client = self._get_or_create_client() + put_log = tenacity_retry(client.PutLog, self.client_config.tenacity) request = PutLogRequest( task_run_external_id=step_run_id, created_at=proto_timestamp_now(), @@ -289,10 +370,9 @@ def log( put_log(request, metadata=create_authorization_header(self.token)) - def stream(self, data: str | bytes, step_run_id: str, index: int) -> None: - put_stream_event = tenacity_retry( - self.events_service_client.PutStreamEvent, self.client_config.tenacity - ) + def _create_put_stream_event_request( + self, data: str | bytes, step_run_id: str, index: int + ) -> PutStreamEventRequest: if isinstance(data, str): data_bytes = data.encode("utf-8") elif isinstance(data, bytes): @@ -300,18 +380,45 @@ def stream(self, data: str | bytes, step_run_id: str, index: int) -> None: else: raise ValueError("Invalid data type. Expected str, bytes, or file.") - request = PutStreamEventRequest( + return PutStreamEventRequest( task_run_external_id=step_run_id, created_at=proto_timestamp_now(), message=data_bytes, event_index=index, ) + def stream(self, data: str | bytes, step_run_id: str, index: int) -> None: + client = self._get_or_create_client() + put_stream_event = tenacity_retry( + client.PutStreamEvent, self.client_config.tenacity + ) + request = self._create_put_stream_event_request(data, step_run_id, index) + try: put_stream_event(request, metadata=create_authorization_header(self.token)) except Exception: raise + async def _put_stream_event( + self, + request: PutStreamEventRequest, + metadata: tuple[tuple[str, str]], + ) -> PutStreamEventResponse: + client = self._get_or_create_aio_client() + return cast( + PutStreamEventResponse, + await client.PutStreamEvent( # type: ignore[misc] + request, metadata=metadata + ), + ) + + async def aio_stream(self, data: str | bytes, step_run_id: str, index: int) -> None: + request = self._create_put_stream_event_request(data, step_run_id, index) + + await self._retrying_aio_put_stream_event( + request, create_authorization_header(self.token) + ) + async def aio_list( self, offset: int | None = None, diff --git a/sdks/python/hatchet_sdk/context/context.py b/sdks/python/hatchet_sdk/context/context.py index f04de2daa5..45b05d4275 100644 --- a/sdks/python/hatchet_sdk/context/context.py +++ b/sdks/python/hatchet_sdk/context/context.py @@ -544,7 +544,16 @@ async def aio_put_stream(self, data: str | bytes) -> None: :param data: The data to send to the Hatchet API. Can be a string or bytes. :return: None """ - await asyncio.to_thread(self.put_stream, data) + try: + ix = self._increment_stream_index() + + await self._event_client.aio_stream( + data=data, + step_run_id=self._step_run_id, + index=ix, + ) + except Exception: + logger.exception("error putting stream event") def refresh_timeout(self, increment_by: Duration) -> None: """ diff --git a/sdks/python/hatchet_sdk/opentelemetry/instrumentor.py b/sdks/python/hatchet_sdk/opentelemetry/instrumentor.py index 136f9a0b89..09a4449006 100644 --- a/sdks/python/hatchet_sdk/opentelemetry/instrumentor.py +++ b/sdks/python/hatchet_sdk/opentelemetry/instrumentor.py @@ -404,6 +404,18 @@ def _instrument(self, **kwargs: InstrumentKwargs) -> None: self._wrap_bulk_push_event, ) + wrap_function_wrapper( + hatchet_sdk, + "clients.events.EventClient.aio_push", + self._wrap_aio_push_event, + ) + + wrap_function_wrapper( + hatchet_sdk, + "clients.events.EventClient.aio_bulk_push", + self._wrap_aio_bulk_push_event, + ) + wrap_function_wrapper( hatchet_sdk, "clients.admin.AdminClient.run_workflow", @@ -650,6 +662,129 @@ def _wrap_bulk_push_event( options, ) + async def _wrap_aio_push_event( + self, + wrapped: Callable[..., Coroutine[None, None, Event]], + instance: EventClient, + args: tuple[ + str, + JSONSerializableMapping, + PushEventOptions | None, + JSONSerializableMapping | None, + Priority | None, + str | None, + ], + kwargs: dict[ + str, + str | JSONSerializableMapping | PushEventOptions | Priority | None, + ], + ) -> Event: + params = self.extract_bound_args(wrapped, args, kwargs) + + event_key = cast(str, params[0]) + payload = cast(JSONSerializableMapping, params[1]) + options = cast(PushEventOptions | None, params[2]) + additional_metadata = cast(JSONSerializableMapping | None, params[3]) + priority = cast(Priority | None, params[4]) + scope = cast(str | None, params[5]) + + additional_metadata = additional_metadata or ( + options.additional_metadata if options else {} + ) + + priority_option = options.priority if options else None + + if isinstance(priority_option, int): + priority_option = Priority(priority_option) + + priority = priority or priority_option + scope = scope or (options.scope if options else None) + + attributes = { + OTelAttribute.EVENT_KEY: event_key, + OTelAttribute.ACTION_PAYLOAD: json.dumps(payload, default=str), + OTelAttribute.ADDITIONAL_METADATA: json.dumps( + additional_metadata, default=str + ), + OTelAttribute.PRIORITY: priority, + OTelAttribute.FILTER_SCOPE: scope, + } + + with self._tracer.start_as_current_span( + "hatchet.push_event", + attributes={ + "instrumentor": "hatchet", + **{ + f"hatchet.{k.value}": v + for k, v in attributes.items() + if v + and k not in self.config.otel.excluded_attributes + and v != "{}" + and v != "[]" + }, + }, + kind=SpanKind.PRODUCER, + ): + return await wrapped( + event_key, + payload, + None, + _inject_source_info( + _inject_traceparent_into_metadata(dict(additional_metadata)), + ), + priority, + scope, + ) + + async def _wrap_aio_bulk_push_event( + self, + wrapped: Callable[ + [list[BulkPushEventWithMetadata], BulkPushEventOptions | None], + Coroutine[None, None, list[Event]], + ], + instance: EventClient, + args: tuple[ + list[BulkPushEventWithMetadata], + BulkPushEventOptions | None, + ], + kwargs: dict[ + str, list[BulkPushEventWithMetadata] | BulkPushEventOptions | None + ], + ) -> list[Event]: + params = self.extract_bound_args(wrapped, args, kwargs) + + bulk_events = cast(list[BulkPushEventWithMetadata], params[0]) + options = cast(BulkPushEventOptions | None, params[1]) + + num_bulk_events = len(bulk_events) + unique_event_keys = {event.key for event in bulk_events} + + with self._tracer.start_as_current_span( + "hatchet.bulk_push_event", + attributes={ + "instrumentor": "hatchet", + "hatchet.num_events": num_bulk_events, + "hatchet.unique_event_keys": json.dumps(unique_event_keys, default=str), + }, + kind=SpanKind.PRODUCER, + ): + bulk_events_with_meta = [ + BulkPushEventWithMetadata( + **event.model_dump(exclude={"additional_metadata"}), + additional_metadata=_inject_source_info( + _inject_traceparent_into_metadata( + event.additional_metadata, + ) + ), + ) + for event in bulk_events + ] + + return await wrapped( + bulk_events_with_meta, + options, + ) + def _build_run_workflow_attributes( self, config: WorkflowRunTriggerConfig ) -> dict[str, Any]: diff --git a/sdks/python/hatchet_sdk/worker/runner/run_loop_manager.py b/sdks/python/hatchet_sdk/worker/runner/run_loop_manager.py index 6b36c50e82..1a6cbf4bd0 100644 --- a/sdks/python/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/sdks/python/hatchet_sdk/worker/runner/run_loop_manager.py @@ -57,8 +57,8 @@ def __init__( self.client = Client(config=self.config, debug=self.debug) self.start_loop_manager_task: asyncio.Task[None] | None = None self.log_sender = AsyncLogSender(self.client.event) - self.log_task = self.loop.create_task(self.log_sender.consume()) + self.log_sender.start() self.start() def start(self) -> None: @@ -85,7 +85,7 @@ def cleanup(self) -> None: self.killing = True self.action_queue.put(STOP_LOOP) - self.log_sender.publish(STOP_LOOP) + self.log_sender.stop() async def evict_all_waiting_durable_runs(self) -> None: if self.runner: diff --git a/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py b/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py index ab08dd08d3..9082a7b985 100644 --- a/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -1,8 +1,8 @@ -import asyncio import functools import logging +import queue +import threading from collections.abc import Awaitable, Callable -from contextlib import suppress from dataclasses import dataclass from io import StringIO from typing import Any, Literal, ParamSpec, TypeVar @@ -127,24 +127,17 @@ class LogRecord: class AsyncLogSender: def __init__(self, event_client: EventClient): - self.event_client = event_client - self.q = asyncio.Queue[LogRecord | STOP_LOOP_TYPE]( - maxsize=event_client.client_config.log_queue_size - ) - self._owner_loop: asyncio.AbstractEventLoop | None = None - - async def consume(self) -> None: - self._owner_loop = asyncio.get_running_loop() + self._event_client = event_client + self.q: queue.SimpleQueue[LogRecord | STOP_LOOP_TYPE] = queue.SimpleQueue() + self._thread: threading.Thread | None = None + def _consume(self) -> None: while True: - record = await self.q.get() - + record = self.q.get() if record == STOP_LOOP: break - try: - await asyncio.to_thread( - self.event_client.log, + self._event_client.log( message=record.message, step_run_id=record.step_run_id, level=record.level, @@ -154,30 +147,18 @@ async def consume(self) -> None: logger.exception("failed to send log to Hatchet") def publish(self, record: LogRecord | STOP_LOOP_TYPE) -> None: - owner_loop = self._owner_loop - - if owner_loop is None: - self._enqueue_or_drop(record) - return + self.q.put(record) - try: - running_loop = asyncio.get_running_loop() - except RuntimeError: - running_loop = None + def start(self) -> None: + self._thread = threading.Thread(target=self._consume, daemon=True) + self._thread.start() - if running_loop is owner_loop: - self._enqueue_or_drop(record) + def stop(self, timeout: float = 5.0) -> None: + if self._thread is None: return - - with suppress(RuntimeError): - # The owner loop may already be closed during worker shutdown. - owner_loop.call_soon_threadsafe(self._enqueue_or_drop, record) - - def _enqueue_or_drop(self, record: LogRecord | STOP_LOOP_TYPE) -> None: - try: - self.q.put_nowait(record) - except asyncio.QueueFull: - logger.warning("log queue is full, dropping log message") + self.q.put(STOP_LOOP) + self._thread.join(timeout) + self._thread = None class LogForwardingHandler(logging.StreamHandler): # type: ignore[type-arg] diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 4446ebfb43..a3826d335d 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hatchet-sdk" -version = "1.33.12" +version = "1.33.13" description = "This is the official Python SDK for Hatchet, a distributed, fault-tolerant task queue. The SDK allows you to easily integrate Hatchet's task scheduling and workflow orchestration capabilities into your Python applications." readme = "README.md" license = { text = "MIT" } diff --git a/sdks/python/tests/unit/test_aio_put_stream.py b/sdks/python/tests/unit/test_aio_put_stream.py new file mode 100644 index 0000000000..586de87e72 --- /dev/null +++ b/sdks/python/tests/unit/test_aio_put_stream.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable +from typing import cast + +import grpc +import grpc.aio +import pytest +import tenacity + +from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry +from hatchet_sdk.config import ClientConfig, TenacityConfig +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.events_pb2 import ( + PutStreamEventRequest, + PutStreamEventResponse, +) +from hatchet_sdk.contracts.events_pb2_grpc import EventsServiceStub + + +def _make_grpc_error(code: grpc.StatusCode, details: str = "") -> grpc.aio.AioRpcError: + empty: grpc.aio.Metadata = grpc.aio.Metadata() + return grpc.aio.AioRpcError(code, empty, empty, details) + + +class _GeneratedAioUnaryCall: + """Matches grpc.aio generated unary calls: sync call, awaitable result.""" + + def __init__(self, failures_before_success: int) -> None: + self.failures_before_success = failures_before_success + self.calls = 0 + self.requests: list[PutStreamEventRequest] = [] + self.metadata: list[tuple[tuple[str, str]]] = [] + + def __call__( + self, + request: PutStreamEventRequest, + *, + metadata: tuple[tuple[str, str]], + ) -> Awaitable[PutStreamEventResponse]: + self.calls += 1 + self.requests.append(request) + self.metadata.append(metadata) + + async def response() -> PutStreamEventResponse: + if self.calls <= self.failures_before_success: + raise _make_grpc_error(grpc.StatusCode.UNAVAILABLE, "transient") + + return PutStreamEventResponse() + + return response() + + +class _FakeAioEventsServiceStub: + def __init__(self, put_stream_event: _GeneratedAioUnaryCall) -> None: + self.PutStreamEvent = put_stream_event + + +def _event_client(aio_stub: _FakeAioEventsServiceStub) -> EventClient: + client = EventClient.__new__(EventClient) + client.client_config = ClientConfig.model_construct( + tenant_id="tenant", + token="token", + namespace="", + server_url="http://localhost", + host_port="localhost:7070", + tenacity=TenacityConfig(max_attempts=3, wait=tenacity.wait_none), + ) + client.token = "token" + client.namespace = "" + client._aio_client = cast(EventsServiceStub, aio_stub) + client._retrying_aio_put_stream_event = tenacity_retry( + client._put_stream_event, client.client_config.tenacity + ) + + return client + + +@pytest.mark.parametrize( + ("data", "expected_message"), + [ + ("hello", b"hello"), + (b"hello", b"hello"), + ], +) +async def test_aio_stream_retries_generated_aio_callable( + data: str | bytes, expected_message: bytes +) -> None: + put_stream_event = _GeneratedAioUnaryCall(failures_before_success=2) + client = _event_client(_FakeAioEventsServiceStub(put_stream_event)) + + await client.aio_stream(data, step_run_id="step-run-id", index=7) + + assert put_stream_event.calls == 3 + assert [request.task_run_external_id for request in put_stream_event.requests] == [ + "step-run-id", + "step-run-id", + "step-run-id", + ] + assert [request.message for request in put_stream_event.requests] == [ + expected_message, + expected_message, + expected_message, + ] + assert [request.event_index for request in put_stream_event.requests] == [7, 7, 7] + assert put_stream_event.metadata == [ + (("authorization", "bearer token"),), + (("authorization", "bearer token"),), + (("authorization", "bearer token"),), + ] + + +class _RecordingEventClient: + def __init__(self) -> None: + self.calls: list[tuple[str | bytes, str, int]] = [] + self.both_calls_started = asyncio.Event() + self.release_sends = asyncio.Event() + + async def aio_stream(self, data: str | bytes, step_run_id: str, index: int) -> None: + self.calls.append((data, step_run_id, index)) + if len(self.calls) == 2: + self.both_calls_started.set() + await self.release_sends.wait() + + +def _context(event_client: _RecordingEventClient) -> Context: + ctx = Context.__new__(Context) + ctx._stream_index = 0 + ctx._step_run_id = "step-run-id" + ctx._event_client = cast(EventClient, event_client) + + return ctx + + +async def test_aio_put_stream_assigns_index_before_async_send() -> None: + event_client = _RecordingEventClient() + ctx = _context(event_client) + + tasks = [ + asyncio.create_task(ctx.aio_put_stream("first")), + asyncio.create_task(ctx.aio_put_stream(b"second")), + ] + + try: + await asyncio.wait_for(event_client.both_calls_started.wait(), timeout=1) + + assert event_client.calls == [ + ("first", "step-run-id", 0), + (b"second", "step-run-id", 1), + ] + finally: + event_client.release_sends.set() + await asyncio.gather(*tasks) diff --git a/sdks/python/tests/unit/test_capture_logs.py b/sdks/python/tests/unit/test_capture_logs.py index cc4c272e88..21da68c4cd 100644 --- a/sdks/python/tests/unit/test_capture_logs.py +++ b/sdks/python/tests/unit/test_capture_logs.py @@ -2,93 +2,54 @@ import asyncio import logging -from contextlib import suppress from io import StringIO from types import SimpleNamespace -from typing import Any, cast +from typing import cast from hatchet_sdk.clients.events import EventClient from hatchet_sdk.runnables.contextvars import ctx_step_run_id, ctx_task_retry_count -from hatchet_sdk.utils.typing import STOP_LOOP, LogLevel +from hatchet_sdk.utils.typing import LogLevel from hatchet_sdk.worker.runner.utils.capture_logs import ( AsyncLogSender, LogForwardingHandler, + LogRecord, ) class FakeEventClient: - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__(self) -> None: self.client_config = SimpleNamespace(log_queue_size=10) - self.loop = loop - self.logged = asyncio.Event() - self.records: list[dict[str, Any]] = [] - def log( - self, - message: str, - step_run_id: str, - level: LogLevel | None = None, - task_retry_count: int | None = None, - ) -> None: - self.records.append( - { - "message": message, - "step_run_id": step_run_id, - "level": level.value if level else None, - "task_retry_count": task_retry_count, - } - ) - self.loop.call_soon_threadsafe(self.logged.set) - -async def test_log_forwarding_from_to_thread_uses_sender_loop() -> None: - loop = asyncio.get_running_loop() - previous_debug = loop.get_debug() - loop.set_debug(True) - - event_client = FakeEventClient(loop) +async def test_log_forwarding_handler_enqueues_correct_record() -> None: + event_client = FakeEventClient() log_sender = AsyncLogSender(cast(EventClient, event_client)) - consume_task = asyncio.create_task(log_sender.consume()) - await asyncio.sleep(0.01) + + target_logger = logging.getLogger("capture-log-test") + previous_level = target_logger.level + target_logger.setLevel(logging.INFO) handler = LogForwardingHandler(log_sender, StringIO()) - root_logger = logging.getLogger() - previous_level = root_logger.level - root_logger.setLevel(logging.INFO) - root_logger.addHandler(handler) + target_logger.addHandler(handler) step_token = ctx_step_run_id.set("step-run-id") retry_token = ctx_task_retry_count.set(2) - log_sent = False try: def log_from_worker_thread() -> None: logging.getLogger("capture-log-test").info("hello from worker thread") - await asyncio.wait_for(asyncio.to_thread(log_from_worker_thread), timeout=1) - await asyncio.wait_for(event_client.logged.wait(), timeout=1) - log_sent = True + await asyncio.to_thread(log_from_worker_thread) - assert event_client.records == [ - { - "message": "hello from worker thread", - "step_run_id": "step-run-id", - "level": "INFO", - "task_retry_count": 2, - } - ] + record = log_sender.q.get() + assert isinstance(record, LogRecord) + assert record.message == "hello from worker thread" + assert record.step_run_id == "step-run-id" + assert record.level == LogLevel.INFO + assert record.task_retry_count == 2 finally: ctx_step_run_id.reset(step_token) ctx_task_retry_count.reset(retry_token) - root_logger.removeHandler(handler) - root_logger.setLevel(previous_level) - loop.set_debug(previous_debug) - - if log_sent: - log_sender.publish(STOP_LOOP) - await consume_task - else: - consume_task.cancel() - with suppress(asyncio.CancelledError): - await consume_task + target_logger.removeHandler(handler) + target_logger.setLevel(previous_level)