Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,11 @@ jobs:
- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Ensure no differences in generated code
if: runner.os != 'Windows'
- name: Generate protobuf code
shell: bash
run: |
source .venv/bin/activate
GENERATED_DIR="src/resolver_athena_client/generated"
BACKUP_DIR="src/resolver_athena_client/generated_backup"

cp -r $GENERATED_DIR $BACKUP_DIR

./scripts/compile_proto.sh || (echo "Protobuf compilation failed. Ensure submodules are initialized and the proto file exists." && exit 1)

# Fix imports in generated files
if [[ -f "$GENERATED_DIR/athena/athena_pb2_grpc.py" && -f "$GENERATED_DIR/athena/athena_pb2.py" ]]; then
sed -i.bak 's/^from athena /from resolver_athena_client.generated.athena /' "$GENERATED_DIR/athena/athena_pb2_grpc.py"
sed -i.bak 's/^from athena /from resolver_athena_client.generated.athena /' "$GENERATED_DIR/athena/athena_pb2.py"
rm -f "$GENERATED_DIR/athena/athena_pb2_grpc.py.bak" "$GENERATED_DIR/athena/athena_pb2.py.bak"
else
echo "Error: Expected files not found in $GENERATED_DIR/athena"
exit 1
fi

diff -r $GENERATED_DIR $BACKUP_DIR || (echo "Generated code differs. Please commit the changes after running compile_proto.sh." && exit 1)

rm -rf $BACKUP_DIR
./scripts/compile_proto.sh || (echo "Protobuf compilation failed. Ensure submodules are initialized and proto files exist." && exit 1)

- name: Run linter
run: |
Expand Down Expand Up @@ -123,6 +103,17 @@ jobs:
with:
enable-cache: true

- name: Set up Python
run: uv python install

- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Generate protobuf code
run: |
source .venv/bin/activate
./scripts/compile_proto.sh || (echo "Protobuf compilation failed. Ensure submodules are initialized and proto files exist." && exit 1)

