1+ import asyncio
12import os
23import uuid
4+ from asyncio import Future , Queue , Task , create_task
5+ from collections .abc import AsyncIterator
6+ from copy import deepcopy
37from typing import cast
48
59import cv2 as cv
610import numpy as np
711import pytest
812import pytest_asyncio
913from dotenv import load_dotenv
14+ from grpc .aio import Channel
15+ from resolver_athena_client .generated .athena .models_pb2 import (
16+ ClassificationOutput ,
17+ )
1018
19+ from resolver_athena_client .client .athena_client import AthenaClient
1120from resolver_athena_client .client .athena_options import AthenaOptions
12- from resolver_athena_client .client .channel import CredentialHelper
21+ from resolver_athena_client .client .channel import (
22+ CredentialHelper ,
23+ create_channel_with_credentials ,
24+ )
1325from resolver_athena_client .client .consts import (
1426 EXPECTED_HEIGHT ,
1527 EXPECTED_WIDTH ,
1628 MAX_DEPLOYMENT_ID_LENGTH ,
1729)
30+ from resolver_athena_client .client .models .input_model import ImageData
1831
1932
2033def _create_base_test_image_opencv (width : int , height : int ) -> np .ndarray :
@@ -80,8 +93,7 @@ async def credential_helper() -> CredentialHelper:
8093 )
8194
8295
83- @pytest .fixture
84- def athena_options () -> AthenaOptions :
96+ def _load_options () -> AthenaOptions :
8597 _ = load_dotenv ()
8698 host = os .getenv ("ATHENA_HOST" , "localhost" )
8799
@@ -100,9 +112,15 @@ def athena_options() -> AthenaOptions:
100112 timeout = 120.0 , # Maximum duration, not forced timeout
101113 keepalive_interval = 30.0 , # Longer intervals for persistent streams
102114 affiliate = affiliate ,
115+ compression_quality = 2 ,
103116 )
104117
105118
119+ @pytest .fixture
120+ def athena_options () -> AthenaOptions :
121+ return _load_options ()
122+
123+
106124@pytest .fixture (scope = "session" , params = SUPPORTED_TEST_FORMATS )
107125def valid_formatted_image (
108126 request : pytest .FixtureRequest ,
@@ -145,3 +163,75 @@ def valid_formatted_image(
145163 _ = f .write (image_bytes )
146164
147165 return image_bytes
166+
167+
168+ class StreamingSender :
169+ """Helper class to provide a single-send-like interface with speed
170+
171+ The class provides a 'send' method that can be passed an imagedata and will
172+ send it along a stream, and collect all results into an internal buffer.
173+
174+ The 'send' method will asynchronously wait for the result and return it,
175+ providing an interface that mimics a single request-response call, while
176+ under the hood it is using a streaming connection for speed.
177+ """
178+
179+ def __init__ (self , grpc_channel : Channel , options : AthenaOptions ) -> None :
180+ self ._request_queue : Queue [ImageData ] = Queue ()
181+ self ._pending_results : dict [str , Future [ClassificationOutput ]] = {}
182+
183+ # tests are run in series, so we gain nothing here from waiting for a
184+ # batch that will never fill, so just send it immediately for better
185+ # latency
186+ streaming_options = deepcopy (options )
187+ streaming_options .max_batch_size = 1
188+
189+ self ._run_task : Task [None ] = create_task (
190+ self ._run (grpc_channel , streaming_options )
191+ )
192+
193+ async def _run (self , grpc_channel : Channel , options : AthenaOptions ) -> None :
194+ async with AthenaClient (grpc_channel , options ) as client :
195+ generator = self ._send_from_queue ()
196+ responses = client .classify_images (generator )
197+ async for response in responses :
198+ for output in response .outputs :
199+ if output .correlation_id in self ._pending_results :
200+ future = self ._pending_results .pop (
201+ output .correlation_id
202+ )
203+ future .set_result (output )
204+
205+ async def _send_from_queue (self ) -> AsyncIterator [ImageData ]:
206+ """Async generator to yield requests from the queue."""
207+ while True :
208+ if image_data := await self ._request_queue .get ():
209+ yield image_data
210+ self ._request_queue .task_done ()
211+
212+ async def send (self , image_data : ImageData ) -> ClassificationOutput :
213+ """Send an image and wait for the corresponding result."""
214+ if self ._run_task .done ():
215+ self ._run_task .result ()
216+
217+ if image_data .correlation_id is None :
218+ image_data .correlation_id = str (uuid .uuid4 ())
219+ future = asyncio .get_event_loop ().create_future ()
220+ self ._pending_results [image_data .correlation_id ] = future
221+
222+ await self ._request_queue .put (image_data )
223+
224+ return await future
225+
226+
227+ @pytest_asyncio .fixture (scope = "session" , loop_scope = "session" )
228+ async def streaming_sender (
229+ credential_helper : CredentialHelper ,
230+ ) -> StreamingSender :
231+ """Fixture to provide a helper for sending over a streaming connection."""
232+ # Create gRPC channel with credentials
233+ opts = _load_options ()
234+ channel = await create_channel_with_credentials (
235+ opts .host , credential_helper
236+ )
237+ return StreamingSender (channel , opts )
0 commit comments