Skip to content

Commit 7b810cd

Browse files
snus-kinThomas CarrollCopilotcorpo-iwillspeak
authored
feat: enable reportAny in typechecker (#111)
* feat: enable reportAny * refactor: use TypedDict * style: lint fix --------- Co-authored-by: Thomas Carroll <tcarroll@Thomass-MacBook-Pro.local> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: corpo-iwillspeak <265613520+corpo-iwillspeak@users.noreply.github.com>
1 parent 38c88ee commit 7b810cd

13 files changed

Lines changed: 342 additions & 132 deletions

File tree

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
],
55
"python.testing.unittestEnabled": false,
66
"python.testing.pytestEnabled": true
7-
}
7+
}

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ venvPath = "."
7777
venv = ".venv"
7878
stubPath = "stubs"
7979
reportImplicitStringConcatenation = false
80-
reportAny = false
8180

8281
[tool.pytest.ini_options]
8382
markers = [

src/resolver_athena_client/client/athena_client.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ async def shutdown_worker(
410410
) -> None:
411411
"""Safely shutdown a single worker, handling mocks/errors."""
412412
try:
413-
shutdown_method = getattr(worker_batcher, "shutdown", None)
414-
if shutdown_method and callable(shutdown_method):
413+
if hasattr(worker_batcher, "shutdown"):
414+
shutdown_method = worker_batcher.shutdown
415415
shutdown_coro = shutdown_method()
416416
# Only await if it's actually a coroutine (not a mock)
417417
if asyncio.iscoroutine(shutdown_coro):
@@ -421,8 +421,7 @@ async def shutdown_worker(
421421
self.logger.debug(
422422
"Skipping non-coroutine shutdown method"
423423
)
424-
else:
425-
self.logger.debug("Worker has no shutdown method")
424+
426425
except (AttributeError, TypeError):
427426
# Worker doesn't have shutdown method or it's not callable
428427
self.logger.debug("Worker shutdown failed, skipping")
@@ -442,7 +441,7 @@ async def shutdown_worker(
442441
for result in results
443442
if isinstance(
444443
result,
445-
(asyncio.CancelledError, ConnectionError, OSError),
444+
asyncio.CancelledError | ConnectionError | OSError,
446445
)
447446
]
448447

src/resolver_athena_client/client/channel.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import threading
66
import time
77
from dataclasses import dataclass
8-
from typing import override
8+
from typing import cast, override
99

1010
import grpc
1111
import httpx
@@ -230,10 +230,18 @@ def _refresh_token(self) -> None:
230230
)
231231
_ = response.raise_for_status()
232232

233-
raw = response.json()
234-
access_token: str = raw["access_token"]
235-
expires_in: int = raw.get("expires_in", 3600) # Default 1 hour
236-
token_type = raw.get("token_type", "Bearer")
233+
raw = cast("dict[str, object]", response.json())
234+
access_token = str(raw["access_token"])
235+
expires_in_raw = raw.get("expires_in")
236+
expires_in: int = (
237+
int(cast("int", expires_in_raw))
238+
if expires_in_raw is not None
239+
else 3600
240+
)
241+
token_type_raw = raw.get("token_type")
242+
token_type: str = (
243+
str(token_type_raw) if token_type_raw is not None else "Bearer"
244+
)
237245
scheme: str = token_type.strip() if token_type else "Bearer"
238246
current_time = time.time()
239247
self._token_data = TokenData(
@@ -247,7 +255,7 @@ def _refresh_token(self) -> None:
247255
except httpx.HTTPStatusError as e:
248256
error_detail = ""
249257
try:
250-
error_data = e.response.json()
258+
error_data = cast("dict[str, str]", e.response.json())
251259
error_desc = error_data.get(
252260
"error_description", error_data.get("error", "")
253261
)

src/resolver_athena_client/client/transformers/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import enum
10+
from typing import cast
1011

1112
import brotli
1213
import cv2 as cv
@@ -72,7 +73,10 @@ def process_image() -> tuple[bytes, bool]:
7273
err = "Failed to decode image data for resizing"
7374
raise ValueError(err)
7475

75-
if img.shape[0] == EXPECTED_HEIGHT and img.shape[1] == EXPECTED_WIDTH:
76+
shape = cast("tuple[int, int, int]", img.shape)
77+
height: int = shape[0]
78+
width: int = shape[1]
79+
if height == EXPECTED_HEIGHT and width == EXPECTED_WIDTH:
7680
resized_img = img
7781
else:
7882
resized_img = cv.resize(

tests/client/test_athena_client.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextlib
5+
from typing import cast
56
from unittest import mock
67

78
import pytest
@@ -64,9 +65,10 @@ async def test_classify_images_success(
6465

6566
# Setup mock classifier client
6667
with mock.patch(
67-
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
68+
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
69+
spec=ClassifierServiceClient,
6870
) as mock_client_cls:
69-
mock_client = mock_client_cls.return_value
71+
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
7072

7173
# Create mock stream that returns our responses
7274
mock_classify = MockAsyncIterator(test_responses)
@@ -121,9 +123,10 @@ async def test_client_context_manager_success(
121123
) # Success response will have default empty global_error
122124

123125
with mock.patch(
124-
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
126+
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
127+
spec=ClassifierServiceClient,
125128
) as mock_client_cls:
126-
mock_client = mock_client_cls.return_value
129+
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
127130

128131
# Create mock stream that returns our response
129132
mock_classify = MockAsyncIterator([init_response])
@@ -157,7 +160,8 @@ async def get_one_response() -> None:
157160
await classify_task
158161

159162
# Verify channel was closed
160-
mock_channel.close.assert_called_once()
163+
close_mock = cast("mock.MagicMock", mock_channel.close)
164+
close_mock.assert_called_once()
161165

162166

163167
@pytest.mark.asyncio
@@ -176,9 +180,10 @@ async def test_client_context_manager_error(
176180
)
177181

178182
with mock.patch(
179-
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
183+
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
184+
spec=ClassifierServiceClient,
180185
) as mock_client_cls:
181-
mock_client = mock_client_cls.return_value
186+
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
182187

183188
# Create mock stream that returns our error response
184189
mock_classify = MockAsyncIterator([error_response])
@@ -225,9 +230,10 @@ async def test_client_transformers_disabled(
225230
)
226231

227232
with mock.patch(
228-
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
233+
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
234+
spec=ClassifierServiceClient,
229235
) as mock_client_cls:
230-
mock_client = mock_client_cls.return_value
236+
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
231237
mock_classify = MockAsyncIterator([test_response])
232238
mock_client.classify = mock_classify
233239

@@ -277,9 +283,10 @@ async def test_client_transformers_enabled(
277283
)
278284

279285
with mock.patch(
280-
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
286+
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
287+
spec=ClassifierServiceClient,
281288
) as mock_client_cls:
282-
mock_client = mock_client_cls.return_value
289+
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
283290
mock_classify = MockAsyncIterator([test_response])
284291
mock_client.classify = mock_classify
285292

@@ -337,13 +344,14 @@ async def test_client_num_workers_configuration(
337344

338345
with (
339346
mock.patch(
340-
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
347+
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
348+
spec=ClassifierServiceClient,
341349
) as mock_client_cls,
342350
mock.patch(
343351
"resolver_athena_client.client.athena_client.WorkerBatcher"
344352
) as mock_worker_batcher_cls,
345353
):
346-
mock_client = mock_client_cls.return_value
354+
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
347355
mock_classify = MockAsyncIterator([test_response])
348356
mock_client.classify = mock_classify
349357

@@ -391,4 +399,5 @@ async def test_client_close(
391399

392400
await client.close()
393401

394-
mock_channel.close.assert_called_once()
402+
close_mock = cast("mock.MagicMock", mock_channel.close)
403+
close_mock.assert_called_once()

0 commit comments

Comments
 (0)