diff --git a/pyproject.toml b/pyproject.toml index fa72313..6a9cb63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ reportImplicitStringConcatenation = false [tool.pytest.ini_options] markers = [ "functional: mark a test as a functional test", - "e2e: mark a test as an end-to-end test which asserts expected scores for known inputs. These tests should only be considered authoritative against the live classifier." + "e2e: mark a test as an end-to-end test which asserts expected scores for known inputs. These tests should only be considered authoritative against the live classifier.", + "soak: mark a test as a soak test that runs many requests over a longer duration to validate stability and performance." ] addopts = "--strict-markers" diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 750723a..895fab9 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -79,23 +79,25 @@ async def credential_helper() -> CredentialHelper: _ = load_dotenv() client_id = os.environ["OAUTH_CLIENT_ID"] client_secret = os.environ["OAUTH_CLIENT_SECRET"] - auth_url = os.getenv( - "OAUTH_AUTH_URL", "https://crispthinking.auth0.com/oauth/token" - ) - audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-live") + auth_url = os.getenv("OAUTH_AUTH_URL") + audience = os.getenv("OAUTH_AUDIENCE") # Create credential helper - return CredentialHelper( - client_id=client_id, - client_secret=client_secret, - auth_url=auth_url, - audience=audience, - ) + kwargs: dict[str, str] = { + "client_id": client_id, + "client_secret": client_secret, + } + if auth_url: + kwargs["auth_url"] = auth_url + if audience: + kwargs["audience"] = audience + + return CredentialHelper(proactive_refresh_threshold=0.25, **kwargs) def _load_options() -> AthenaOptions: _ = load_dotenv() - host = os.getenv("ATHENA_HOST", "localhost") + host = os.getenv("ATHENA_HOST") deployment_id = f"functional-test-{uuid.uuid4()}" if len(deployment_id) > MAX_DEPLOYMENT_ID_LENGTH: @@ -104,8 +106,7 @@ def _load_options() -> AthenaOptions: affiliate = os.environ["ATHENA_TEST_AFFILIATE"] # Run classification with OAuth authentication - return AthenaOptions( - host=host, + opts = AthenaOptions( resize_images=True, deployment_id=deployment_id, compress_images=True, @@ -115,6 +116,11 @@ def _load_options() -> AthenaOptions: compression_quality=2, ) + if host: + opts.host = host + + return opts + @pytest.fixture def athena_options() -> AthenaOptions: diff --git a/tests/functional/test_classify_single_soak.py b/tests/functional/test_classify_single_soak.py new file mode 100644 index 0000000..5cdcfe2 --- /dev/null +++ b/tests/functional/test_classify_single_soak.py @@ -0,0 +1,66 @@ +import grpc +import pytest + +from common_utils.image_generation import create_test_image +from resolver_athena_client.client.athena_client import AthenaClient +from resolver_athena_client.client.athena_options import AthenaOptions +from resolver_athena_client.client.channel import ( + CredentialHelper, + create_channel_with_credentials, +) +from resolver_athena_client.client.models import ImageData + +SOAK_ITERATIONS = 1000 +MIN_HASH_CHECK_SUCCESS_RATE = 0.95 + + +@pytest.mark.asyncio +@pytest.mark.functional +@pytest.mark.soak +async def test_classify_single_hash_check_soak( + athena_options: AthenaOptions, credential_helper: CredentialHelper +) -> None: + """Soak test: classify images and assert hash check result exists.""" + + channel = await create_channel_with_credentials( + athena_options.host, credential_helper + ) + + async with AthenaClient(channel, athena_options) as client: + successes = 0 + failures: list[str] = [] + + for i in range(SOAK_ITERATIONS): + image_bytes = create_test_image() + image_data = ImageData(image_bytes) + + try: + result = await client.classify_single(image_data) + except grpc.aio.AioRpcError as e: + failures.append( + f"Iteration {i}: gRPC error {e.code()} - {e.details()}" + ) + continue + + if result.error.code: + failures.append( + f"Iteration {i}: error {result.error.code}" + f" - {result.error.message}" + ) + continue + + found_hash_check = any( + c.label.startswith("KnownCSAM-") for c in result.classifications + ) + if found_hash_check: + successes += 1 + else: + failures.append( + f"Iteration {i}: no KnownCSAM- classification found" + ) + + success_rate = successes / SOAK_ITERATIONS + assert success_rate >= MIN_HASH_CHECK_SUCCESS_RATE, ( + f"Hash check success rate {success_rate:.1%} is below 95%. " + f"{len(failures)} failures:\n" + "\n".join(failures[:20]) + )