Skip to content

Commit f11bf96

Browse files
Merge remote-tracking branch 'origin/main' into feature/report-any-enabled
# Conflicts: # tests/functional/conftest.py # tests/functional/e2e/test_classify_single.py Co-authored-by: corpo-iwillspeak <265613520+corpo-iwillspeak@users.noreply.github.com>
2 parents eac65cb + c53449c commit f11bf96

8 files changed

Lines changed: 282 additions & 187 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ jobs:
131131
run: uv build
132132

133133
- name: Upload build artifacts
134-
uses: actions/upload-artifact@v6
134+
uses: actions/upload-artifact@v7
135135
with:
136136
name: dist
137137
path: dist/
@@ -159,7 +159,7 @@ jobs:
159159
enable-cache: true
160160

161161
- name: Download build artifacts
162-
uses: actions/download-artifact@v7
162+
uses: actions/download-artifact@v8
163163
with:
164164
name: dist
165165
path: dist/

.github/workflows/docs.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,24 @@ jobs:
7272
uv run make clean html
7373
7474
- name: Setup Pages
75-
uses: actions/configure-pages@v5
75+
uses: actions/configure-pages@v6
7676
if: github.event_name == 'release' && github.event.action == 'published'
7777

7878
- name: Upload artifact for GitHub Pages
79-
uses: actions/upload-pages-artifact@v4
79+
uses: actions/upload-pages-artifact@v5
8080
if: github.event_name == 'release' && github.event.action == 'published'
8181
with:
8282
path: docs/_build/html
8383

8484
- name: Upload documentation artifacts
85-
uses: actions/upload-artifact@v6
85+
uses: actions/upload-artifact@v7
8686
if: github.event_name != 'release'
8787
with:
8888
name: documentation
8989
path: docs/_build/html
9090

9191
- name: Upload build artifacts for debugging
92-
uses: actions/upload-artifact@v6
92+
uses: actions/upload-artifact@v7
9393
if: failure()
9494
with:
9595
name: docs-build-artifacts
@@ -107,4 +107,4 @@ jobs:
107107
steps:
108108
- name: Deploy to GitHub Pages
109109
id: deployment
110-
uses: actions/deploy-pages@v4
110+
uses: actions/deploy-pages@v5

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ ATHENA_NON_EXISTENT_AFFILIATE=non-existent-affiliate-id (default:
162162
thisaffiliatedoesnotexist123) - this is used to test error handling.
163163
ATHENA_NON_PERMITTED_AFFILIATE=non-permitted-affiliate-id (default:
164164
thisaffiliatedoesnothaveathenaenabled) - this is used to test error handling.
165+
ATHENA_E2E_TESTCASE_DIR=test-case-directory (default: integrator_sample) - this is the test case directory to use for the e2e tests.
166+
See E2E Tests section below for more details.
165167
```
166168

167169
Then run the functional tests with:
@@ -170,8 +172,18 @@ Then run the functional tests with:
170172
pytest -m functional
171173
```
172174

173-
To exclude the e2e tests, which require usage of the live classifier and
174-
therefore are unsuitable for regular development runs, use:
175+
#### E2E Tests
176+
177+
The e2e tests assert that the API returns some expected _scores_ rather than
178+
exercising different API paths. As such, they are dependent on the classifier
179+
that you are calling through the API. Right now, there are 2 types of
180+
classifier, benign and live. By default, the tests will run the
181+
`integrator_sample` test set, which uses the live classifier. If you wish to
182+
use the benign classifier instead, you may set the `ATHENA_E2E_TESTCASE_DIR`
183+
environment variable to `benign_model`.
184+
185+
Alternatively, you may disable these tests altogether, by excluding tests that
186+
have the `e2e` marker, something like this:
175187

