Skip to content

Commit d4b5692

Browse files
snus-kinCopilot
andauthored
ClassifySingle Endpoint support (#20)
* feat: s/athena-client/resolver-athena-client/g * Update docs/conf.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * feat: classify single endpoint * feat: protobuf gen dynamically instead of static inclusion - Update CI workflow to use simplified protobuf generation approach - Modify .gitignore to exclude generated protobuf code from version control - Update compile_proto.sh to be more flexible and handle multiple proto files - Fix all import paths for new directory structure (resolver_athena_client) - Update AGENTS.md documentation to reflect new workflow - Change package name in pyproject.toml to use underscores This brings the protobuf compilation process in line with modern practices where generated code is built dynamically rather than checked into the repo. * fix: correct package name in version.py * fix: delete untracked generated files * docs: clean up classify single endpoint examples * fix: correlation ids now limited to 63 * docs: fix single example * style: ignore style in doc files --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 5a7afa3 commit d4b5692

35 files changed

Lines changed: 1189 additions & 721 deletions

.github/workflows/ci.yml

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,11 @@ jobs:
5252
- name: Install the project
5353
run: uv sync --locked --all-extras --dev
5454

55-
- name: Ensure no differences in generated code
56-
if: runner.os != 'Windows'
55+
- name: Generate protobuf code
5756
shell: bash
5857
run: |
5958
source .venv/bin/activate
60-
GENERATED_DIR="src/resolver_athena_client/generated"
61-
BACKUP_DIR="src/resolver_athena_client/generated_backup"
62-
63-
cp -r $GENERATED_DIR $BACKUP_DIR
64-
65-
./scripts/compile_proto.sh || (echo "Protobuf compilation failed. Ensure submodules are initialized and the proto file exists." && exit 1)
66-
67-
# Fix imports in generated files
68-
if [[ -f "$GENERATED_DIR/athena/athena_pb2_grpc.py" && -f "$GENERATED_DIR/athena/athena_pb2.py" ]]; then
69-
sed -i.bak 's/^from athena /from resolver_athena_client.generated.athena /' "$GENERATED_DIR/athena/athena_pb2_grpc.py"
70-
sed -i.bak 's/^from athena /from resolver_athena_client.generated.athena /' "$GENERATED_DIR/athena/athena_pb2.py"
71-
rm -f "$GENERATED_DIR/athena/athena_pb2_grpc.py.bak" "$GENERATED_DIR/athena/athena_pb2.py.bak"
72-
else
73-
echo "Error: Expected files not found in $GENERATED_DIR/athena"
74-
exit 1
75-
fi
76-
77-
diff -r $GENERATED_DIR $BACKUP_DIR || (echo "Generated code differs. Please commit the changes after running compile_proto.sh." && exit 1)
78-
79-
rm -rf $BACKUP_DIR
59+
./scripts/compile_proto.sh || (echo "Protobuf compilation failed. Ensure submodules are initialized and proto files exist." && exit 1)
8060
8161
- name: Run linter
8262
run: |
@@ -123,6 +103,17 @@ jobs:
123103
with:
124104
enable-cache: true
125105

106+
- name: Set up Python
107+
run: uv python install
108+
109+
- name: Install the project
110+
run: uv sync --locked --all-extras --dev
111+
112+
- name: Generate protobuf code
113+
run: |
114+
source .venv/bin/activate
115+
./scripts/compile_proto.sh || (echo "Protobuf compilation failed. Ensure submodules are initialized and proto files exist." && exit 1)
116+
126117
- name: Set version from tag
127118
run: |
128119
if [[ "$GITHUB_REF" == refs/tags/* ]]; then

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ wheels/
1111
.env
1212

1313
docs/_build/
14+
15+
# Generated protobuf code
16+
src/resolver_athena_client/generated/

AGENTS.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
- Format code: `ruff format`
99
- Lint code: `ruff check`
1010
- Install git hooks: `pre-commit install`
11-
- Compile protobufs: `bash scripts/compile_proto.sh` (run from root)
11+
- Compile protobufs: `bash scripts/compile_proto.sh` (run from root, required for local development)
1212

1313
## Code style
1414
- Use Python type hints throughout
@@ -30,7 +30,7 @@
3030

3131
## PR instructions
3232
- Title format: [component] Description
33-
- Run `ruff check`, `pyright`, and `pytest` before committing
33+
- Run `bash scripts/compile_proto.sh`, `ruff check`, `pyright`, and `pytest` before committing
3434
- Keep PRs focused on a single change
3535
- Add tests for new functionality
3636
- Update documentation for API changes
@@ -40,7 +40,8 @@
4040
- Use `uv` package manager instead of pip
4141
- Don't use `uv pip` commands, just the base `uv` commands
4242
- Run formatters before committing
43-
- Check generated code in `src/athena_client/generated/`
43+
- Generated code is built automatically in CI, but run `bash scripts/compile_proto.sh` locally for development
44+
- Generated code in `src/resolver_athena_client/generated/` is not committed to the repo
4445
- Add error handling at each pipeline stage
4546
- Use correlation IDs for request tracing
4647

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
#!/usr/bin/env python3
2+
"""Example script demonstrating the classify_single method."""
3+
4+
import asyncio
5+
import logging
6+
import os
7+
import sys
8+
import uuid
9+
from pathlib import Path
10+
11+
from create_image import create_test_image
12+
from dotenv import load_dotenv
13+
14+
from resolver_athena_client.client.athena_client import AthenaClient
15+
from resolver_athena_client.client.athena_options import AthenaOptions
16+
from resolver_athena_client.client.channel import (
17+
CredentialHelper,
18+
create_channel_with_credentials,
19+
)
20+
from resolver_athena_client.client.models import ImageData
21+
22+
23+
async def classify_single_image_example(
24+
logger: logging.Logger,
25+
options: AthenaOptions,
26+
credential_helper: CredentialHelper,
27+
image_path: str | None = None,
28+
) -> bool:
29+
"""Demonstrate single image classification.
30+
31+
Args:
32+
logger: Logger instance for output
33+
options: Configuration options for the Athena client
34+
credential_helper: OAuth credential helper for authentication
35+
image_path: Path to image file to classify (optional)
36+
37+
Returns:
38+
True if classification was successful, False otherwise
39+
40+
"""
41+
# Create gRPC channel with credentials
42+
channel = await create_channel_with_credentials(
43+
options.host, credential_helper
44+
)
45+
46+
async with AthenaClient(channel, options) as client:
47+
# Load image data
48+
if image_path and Path(image_path).exists():
49+
logger.info("Loading image from: %s", image_path)
50+
image_bytes = Path(image_path).read_bytes()
51+
else:
52+
# Create a simple test image if no path provided
53+
logger.info("Creating synthetic test image")
54+
image_bytes = create_test_image()
55+
56+
# Create ImageData object
57+
image_data = ImageData(image_bytes)
58+
logger.info(
59+
"Image loaded: %d bytes, MD5: %s",
60+
len(image_data.data),
61+
image_data.md5_hashes[0][:8] + "...",
62+
)
63+
64+
try:
65+
# Classify the single image
66+
logger.info("Classifying single image...")
67+
correlation_id = uuid.uuid4().hex[:63]
68+
logger.info("Correlation ID: %s", correlation_id)
69+
result = await client.classify_single(
70+
image_data, correlation_id=correlation_id
71+
)
72+
73+
# Process the result
74+
logger.info("Classification completed successfully!")
75+
76+
if result.error.code:
77+
logger.error(
78+
"Classification error: %s (%s)",
79+
result.error.message,
80+
result.error.code,
81+
)
82+
if result.error.details:
83+
logger.error("Error details: %s", result.error.details)
84+
return False
85+
86+
if result.classifications:
87+
logger.info(
88+
"Found %d classifications:", len(result.classifications)
89+
)
90+
for i, classification in enumerate(result.classifications, 1):
91+
logger.info(
92+
" %d. Label: %s, Weight: %.3f",
93+
i,
94+
classification.label,
95+
classification.weight,
96+
)
97+
else:
98+
logger.info("No classifications found for this image")
99+
100+
except Exception:
101+
logger.exception("Error during single image classification")
102+
return False
103+
else:
104+
return True
105+
106+
107+
async def classify_multiple_single_images_example(
108+
logger: logging.Logger,
109+
options: AthenaOptions,
110+
credential_helper: CredentialHelper,
111+
num_images: int = 3,
112+
) -> int:
113+
"""Demonstrate classifying multiple images individually.
114+
115+
This shows how classify_single can be used for multiple images
116+
when you want individual control over each classification request.
117+
118+
Args:
119+
logger: Logger instance for output
120+
options: Configuration options for the Athena client
121+
credential_helper: OAuth credential helper for authentication
122+
num_images: Number of test images to classify
123+
124+
Returns:
125+
Number of successfully classified images
126+
127+
"""
128+
# Create gRPC channel with credentials
129+
channel = await create_channel_with_credentials(
130+
options.host, credential_helper
131+
)
132+
133+
successful_count = 0
134+
135+
async with AthenaClient(channel, options) as client:
136+
logger.info("Classifying %d images individually...", num_images)
137+
138+
for i in range(num_images):
139+
try:
140+
# Create a unique test image for each iteration
141+
image_bytes = create_test_image(seed=i)
142+
image_data = ImageData(image_bytes)
143+
144+
# Classify with auto-generated correlation ID
145+
result = await client.classify_single(image_data)
146+
147+
logger.info(
148+
"Image %d/%d - Correlation: %s",
149+
i + 1,
150+
num_images,
151+
result.correlation_id[:8] + "...",
152+
)
153+
154+
if result.error.code:
155+
logger.warning(
156+
"Image %d failed: %s", i + 1, result.error.message
157+
)
158+
elif result.classifications:
159+
top_classification = max(
160+
result.classifications, key=lambda c: c.weight
161+
)
162+
logger.info(
163+
"Image %d - Top result: %s (%.3f)",
164+
i + 1,
165+
top_classification.label,
166+
top_classification.weight,
167+
)
168+
successful_count += 1
169+
else:
170+
logger.info("Image %d - No classifications", i + 1)
171+
successful_count += 1
172+
173+
except Exception: # noqa: PERF203
174+
logger.exception("Failed to classify image %d", i + 1)
175+
176+
logger.info(
177+
"Completed: %d/%d images classified successfully",
178+
successful_count,
179+
num_images,
180+
)
181+
return successful_count
182+
183+
184+
async def main() -> int:
185+
"""Run the classify_single examples."""
186+
logger = logging.getLogger(__name__)
187+
load_dotenv()
188+
189+
# OAuth credentials from environment
190+
client_id = os.getenv("OAUTH_CLIENT_ID")
191+
client_secret = os.getenv("OAUTH_CLIENT_SECRET")
192+
auth_url = os.getenv(
193+
"OAUTH_AUTH_URL", "https://crispthinking.auth0.com/oauth/token"
194+
)
195+
audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-dev")
196+
197+
if not client_id or not client_secret:
198+
logger.error("OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET must be set")
199+
return 1
200+
201+
host = os.getenv("ATHENA_HOST", "localhost")
202+
logger.info("Connecting to %s", host)
203+
204+
# Create credential helper
205+
credential_helper = CredentialHelper(
206+
client_id=client_id,
207+
client_secret=client_secret,
208+
auth_url=auth_url,
209+
audience=audience,
210+
)
211+
212+
# Test token acquisition
213+
try:
214+
logger.info("Acquiring OAuth token...")
215+
token = await credential_helper.get_token()
216+
logger.info("Successfully acquired token (length: %d)", len(token))
217+
except Exception:
218+
logger.exception("Failed to acquire OAuth token")
219+
return 1
220+
221+
# Configure client options
222+
options = AthenaOptions(
223+
host=host,
224+
resize_images=True,
225+
compress_images=True,
226+
timeout=30.0, # Shorter timeout for single requests
227+
affiliate="Crisp",
228+
deployment_id="single-example-deployment", # Not used
229+
)
230+
231+
try:
232+
# Example 1: Classify a single image
233+
logger.info("\n=== Example 1: Single Image Classification ===")
234+
success = await classify_single_image_example(
235+
logger,
236+
options,
237+
credential_helper,
238+
image_path=os.getenv("TEST_IMAGE_PATH"), # Optional image path
239+
)
240+
241+
if not success:
242+
logger.error("Single image classification failed")
243+
return 1
244+
245+
# Example 2: Classify multiple images individually
246+
logger.info("\n=== Example 2: Multiple Individual Classifications ===")
247+
successful_count = await classify_multiple_single_images_example(
248+
logger, options, credential_helper, num_images=5
249+
)
250+
251+
if successful_count == 0:
252+
logger.error("No images were successfully classified")
253+
return 1
254+
255+
logger.info("\n=== All examples completed successfully! ===")
256+
257+
except Exception:
258+
logger.exception("Examples failed")
259+
return 1
260+
else:
261+
return 0
262+
263+
264+
if __name__ == "__main__":
265+
logging.basicConfig(
266+
level=logging.INFO,
267+
format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
268+
datefmt="%H:%M:%S",
269+
)
270+
271+
sys.exit(asyncio.run(main()))

examples/create_image.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,23 @@ async def iter_images(
139139
counter[0] += 1
140140
yield ImageData(img_bytes)
141141
count += 1
142+
143+
144+
def create_test_image(
145+
width: int = 160, height: int = 120, seed: int | None = None
146+
) -> bytes:
147+
"""Create a test image with specified dimensions and optional seed.
148+
149+
Args:
150+
width: Width of the test image in pixels (default: 160)
151+
height: Height of the test image in pixels (default: 120)
152+
seed: Optional seed for reproducible image generation
153+
154+
Returns:
155+
PNG image bytes
156+
157+
"""
158+
if seed is not None:
159+
_rng.seed(seed)
160+
161+
return create_random_image(width, height)

0 commit comments

Comments
 (0)