Skip to content

Commit 695b894

Browse files
committed
fix: fix typechecking errors (still warnings)
1 parent e7d3393 commit 695b894

8 files changed

Lines changed: 52 additions & 26 deletions

File tree

src/resolver_athena_client/client/deployment_selector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class DeploymentSelector:
2626
2727
"""
2828

29-
channel: grpc.aio.Channel
29+
channel: grpc.aio.Channel | None = None
3030
classifier: ClassifierServiceClient
3131

3232
def __init__(self, channel: grpc.aio.Channel) -> None:
@@ -37,7 +37,7 @@ def __init__(self, channel: grpc.aio.Channel) -> None:
3737
the Athena service.
3838
3939
"""
40-
self.logger = logging.getLogger(__name__)
40+
self.logger: logging.Logger = logging.getLogger(__name__)
4141
self.classifier = ClassifierServiceClient(channel)
4242

4343
async def list_deployments(self) -> ListDeploymentsResponse:
@@ -90,5 +90,5 @@ async def __aexit__(
9090
exc_tb: The traceback of the exception that was raised
9191
9292
"""
93-
if hasattr(self, "channel"):
93+
if hasattr(self, "channel") and self.channel is not None:
9494
await self.channel.close()

src/resolver_athena_client/grpc_wrappers/classifier_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Low-level GRPC client for the ClassifierService."""
22

33
from collections.abc import AsyncIterable
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, final
55

66
from google.protobuf.empty_pb2 import Empty
77
from grpc import aio
@@ -21,6 +21,7 @@
2121
from grpc.aio import StreamStreamCall
2222

2323

24+
@final
2425
class ClassifierServiceClient:
2526
"""Low-level gRPC wrapper for the ClassifierService."""
2627

tests/client/test_timeout_behavior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class MockGrpcError(AioRpcError):
2323
"""Mock gRPC error for testing."""
2424

25-
def __init__(self, code: StatusCode, details: str | None = None) -> None:
25+
def __init__(self, code: StatusCode, details: str | None = None) -> None: # pyright: ignore[reportMissingSuperCall] - Mock
2626
self._code = code
2727
self._details: str | None = details
2828
self._debug_error_string: str | None = f"MockGrpcError: {code.name}"

tests/test_classify_single.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async def test_classify_single_with_correlation_id(
123123
)
124124

125125
# Call classify_single with custom correlation ID
126-
await athena_client.classify_single(
126+
_ = await athena_client.classify_single(
127127
sample_image_data, correlation_id=custom_correlation_id
128128
)
129129

@@ -147,14 +147,16 @@ async def test_classify_single_auto_correlation_id(
147147
)
148148

149149
# Call classify_single without correlation ID
150-
await athena_client.classify_single(sample_image_data)
150+
_ = await athena_client.classify_single(sample_image_data)
151151

152152
# Verify a correlation ID was generated
153153
call_args = athena_client.classifier.classify_single.call_args[0][0]
154154
assert call_args.correlation_id is not None
155155
assert len(call_args.correlation_id) > 0
156156
# Should be a valid UUID format
157-
uuid.UUID(call_args.correlation_id) # This will raise if not a valid UUID
157+
_ = uuid.UUID(
158+
call_args.correlation_id
159+
) # This will raise if not a valid UUID
158160

159161

160162
@pytest.mark.asyncio
@@ -175,7 +177,7 @@ async def test_classify_single_with_compression(
175177
)
176178

177179
# Call classify_single
178-
await athena_client.classify_single(sample_image_data)
180+
_ = await athena_client.classify_single(sample_image_data)
179181

180182
# Verify compression settings were applied
181183
call_args = athena_client.classifier.classify_single.call_args[0][0]
@@ -210,7 +212,7 @@ async def test_classify_single_error_handling(
210212
)
211213

212214
# Call classify_single with valid image
213-
await athena_client.classify_single(valid_image_data)
215+
_ = await athena_client.classify_single(valid_image_data)
214216

215217
# Verify resizing was processed (encoding should be uncompressed)
216218
call_args = athena_client.classifier.classify_single.call_args[0][0]
@@ -239,7 +241,7 @@ async def test_classify_single_with_error_response(
239241

240242
# Call classify_single and expect AthenaError
241243
with pytest.raises(AthenaError, match="Image is too large"):
242-
await athena_client.classify_single(sample_image_data)
244+
_ = await athena_client.classify_single(sample_image_data)
243245

244246

245247
@pytest.mark.asyncio
@@ -276,7 +278,7 @@ async def test_classify_single_timeout_parameter(
276278
)
277279

278280
# Call classify_single
279-
await athena_client.classify_single(sample_image_data)
281+
_ = await athena_client.classify_single(sample_image_data)
280282

281283
# Verify timeout was passed
282284
call_kwargs = athena_client.classifier.classify_single.call_args[1]
@@ -306,7 +308,7 @@ async def test_classify_single_multiple_hashes(
306308
)
307309

308310
# Call classify_single
309-
await athena_client.classify_single(image_data)
311+
_ = await athena_client.classify_single(image_data)
310312

311313
# Verify all hashes were included
312314
call_args = athena_client.classifier.classify_single.call_args[0][0]

tests/test_correlation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for correlation ID generation."""
22

3+
from typing import final, override
4+
35
import pytest
46

57
from resolver_athena_client.client.correlation import HashCorrelationProvider
@@ -47,15 +49,19 @@ def test_hash_correlation_provider_with_invalid_input() -> None:
4749
provider = HashCorrelationProvider()
4850

4951
# Create an object that raises an exception when converted to string
50-
class BadStr:
52+
@final
53+
class BadStr(str):
54+
__slots__ = ()
55+
56+
@override
5157
def __str__(self) -> str:
5258
error_msg = "Cannot convert to string"
5359
raise ValueError(error_msg)
5460

5561
with pytest.raises(
5662
ValueError, match="Failed to generate correlation ID from input"
5763
):
58-
provider.get_correlation_id(BadStr()) # type: ignore[arg-type]
64+
_ = provider.get_correlation_id(BadStr())
5965

6066

6167
def test_hash_correlation_provider_consistency() -> None:

tests/test_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_version_from_metadata() -> None:
1313
mock_version.return_value = "1.2.3"
1414
# Force reload of version module to get mocked value
1515

16-
importlib.reload(resolver_athena_client.version)
16+
_ = importlib.reload(resolver_athena_client.version)
1717
assert resolver_athena_client.version.__version__ == "1.2.3"
1818

1919

tests/utils/mock_async_iterator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
class MockAsyncIterator(Generic[T]):
88
def __init__(self, items: list[T]) -> None:
9-
self._items = items.copy()
10-
self.call_count = 0
9+
self._items: list[T] = items.copy()
10+
self.call_count: int = 0
11+
self._timeout: float | None = None
1112

1213
async def __call__(
1314
self,
1415
_: AsyncIterable[bytes],
1516
*,
1617
timeout: float | None = None,
17-
) -> "MockAsyncIterator":
18+
) -> "MockAsyncIterator[T]":
1819
self.call_count += 1
1920
# Store timeout for potential use in testing
2021
self._timeout = timeout

tests/utils/mock_stream_call.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Mock gRPC stream call for testing."""
22

33
from collections.abc import AsyncIterator, Callable
4-
from typing import Generic, TypeVar
4+
from typing import Generic, TypeVar, override
55

66
import grpc
77
from grpc.aio import Call, Metadata, StreamStreamCall
@@ -24,8 +24,10 @@ class MockStreamCall(Generic[RequestT, ResponseT]):
2424

2525
def __init__(self, responses: list[ResponseT]) -> None:
2626
"""Initialize with response list."""
27-
self.responses = responses.copy()
28-
self.call_count = 0
27+
self.responses: list[ResponseT] = responses.copy()
28+
self.call_count: int = 0
29+
self._last_timeout: float | None = None
30+
self._last_wait_for_ready: bool = True
2931

3032
def __call__(
3133
self,
@@ -57,12 +59,13 @@ def __init__(
5759
) -> None:
5860
"""Initialize with request iterator and responses."""
5961
super().__init__()
60-
self._request_iter = request_iter
61-
self._responses = responses.copy()
62-
self._done = False
63-
self._cancelled = False
62+
self._request_iter: AsyncIterator[RequestT] = request_iter
63+
self._responses: list[ResponseT] = responses.copy()
64+
self._done: bool = False
65+
self._cancelled: bool = False
6466
self._done_callbacks: list[Callable[[Call], None]] = []
6567

68+
@override
6669
def __aiter__(self) -> AsyncIterator[ResponseT]:
6770
"""Get async iterator over responses."""
6871
return self
@@ -80,6 +83,7 @@ async def __anext__(self) -> ResponseT:
8083
raise StopAsyncIteration
8184
return self._responses.pop(0)
8285

86+
@override
8387
async def read(self) -> ResponseT:
8488
"""Read next response message.
8589
@@ -88,17 +92,20 @@ async def read(self) -> ResponseT:
8892
"""
8993
return await self.__anext__()
9094

95+
@override
9196
async def write(self, request: RequestT) -> None:
9297
"""Write a request message (no-op).
9398
9499
Args:
95100
request: Request message to write.
96101
"""
97102

103+
@override
98104
async def done_writing(self) -> None:
99105
"""Signal end of request stream."""
100106
self._done = True
101107

108+
@override
102109
def add_done_callback(self, callback: Callable[[Call], None]) -> None:
103110
"""Register completion callback.
104111
@@ -109,6 +116,7 @@ def add_done_callback(self, callback: Callable[[Call], None]) -> None:
109116
if self._done:
110117
callback(self)
111118

119+
@override
112120
def time_remaining(self) -> float | None:
113121
"""Get remaining timeout time.
114122
@@ -117,6 +125,7 @@ def time_remaining(self) -> float | None:
117125
"""
118126
return None
119127

128+
@override
120129
def cancel(self) -> bool:
121130
"""Cancel the call.
122131
@@ -131,6 +140,7 @@ def cancel(self) -> bool:
131140
return True
132141
return False
133142

143+
@override
134144
def cancelled(self) -> bool:
135145
"""Check if call was cancelled.
136146
@@ -139,6 +149,7 @@ def cancelled(self) -> bool:
139149
"""
140150
return self._cancelled
141151

152+
@override
142153
async def code(self) -> grpc.StatusCode:
143154
"""Get status code.
144155
@@ -147,6 +158,7 @@ async def code(self) -> grpc.StatusCode:
147158
"""
148159
return grpc.StatusCode.OK
149160

161+
@override
150162
async def details(self) -> str:
151163
"""Get error details.
152164
@@ -155,6 +167,7 @@ async def details(self) -> str:
155167
"""
156168
return ""
157169

170+
@override
158171
async def initial_metadata(self) -> Metadata:
159172
"""Get initial metadata.
160173
@@ -163,6 +176,7 @@ async def initial_metadata(self) -> Metadata:
163176
"""
164177
return Metadata()
165178

179+
@override
166180
async def trailing_metadata(self) -> Metadata:
167181
"""Get trailing metadata.
168182
@@ -171,6 +185,7 @@ async def trailing_metadata(self) -> Metadata:
171185
"""
172186
return Metadata()
173187

188+
@override
174189
def done(self) -> bool:
175190
"""Check if call is complete.
176191
@@ -179,5 +194,6 @@ def done(self) -> bool:
179194
"""
180195
return self._done
181196

197+
@override
182198
async def wait_for_connection(self) -> None:
183199
"""Wait for connection (no-op)."""

0 commit comments

Comments
 (0)