- name: Set version from tag
run: |
if [[ "$GITHUB_REF" == refs/tags/* ]]; then
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ wheels/
.env

docs/_build/

# Generated protobuf code
src/resolver_athena_client/generated/
7 changes: 4 additions & 3 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- Format code: `ruff format`
- Lint code: `ruff check`
- Install git hooks: `pre-commit install`
- Compile protobufs: `bash scripts/compile_proto.sh` (run from root)
- Compile protobufs: `bash scripts/compile_proto.sh` (run from root, required for local development)

## Code style
- Use Python type hints throughout
Expand All @@ -30,7 +30,7 @@

## PR instructions
- Title format: [component] Description
- Run `ruff check`, `pyright`, and `pytest` before committing
- Run `bash scripts/compile_proto.sh`, `ruff check`, `pyright`, and `pytest` before committing
- Keep PRs focused on a single change
- Add tests for new functionality
- Update documentation for API changes
Expand All @@ -40,7 +40,8 @@
- Use `uv` package manager instead of pip
- Don't use `uv pip` commands, just the base `uv` commands
- Run formatters before committing
- Check generated code in `src/athena_client/generated/`
- Generated code is built automatically in CI, but run `bash scripts/compile_proto.sh` locally for development
- Generated code in `src/resolver_athena_client/generated/` is not committed to the repo
- Add error handling at each pipeline stage
- Use correlation IDs for request tracing

Expand Down
271 changes: 271 additions & 0 deletions examples/classify_single_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
#!/usr/bin/env python3
"""Example script demonstrating the classify_single method."""

import asyncio
import logging
import os
import sys
import uuid
from pathlib import Path

from create_image import create_test_image
from dotenv import load_dotenv

from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
from resolver_athena_client.client.channel import (
CredentialHelper,
create_channel_with_credentials,
)
from resolver_athena_client.client.models import ImageData


async def classify_single_image_example(
logger: logging.Logger,
options: AthenaOptions,
credential_helper: CredentialHelper,
image_path: str | None = None,
) -> bool:
"""Demonstrate single image classification.

Args:
logger: Logger instance for output
options: Configuration options for the Athena client
credential_helper: OAuth credential helper for authentication
image_path: Path to image file to classify (optional)

Returns:
True if classification was successful, False otherwise

"""
# Create gRPC channel with credentials
channel = await create_channel_with_credentials(
options.host, credential_helper
)

async with AthenaClient(channel, options) as client:
# Load image data
if image_path and Path(image_path).exists():
logger.info("Loading image from: %s", image_path)
image_bytes = Path(image_path).read_bytes()
else:
# Create a simple test image if no path provided
logger.info("Creating synthetic test image")
image_bytes = create_test_image()

# Create ImageData object
image_data = ImageData(image_bytes)
logger.info(
"Image loaded: %d bytes, MD5: %s",
len(image_data.data),
image_data.md5_hashes[0][:8] + "...",
)

try:
# Classify the single image
logger.info("Classifying single image...")
correlation_id = uuid.uuid4().hex[:63]
logger.info("Correlation ID: %s", correlation_id)
result = await client.classify_single(
image_data, correlation_id=correlation_id
)

# Process the result
logger.info("Classification completed successfully!")

if result.error.code:
logger.error(
"Classification error: %s (%s)",
result.error.message,
result.error.code,
)
if result.error.details:
logger.error("Error details: %s", result.error.details)
return False

if result.classifications:
logger.info(
"Found %d classifications:", len(result.classifications)
)
for i, classification in enumerate(result.classifications, 1):
logger.info(
" %d. Label: %s, Weight: %.3f",
i,
classification.label,
classification.weight,
)
else:
logger.info("No classifications found for this image")

except Exception:
logger.exception("Error during single image classification")
return False
else:
return True


async def classify_multiple_single_images_example(
logger: logging.Logger,
options: AthenaOptions,
credential_helper: CredentialHelper,
num_images: int = 3,
) -> int:
"""Demonstrate classifying multiple images individually.

This shows how classify_single can be used for multiple images
when you want individual control over each classification request.

Args:
logger: Logger instance for output
options: Configuration options for the Athena client
credential_helper: OAuth credential helper for authentication
num_images: Number of test images to classify

Returns:
Number of successfully classified images

"""
# Create gRPC channel with credentials
channel = await create_channel_with_credentials(
options.host, credential_helper
)

successful_count = 0

async with AthenaClient(channel, options) as client:
logger.info("Classifying %d images individually...", num_images)

for i in range(num_images):
try:
# Create a unique test image for each iteration
image_bytes = create_test_image(seed=i)
image_data = ImageData(image_bytes)

# Classify with auto-generated correlation ID
result = await client.classify_single(image_data)

logger.info(
"Image %d/%d - Correlation: %s",
i + 1,
num_images,
result.correlation_id[:8] + "...",
)

if result.error.code:
logger.warning(
"Image %d failed: %s", i + 1, result.error.message
)
elif result.classifications:
top_classification = max(
result.classifications, key=lambda c: c.weight
)
logger.info(
"Image %d - Top result: %s (%.3f)",
i + 1,
top_classification.label,
top_classification.weight,
)
successful_count += 1
else:
logger.info("Image %d - No classifications", i + 1)
successful_count += 1

except Exception: # noqa: PERF203
logger.exception("Failed to classify image %d", i + 1)

logger.info(
"Completed: %d/%d images classified successfully",
successful_count,
num_images,
)
return successful_count


async def main() -> int:
"""Run the classify_single examples."""
logger = logging.getLogger(__name__)
load_dotenv()

# OAuth credentials from environment
client_id = os.getenv("OAUTH_CLIENT_ID")
client_secret = os.getenv("OAUTH_CLIENT_SECRET")
auth_url = os.getenv(
"OAUTH_AUTH_URL", "https://crispthinking.auth0.com/oauth/token"
)
audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-dev")

if not client_id or not client_secret:
logger.error("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set")
return 1

host = os.getenv("ATHENA_HOST", "localhost")
logger.info("Connecting to %s", host)

# Create credential helper
credential_helper = CredentialHelper(
client_id=client_id,
client_secret=client_secret,
auth_url=auth_url,
audience=audience,
)

# Test token acquisition
try:
logger.info("Acquiring OAuth token...")
token = await credential_helper.get_token()
logger.info("Successfully acquired token (length: %d)", len(token))
except Exception:
logger.exception("Failed to acquire OAuth token")
return 1

# Configure client options
options = AthenaOptions(
host=host,
resize_images=True,
compress_images=True,
timeout=30.0, # Shorter timeout for single requests
affiliate="Crisp",
deployment_id="single-example-deployment", # Not used
)

try:
# Example 1: Classify a single image
logger.info("\n=== Example 1: Single Image Classification ===")
success = await classify_single_image_example(
logger,
options,
credential_helper,
image_path=os.getenv("TEST_IMAGE_PATH"), # Optional image path
)

if not success:
logger.error("Single image classification failed")
return 1

# Example 2: Classify multiple images individually
logger.info("\n=== Example 2: Multiple Individual Classifications ===")
successful_count = await classify_multiple_single_images_example(
logger, options, credential_helper, num_images=5
)

if successful_count == 0:
logger.error("No images were successfully classified")
return 1

logger.info("\n=== All examples completed successfully! ===")

except Exception:
logger.exception("Examples failed")
return 1
else:
return 0


if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
)

sys.exit(asyncio.run(main()))
20 changes: 20 additions & 0 deletions examples/create_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,23 @@ async def iter_images(
counter[0] += 1
yield ImageData(img_bytes)
count += 1


def create_test_image(
width: int = 160, height: int = 120, seed: int | None = None
) -> bytes:
"""Create a test image with specified dimensions and optional seed.

Args:
width: Width of the test image in pixels (default: 160)
height: Height of the test image in pixels (default: 120)
seed: Optional seed for reproducible image generation

Returns:
PNG image bytes

"""
if seed is not None:
_rng.seed(seed)

return create_random_image(width, height)
Loading