176188
```bash
177189
pytest -m 'functional and not e2e'

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ documentation = "https://crispthinking.github.io/athena-python-client/"
2424
dependencies = [
2525
"anyio>=4.10.0",
2626
"brotli>=1.1.0",
27-
"grpcio-tools>=1.74.0",
27+
"grpcio>=1.78.0,!=1.78.1,<2.0.0",
2828
"httpx>=0.25.0",
2929
"numpy>=2.2.6",
30-
"opencv-python-headless>=4.13.0.92"
30+
"opencv-python-headless>=4.13.0.92",
3131
]
3232

3333
[project.optional-dependencies]
@@ -43,6 +43,7 @@ docs = [
4343
dev = [
4444
"basedpyright>=1.31.4",
4545
"brotli-stubs>=1.1.0",
46+
"grpcio-tools>=1.78.0,!=1.78.1,<2.0.0",
4647
"load-dotenv>=0.1.0",
4748
"mypy-protobuf>=3.6.0",
4849
"pre-commit>=4.2.0",

tests/functional/conftest.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
1+
import asyncio
12
import os
23
import uuid
4+
from asyncio import Future, Queue, Task, create_task
5+
from collections.abc import AsyncIterator
6+
from copy import deepcopy
37
from typing import cast
48

59
import cv2 as cv
610
import numpy as np
711
import pytest
812
import pytest_asyncio
913
from dotenv import load_dotenv
14+
from grpc.aio import Channel
15+
from resolver_athena_client.generated.athena.models_pb2 import (
16+
ClassificationOutput,
17+
)
1018

19+
from resolver_athena_client.client.athena_client import AthenaClient
1120
from resolver_athena_client.client.athena_options import AthenaOptions
12-
from resolver_athena_client.client.channel import CredentialHelper
21+
from resolver_athena_client.client.channel import (
22+
CredentialHelper,
23+
create_channel_with_credentials,
24+
)
1325
from resolver_athena_client.client.consts import (
1426
EXPECTED_HEIGHT,
1527
EXPECTED_WIDTH,
1628
MAX_DEPLOYMENT_ID_LENGTH,
1729
)
30+
from resolver_athena_client.client.models.input_model import ImageData
1831

1932

2033
def _create_base_test_image_opencv(width: int, height: int) -> np.ndarray:
@@ -80,8 +93,7 @@ async def credential_helper() -> CredentialHelper:
8093
)
8194

8295

83-
@pytest.fixture
84-
def athena_options() -> AthenaOptions:
96+
def _load_options() -> AthenaOptions:
8597
_ = load_dotenv()
8698
host = os.getenv("ATHENA_HOST", "localhost")
8799

@@ -100,9 +112,15 @@ def athena_options() -> AthenaOptions:
100112
timeout=120.0, # Maximum duration, not forced timeout
101113
keepalive_interval=30.0, # Longer intervals for persistent streams
102114
affiliate=affiliate,
115+
compression_quality=2,
103116
)
104117

105118

119+
@pytest.fixture
120+
def athena_options() -> AthenaOptions:
121+
return _load_options()
122+
123+
106124
@pytest.fixture(scope="session", params=SUPPORTED_TEST_FORMATS)
107125
def valid_formatted_image(
108126
request: pytest.FixtureRequest,
@@ -145,3 +163,75 @@ def valid_formatted_image(
145163
_ = f.write(image_bytes)
146164

147165
return image_bytes
166+
167+
168+
class StreamingSender:
169+
"""Helper class to provide a single-send-like interface with speed
170+
171+
The class provides a 'send' method that can be passed an imagedata and will
172+
send it along a stream, and collect all results into an internal buffer.
173+
174+
The 'send' method will asynchronously wait for the result and return it,
175+
providing an interface that mimics a single request-response call, while
176+
under the hood it is using a streaming connection for speed.
177+
"""
178+
179+
def __init__(self, grpc_channel: Channel, options: AthenaOptions) -> None:
180+
self._request_queue: Queue[ImageData] = Queue()
181+
self._pending_results: dict[str, Future[ClassificationOutput]] = {}
182+
183+
# tests are run in series, so we gain nothing here from waiting for a
184+
# batch that will never fill, so just send it immediately for better
185+
# latency
186+
streaming_options = deepcopy(options)
187+
streaming_options.max_batch_size = 1
188+
189+
self._run_task: Task[None] = create_task(
190+
self._run(grpc_channel, streaming_options)
191+
)
192+
193+
async def _run(self, grpc_channel: Channel, options: AthenaOptions) -> None:
194+
async with AthenaClient(grpc_channel, options) as client:
195+
generator = self._send_from_queue()
196+
responses = client.classify_images(generator)
197+
async for response in responses:
198+
for output in response.outputs:
199+
if output.correlation_id in self._pending_results:
200+
future = self._pending_results.pop(
201+
output.correlation_id
202+
)
203+
future.set_result(output)
204+
205+
async def _send_from_queue(self) -> AsyncIterator[ImageData]:
206+
"""Async generator to yield requests from the queue."""
207+
while True:
208+
if image_data := await self._request_queue.get():
209+
yield image_data
210+
self._request_queue.task_done()
211+
212+
async def send(self, image_data: ImageData) -> ClassificationOutput:
213+
"""Send an image and wait for the corresponding result."""
214+
if self._run_task.done():
215+
self._run_task.result()
216+
217+
if image_data.correlation_id is None:
218+
image_data.correlation_id = str(uuid.uuid4())
219+
future = asyncio.get_event_loop().create_future()
220+
self._pending_results[image_data.correlation_id] = future
221+
222+
await self._request_queue.put(image_data)
223+
224+
return await future
225+
226+
227+
@pytest_asyncio.fixture(scope="session", loop_scope="session")
228+
async def streaming_sender(
229+
credential_helper: CredentialHelper,
230+
) -> StreamingSender:
231+
"""Fixture to provide a helper for sending over a streaming connection."""
232+
# Create gRPC channel with credentials
233+
opts = _load_options()
234+
channel = await create_channel_with_credentials(
235+
opts.host, credential_helper
236+
)
237+
return StreamingSender(channel, opts)

tests/functional/e2e/test_classify_single.py

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,24 @@
22

33
import pytest
44

5-
from resolver_athena_client.client.athena_client import AthenaClient
6-
from resolver_athena_client.client.athena_options import AthenaOptions
7-
from resolver_athena_client.client.channel import (
8-
CredentialHelper,
9-
create_channel_with_credentials,
10-
)
115
from resolver_athena_client.client.models import ImageData
6+
from tests.functional.conftest import StreamingSender
127
from tests.functional.e2e.testcases.parser import (
138
AthenaTestCase,
14-
load_test_cases,
9+
load_test_cases_by_env,
1510
)
1611

17-
TEST_CASES = load_test_cases("integrator_sample")
12+
TEST_CASES = load_test_cases_by_env()
1813

1914
FP_ERROR_TOLERANCE = 1e-4
2015

2116

22-
def _get_test_case_id(tc: AthenaTestCase) -> str:
23-
"""Get the test case ID for pytest parametrize."""
24-
return tc.id
25-
26-
27-
@pytest.mark.asyncio
17+
@pytest.mark.asyncio(loop_scope="session")
2818
@pytest.mark.functional
2919
@pytest.mark.e2e
30-
@pytest.mark.parametrize(
31-
"test_case",
32-
TEST_CASES,
33-
ids=_get_test_case_id,
34-
)
35-
async def test_classify_single(
36-
athena_options: AthenaOptions,
37-
credential_helper: CredentialHelper,
20+
@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id)
21+
async def test_e2e_case(
22+
streaming_sender: StreamingSender,
3823
test_case: AthenaTestCase,
3924
) -> None:
4025
"""Functional test for ClassifySingle endpoint and API methods.
@@ -43,38 +28,33 @@ async def test_classify_single(
4328
4429
"""
4530

46-
# Create gRPC channel with credentials
47-
channel = await create_channel_with_credentials(
48-
athena_options.host, credential_helper
49-
)
5031
with Path.open(Path(test_case.filepath), "rb") as f:
5132
image_bytes = f.read()
5233

53-
async with AthenaClient(channel, athena_options) as client:
54-
image_data = ImageData(image_bytes)
34+
image_data = ImageData(image_bytes)
5535

56-
# Classify with auto-generated correlation ID
57-
result = await client.classify_single(image_data)
36+
# Classify with auto-generated correlation ID
37+
result = await streaming_sender.send(image_data)
5838

59-
if result.error.code:
60-
msg = f"Image Result Error: {result.error.message}"
61-
pytest.fail(msg)
39+
if result.error.code:
40+
msg = f"Image Result Error: {result.error.message}"
41+
pytest.fail(msg)
6242

63-
actual_output = {c.label: c.weight for c in result.classifications}
64-
assert set(test_case.expected_output.keys()).issubset(
65-
set(actual_output.keys())
66-
), (
67-
"Expected output to contain labels: ",
68-
f"{test_case.expected_output.keys() - actual_output.keys()}",
43+
actual_output = {c.label: c.weight for c in result.classifications}
44+
assert set(test_case.expected_output.keys()).issubset(
45+
set(actual_output.keys())
46+
), (
47+
"Expected output to contain labels: ",
48+
f"{test_case.expected_output.keys() - actual_output.keys()}",
49+
)
50+
actual_output = {k: actual_output[k] for k in test_case.expected_output}
51+
52+
for label in test_case.expected_output:
53+
expected = test_case.expected_output[label]
54+
actual = actual_output[label]
55+
diff = abs(expected - actual)
56+
assert diff < FP_ERROR_TOLERANCE, (
57+
f"Weight for label '{label}' differs by more than "
58+
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
59+
f"diff={diff}"
6960
)
70-
actual_output = {k: actual_output[k] for k in test_case.expected_output}
71-
72-
for label in test_case.expected_output:
73-
expected = test_case.expected_output[label]
74-
actual = actual_output[label]
75-
diff = abs(expected - actual)
76-
assert diff < FP_ERROR_TOLERANCE, (
77-
f"Weight for label '{label}' differs by more than "
78-
f"{FP_ERROR_TOLERANCE}: expected={expected}, actual={actual}, "
79-
f"diff={diff}"
80-
)

tests/functional/e2e/testcases/parser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import json
2+
import os
23
from pathlib import Path
34
from typing import TypedDict, cast
45

6+
from dotenv import load_dotenv
7+
58
# Path to the shared testcases directory in athena-protobufs
69
_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent
710
TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases"
@@ -29,6 +32,13 @@ def __init__(
2932
self.classification_labels: list[str] = classification_labels
3033

3134

35+
def load_test_cases_by_env() -> list[AthenaTestCase]:
36+
_ = load_dotenv()
37+
return load_test_cases(
38+
os.getenv("ATHENA_E2E_TESTCASE_DIR", "integrator_sample")
39+
)
40+
41+
3242
def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]:
3343
with Path.open(
3444
Path(TESTCASES_DIR / dirname / "expected_outputs.json"),

0 commit comments

Comments
 (0)