Skip to content

Commit d3614f3

Browse files
anna-singleton-resolveranna-singleton-resolver
andauthored
more configuration options for the client (#92)
* feat: correlationid can be specified on ImageData * feat: set brotli compression quality * style: lint * fix: fixup example code * feat: configurable resizing algorithm --------- Co-authored-by: anna-singleton-resolver <anna.singleton@resolver.com>
1 parent 26c8819 commit d3614f3

6 files changed

Lines changed: 52 additions & 26 deletions

File tree

examples/classify_single_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ async def classify_single_image_example(
6969
logger.info("Classifying single image...")
7070
correlation_id = uuid.uuid4().hex[:MAX_DEPLOYMENT_ID_LENGTH]
7171
logger.info("Correlation ID: %s", correlation_id)
72-
result = await client.classify_single(
73-
image_data, correlation_id=correlation_id
74-
)
72+
image_data.correlation_id = correlation_id # Optional
73+
result = await client.classify_single(image_data)
7574

7675
# Process the result
7776
logger.info("Classification completed successfully!")

src/resolver_athena_client/client/athena_client.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def image_stream():
121121
yield response
122122

123123
async def classify_single(
124-
self, image_data: ImageData, correlation_id: str | None = None
124+
self, image_data: ImageData
125125
) -> ClassificationOutput:
126126
"""Classify a single image synchronously without deployment context.
127127
@@ -169,18 +169,19 @@ async def classify_single(
169169
f"Weight: {classification.weight}")
170170
171171
"""
172-
if correlation_id is None:
173-
correlation_id = str(uuid.uuid4())
174-
175172
processed_image = image_data
176173

177174
# Apply image resizing if enabled
178175
if self.options.resize_images:
179-
processed_image = await resize_image(processed_image)
176+
processed_image = await resize_image(
177+
processed_image, self.options.resampling_algorithm
178+
)
180179

181180
# Apply compression if enabled
182181
if self.options.compress_images:
183-
processed_image = compress_image(processed_image)
182+
processed_image = compress_image(
183+
processed_image, self.options.compression_quality
184+
)
184185

185186
request_encoding = (
186187
RequestEncoding.REQUEST_ENCODING_BROTLI
@@ -196,7 +197,7 @@ async def classify_single(
196197

197198
classification_input = ClassificationInput(
198199
affiliate=self.options.affiliate,
199-
correlation_id=correlation_id,
200+
correlation_id=processed_image.correlation_id or str(uuid.uuid4()),
200201
encoding=request_encoding,
201202
data=processed_image.data,
202203
format=image_format,
@@ -240,13 +241,17 @@ async def transform_image(image_data: ImageData) -> ClassificationInput:
240241
"""Transform a single image through the full pipeline."""
241242
# Apply image resizing if enabled
242243
if self.options.resize_images:
243-
resized_image = await resize_image(image_data)
244+
resized_image = await resize_image(
245+
image_data, self.options.resampling_algorithm
246+
)
244247
else:
245248
resized_image = image_data
246249

247250
# Apply compression if enabled
248251
if self.options.compress_images:
249-
compressed_image = compress_image(resized_image)
252+
compressed_image = compress_image(
253+
resized_image, self.options.compression_quality
254+
)
250255
else:
251256
compressed_image = resized_image
252257

@@ -257,19 +262,23 @@ async def transform_image(image_data: ImageData) -> ClassificationInput:
257262
else RequestEncoding.REQUEST_ENCODING_UNCOMPRESSED
258263
)
259264

260-
# Create classification input directly
261-
correlation_provider = self.options.correlation_provider()
262-
263265
# Ensure we never send UNSPECIFIED format over the API
264266
image_format = compressed_image.image_format
265267
if image_format == ImageFormat.IMAGE_FORMAT_UNSPECIFIED:
266268
image_format = ImageFormat.IMAGE_FORMAT_RAW_UINT8_BGR
267269

270+
if compressed_image.correlation_id:
271+
correlation_id = compressed_image.correlation_id
272+
else:
273+
# Create classification input directly
274+
correlation_provider = self.options.correlation_provider()
275+
correlation_id = correlation_provider.get_correlation_id(
276+
compressed_image.data
277+
)
278+
268279
return ClassificationInput(
269280
affiliate=self.options.affiliate,
270-
correlation_id=correlation_provider.get_correlation_id(
271-
compressed_image.data
272-
),
281+
correlation_id=correlation_id,
273282
data=compressed_image.data,
274283
encoding=request_encoding,
275284
format=image_format,

src/resolver_athena_client/client/athena_options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from dataclasses import dataclass
44

5+
from PIL.Image import Resampling
6+
57
from resolver_athena_client.client.correlation import (
68
CorrelationProvider,
79
HashCorrelationProvider,
@@ -66,3 +68,5 @@ class AthenaOptions:
6668
correlation_provider: type[CorrelationProvider] = HashCorrelationProvider
6769
timeout: float | None = 120.0
6870
keepalive_interval: float | None = None
71+
compression_quality: int = 11 # Brotli quality level (0-11)
72+
resampling_algorithm: Resampling = Resampling.LANCZOS

src/resolver_athena_client/client/models/input_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,16 @@ async def image_stream():
6767
6868
"""
6969

70-
def __init__(self, image_bytes: bytes) -> None:
70+
def __init__(
71+
self, image_bytes: bytes, correlation_id: None | str = None
72+
) -> None:
7173
"""Initialize ImageData with bytes and calculate hashes.
7274
7375
Args:
7476
----
7577
image_bytes: The raw bytes of the image.
78+
correlation_id: Optional correlation ID to associate with this
79+
image data, if not provided, it will be generated by the client.
7680
7781
"""
7882
self.data: bytes = image_bytes
@@ -83,6 +87,7 @@ def __init__(self, image_bytes: bytes) -> None:
8387
hashlib.sha256(image_bytes).hexdigest()
8488
]
8589
self.md5_hashes: list[str] = [hashlib.md5(image_bytes).hexdigest()]
90+
self.correlation_id: None | str = correlation_id
8691

8792
def add_transformation_hashes(self) -> None:
8893
"""Add new hashes for the current data to track transformations.

src/resolver_athena_client/client/transformers/core.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@ def _is_raw_bgr_expected_size(data: bytes) -> bool:
2525
return len(data) == _expected_raw_size
2626

2727

28-
async def resize_image(image_data: ImageData) -> ImageData:
28+
async def resize_image(
29+
image_data: ImageData,
30+
sampling_algorithm: Image.Resampling = Image.Resampling.LANCZOS,
31+
) -> ImageData:
2932
"""Resize an image to expected dimensions.
3033
3134
Args:
3235
----
3336
image_data: The ImageData object to resize
37+
sampling_algorithm: The resampling algorithm to use for resizing.
38+
Defaults to LANCZOS.
3439
3540
Returns:
3641
-------
@@ -53,7 +58,7 @@ def process_image() -> tuple[bytes, bool]:
5358
# Resize if needed
5459
if rgb_image.size != _target_size:
5560
resized_image = rgb_image.resize(
56-
_target_size, Image.Resampling.LANCZOS
61+
_target_size, sampling_algorithm
5762
)
5863
else:
5964
resized_image = rgb_image
@@ -82,19 +87,21 @@ def process_image() -> tuple[bytes, bool]:
8287
return image_data
8388

8489

85-
def compress_image(image_data: ImageData) -> ImageData:
90+
def compress_image(image_data: ImageData, quality: int = 11) -> ImageData:
8691
"""Compress image data using Brotli compression.
8792
8893
Args:
8994
----
9095
image_data: The ImageData object to compress
96+
quality: Compression quality level (0-11), higher is better compression
97+
but slower. Default is 11 for maximum compression.
9198
9299
Returns:
93100
-------
94101
The same ImageData object with compressed data (modified in-place)
95102
96103
"""
97-
compressed_bytes = brotli.compress(image_data.data)
104+
compressed_bytes = brotli.compress(image_data.data, quality=quality)
98105
# Modify existing ImageData with compressed bytes but preserve hashes
99106
# since compression doesn't change image content
100107
image_data.data = compressed_bytes

tests/test_classify_single.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,13 @@ async def test_classify_single_with_correlation_id(
122122
return_value=mock_output
123123
)
124124

125-
# Call classify_single with custom correlation ID
126-
_ = await athena_client.classify_single(
127-
sample_image_data, correlation_id=custom_correlation_id
125+
copied_image_data = ImageData(
126+
sample_image_data.data, correlation_id=custom_correlation_id
128127
)
129128

129+
# Call classify_single with custom correlation ID
130+
_ = await athena_client.classify_single(copied_image_data)
131+
130132
# Verify correlation ID was used
131133
call_args = athena_client.classifier.classify_single.call_args[0][0]
132134
assert call_args.correlation_id == custom_correlation_id

0 commit comments

Comments
 (0)