Skip to content

Commit cff86c8

Browse files
snus-kinCopilot
andauthored
feat: aio workers + keep streams alive (#69)
* feat: aio workers + keep streams alive * Update examples/example.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/resolver_athena_client/client/transformers/worker_batcher.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * feat: some cleanup with logging levels and dead code * feat: cleanup some low hanging fruit * Clean up batching system: remove RequestBatcher and BaseBatcher, use WorkerBatcher for all batching needs * Update src/resolver_athena_client/client/athena_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/resolver_athena_client/client/transformers/worker_batcher.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/resolver_athena_client/client/transformers/worker_batcher.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/resolver_athena_client/client/athena_options.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add configurable num_workers option to AthenaOptions - Add num_workers field to AthenaOptions with default value of 5 - Update documentation with usage guidelines for CPU vs I/O bound scenarios - Add test to verify num_workers configuration is properly passed to WorkerBatcher - Allows users to tune concurrent processing based on their hardware and workload * Update src/resolver_athena_client/client/transformers/worker_batcher.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * style: fix lint issue * feat: remove unused config * feat: this * feat: don't over-handle errors * feat: improve closing logic --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 417bea0 commit cff86c8

25 files changed

Lines changed: 1884 additions & 1661 deletions

examples/classify_single_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from dotenv import load_dotenv
1212

13+
from examples.utils.image_generation import create_test_image
1314
from resolver_athena_client.client.athena_client import AthenaClient
1415
from resolver_athena_client.client.athena_options import AthenaOptions
1516
from resolver_athena_client.client.channel import (
@@ -18,7 +19,6 @@
1819
)
1920
from resolver_athena_client.client.consts import MAX_DEPLOYMENT_ID_LENGTH
2021
from resolver_athena_client.client.models import ImageData
21-
from tests.utils.image_generation import create_test_image
2222

2323

2424
async def classify_single_image_example(

examples/example.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from dotenv import load_dotenv
1212

13+
from examples.utils.image_generation import iter_images
14+
from examples.utils.streaming_classify_utils import count_and_yield
1315
from resolver_athena_client.client.athena_client import AthenaClient
1416
from resolver_athena_client.client.athena_options import AthenaOptions
1517
from resolver_athena_client.client.channel import (
@@ -22,8 +24,9 @@
2224
has_output_errors,
2325
process_classification_outputs,
2426
)
25-
from tests.utils.image_generation import iter_images
26-
from tests.utils.streaming_classify_utils import count_and_yield
27+
28+
# Constants
29+
INITIAL_PROGRESS_THRESHOLD = 10
2730

2831

2932
async def run_oauth_example(
@@ -53,81 +56,77 @@ async def run_oauth_example(
5356
sent_counter = [0] # Use list to allow mutation in closure
5457
received_count = 0
5558

56-
async with AthenaClient(channel, options) as client:
59+
client = AthenaClient(channel, options)
60+
try:
61+
logger.info(
62+
"Generating %s test images...", max_test_images or "unlimited"
63+
)
5764
results = client.classify_images(
5865
count_and_yield(iter_images(max_test_images), sent_counter)
5966
)
6067

6168
start_time = time.time()
69+
logger.info("Starting to process classification results...")
70+
71+
async for result in results:
72+
received_count += len(result.outputs)
73+
74+
# Progress logging
75+
if (
76+
received_count % 100 == 0
77+
or received_count <= INITIAL_PROGRESS_THRESHOLD
78+
):
79+
elapsed = time.time() - start_time
80+
rate = received_count / elapsed if elapsed > 0 else 0
81+
logger.info(
82+
"Received %d results (%.1f/sec)",
83+
received_count,
84+
rate,
85+
)
6286

63-
try:
64-
async for result in results:
65-
received_count += len(result.outputs)
66-
67-
if received_count % 10 == 0:
68-
elapsed = time.time() - start_time
69-
rate = received_count / elapsed if elapsed > 0 else 0
70-
logger.info(
71-
"Sent %d requests, received %d responses (%.1f/sec)",
72-
sent_counter[0],
73-
received_count,
74-
rate,
75-
)
76-
77-
# Check for output errors and handle them
78-
if has_output_errors(result):
79-
error_summary = get_output_error_summary(result)
80-
logger.warning(
81-
"Received %d outputs with errors: %s",
82-
sum(error_summary.values()),
83-
error_summary,
84-
)
87+
# Check for output errors and handle them
88+
if has_output_errors(result):
89+
error_summary = get_output_error_summary(result)
90+
logger.warning(
91+
"Received %d outputs with errors: %s",
92+
sum(error_summary.values()),
93+
error_summary,
94+
)
8595

86-
# Process outputs, logging errors but continuing with
87-
# successful ones
88-
successful_outputs = process_classification_outputs(
89-
result, raise_on_error=False, log_errors=True
96+
# Process outputs, logging errors but continuing with successful
97+
# ones
98+
successful_outputs = process_classification_outputs(
99+
result, raise_on_error=False, log_errors=True
100+
)
101+
102+
# Log individual classification results at INFO level
103+
for i, output in enumerate(successful_outputs):
104+
top_classification = max(
105+
output.classifications,
106+
key=lambda c: c.weight,
107+
default=None,
90108
)
91109

92-
for output in successful_outputs:
93-
classifications = {
94-
c.label: round(c.weight, 3)
95-
for c in output.classifications
96-
}
97-
logger.debug(
98-
"Result [%s]: %s",
110+
if top_classification:
111+
logger.info(
112+
"Classification %d [%s]: %s (confidence: %.3f)",
113+
received_count - len(successful_outputs) + i + 1,
99114
output.correlation_id[:8],
100-
classifications,
115+
top_classification.label,
116+
top_classification.weight,
101117
)
102118

103-
except Exception:
104-
logger.exception("Error during classification")
105-
if received_count == 0:
106-
raise
107-
finally:
108-
duration = time.time() - start_time
109-
if received_count > 0:
110-
avg_rate = received_count / duration if duration > 0 else 0
119+
# Close client when we've received all expected outputs
120+
if received_count >= sent_counter[0]:
111121
logger.info(
112-
"Completed: sent=%d received=%d in %.1fs (%.1f/sec)",
113-
sent_counter[0],
122+
"Received %d outputs matching %d inputs - closing client",
114123
received_count,
115-
duration,
116-
avg_rate,
124+
sent_counter[0],
117125
)
126+
break
118127

119-
if options.timeout and duration >= options.timeout * 0.95:
120-
logger.info(
121-
"Stream reached maximum duration: %.1fs (limit: %.1fs)",
122-
duration,
123-
options.timeout,
124-
)
125-
elif options.timeout:
126-
logger.info(
127-
"Stream completed naturally in %.1fs (max: %.1fs)",
128-
duration,
129-
options.timeout,
130-
)
128+
finally:
129+
await client.close()
131130

132131
return (sent_counter[0], received_count)
133132

@@ -138,7 +137,7 @@ async def main() -> int:
138137
_ = load_dotenv()
139138

140139
# Configuration
141-
max_test_images = 10_000
140+
max_test_images = 100
142141

143142
# OAuth credentials from environment
144143
client_id = os.getenv("OAUTH_CLIENT_ID")
@@ -191,19 +190,34 @@ async def main() -> int:
191190
resize_images=True,
192191
deployment_id=deployment_id,
193192
compress_images=True,
194-
timeout=120.0, # Maximum duration, not forced timeout
195-
keepalive_interval=30.0, # Longer intervals for persistent streams
193+
keepalive_interval=5.0,
196194
affiliate=affiliate,
195+
max_batch_size=10,
197196
)
198197

199198
sent, received = await run_oauth_example(
200199
logger, options, credential_helper, max_test_images
201200
)
202201

203-
if sent == received:
204-
logger.info("Success: %d requests processed", sent)
202+
# Final verification
203+
if received >= sent:
204+
if received == sent:
205+
logger.info("✓ SUCCESS: Exact match - %d requests processed", sent)
206+
else:
207+
logger.info(
208+
"✓ SUCCESS: %d requests processed (sent %d + %d extra from "
209+
"shared queue)",
210+
received,
211+
sent,
212+
received - sent,
213+
)
205214
return 0
206-
logger.warning("Incomplete: %d sent, %d received", sent, received)
215+
logger.error(
216+
"✗ INCOMPLETE: sent=%d received=%d (missing %d)",
217+
sent,
218+
received,
219+
sent - received,
220+
)
207221
return 1
208222

209223

examples/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Utility modules for examples."""

0 commit comments

Comments
 (0)