Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/classify_single_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from dotenv import load_dotenv

from examples.utils.image_generation import create_test_image
from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
from resolver_athena_client.client.channel import (
Expand All @@ -18,7 +19,6 @@
)
from resolver_athena_client.client.consts import MAX_DEPLOYMENT_ID_LENGTH
from resolver_athena_client.client.models import ImageData
from tests.utils.image_generation import create_test_image


async def classify_single_image_example(
Expand Down
145 changes: 73 additions & 72 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from dotenv import load_dotenv

from examples.utils.image_generation import iter_images
from examples.utils.streaming_classify_utils import count_and_yield
from resolver_athena_client.client.athena_client import AthenaClient
from resolver_athena_client.client.athena_options import AthenaOptions
from resolver_athena_client.client.channel import (
Expand All @@ -22,8 +24,9 @@
has_output_errors,
process_classification_outputs,
)
from tests.utils.image_generation import iter_images
from tests.utils.streaming_classify_utils import count_and_yield

# Constants
INITIAL_PROGRESS_THRESHOLD = 10
Comment thread
snus-kin marked this conversation as resolved.


async def run_oauth_example(
Expand Down Expand Up @@ -54,81 +57,63 @@ async def run_oauth_example(
received_count = 0

async with AthenaClient(channel, options) as client:
logger.info(
"Generating %s test images...", max_test_images or "unlimited"
)
results = client.classify_images(
count_and_yield(iter_images(max_test_images), sent_counter)
)

start_time = time.time()

try:
async for result in results:
received_count += len(result.outputs)

if received_count % 10 == 0:
elapsed = time.time() - start_time
rate = received_count / elapsed if elapsed > 0 else 0
logger.info(
"Sent %d requests, received %d responses (%.1f/sec)",
sent_counter[0],
received_count,
rate,
)

# Check for output errors and handle them
if has_output_errors(result):
error_summary = get_output_error_summary(result)
logger.warning(
"Received %d outputs with errors: %s",
sum(error_summary.values()),
error_summary,
)

# Process outputs, logging errors but continuing with
# successful ones
successful_outputs = process_classification_outputs(
result, raise_on_error=False, log_errors=True
logger.info("Starting to process classification results...")

async for result in results:
received_count += len(result.outputs)

# Progress logging
if (
received_count % 100 == 0
or received_count <= INITIAL_PROGRESS_THRESHOLD
):
elapsed = time.time() - start_time
rate = received_count / elapsed if elapsed > 0 else 0
logger.info(
"Received %d results (%.1f/sec)",
received_count,
rate,
)

for output in successful_outputs:
classifications = {
c.label: round(c.weight, 3)
for c in output.classifications
}
logger.debug(
"Result [%s]: %s",
output.correlation_id[:8],
classifications,
)
# Check for output errors and handle them
if has_output_errors(result):
error_summary = get_output_error_summary(result)
logger.warning(
"Received %d outputs with errors: %s",
sum(error_summary.values()),
error_summary,
)

except Exception:
logger.exception("Error during classification")
if received_count == 0:
raise
finally:
duration = time.time() - start_time
if received_count > 0:
avg_rate = received_count / duration if duration > 0 else 0
logger.info(
"Completed: sent=%d received=%d in %.1fs (%.1f/sec)",
sent_counter[0],
received_count,
duration,
avg_rate,
# Process outputs, logging errors but continuing with successful
# ones
successful_outputs = process_classification_outputs(
result, raise_on_error=False, log_errors=True
)

# Log individual classification results at INFO level
for i, output in enumerate(successful_outputs):
top_classification = max(
output.classifications,
key=lambda c: c.weight,
default=None,
)

if options.timeout and duration >= options.timeout * 0.95:
if top_classification:
logger.info(
"Stream reached maximum duration: %.1fs (limit: %.1fs)",
duration,
options.timeout,
)
elif options.timeout:
logger.info(
"Stream completed naturally in %.1fs (max: %.1fs)",
duration,
options.timeout,
"Classification %d [%s]: %s (confidence: %.3f)",
received_count - len(successful_outputs) + i + 1,
output.correlation_id[:8],
top_classification.label,
top_classification.weight,
)

return (sent_counter[0], received_count)


Expand All @@ -138,7 +123,7 @@ async def main() -> int:
_ = load_dotenv()

# Configuration
max_test_images = 10_000
max_test_images = 100

# OAuth credentials from environment
client_id = os.getenv("OAUTH_CLIENT_ID")
Expand Down Expand Up @@ -185,25 +170,41 @@ async def main() -> int:

logger.info("Using deployment: %s", deployment_id)

# Run classification with OAuth authentication
# Run classification with OAuth authentication - maximize resilience
options = AthenaOptions(
host=host,
resize_images=True,
deployment_id=deployment_id,
compress_images=True,
timeout=120.0, # Maximum duration, not forced timeout
keepalive_interval=30.0, # Longer intervals for persistent streams
timeout=None, # No timeout - allow infinite streaming
Comment thread
snus-kin marked this conversation as resolved.
Outdated
keepalive_interval=5.0, # Frequent keepalives for max resilience
affiliate=affiliate,
max_batch_size=10, # Test with larger batch size
)

sent, received = await run_oauth_example(
logger, options, credential_helper, max_test_images
)

if sent == received:
logger.info("Success: %d requests processed", sent)
# Final verification
if received >= sent:
if received == sent:
logger.info("✓ SUCCESS: Exact match - %d requests processed", sent)
Comment thread Dismissed
else:
logger.info(
"✓ SUCCESS: %d requests processed (sent %d + %d extra from "
"shared queue)",
received,
Comment thread Dismissed
sent,
Comment thread Dismissed
received - sent,
)
return 0
logger.warning("Incomplete: %d sent, %d received", sent, received)
logger.error(
"✗ INCOMPLETE: sent=%d received=%d (missing %d)",
sent,
Comment thread Dismissed
received,
Comment thread Dismissed
sent - received,
)
return 1


Expand Down
1 change: 1 addition & 0 deletions examples/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utility modules for examples."""
Loading