22
33import asyncio
44import contextlib
5+ from typing import cast
56from unittest import mock
67
78import pytest
@@ -64,9 +65,10 @@ async def test_classify_images_success(
6465
6566 # Setup mock classifier client
6667 with mock .patch (
67- "resolver_athena_client.client.athena_client.ClassifierServiceClient"
68+ "resolver_athena_client.client.athena_client.ClassifierServiceClient" ,
69+ spec = ClassifierServiceClient ,
6870 ) as mock_client_cls :
69- mock_client = mock_client_cls .return_value
71+ mock_client = cast ( "mock.MagicMock" , mock_client_cls .return_value )
7072
7173 # Create mock stream that returns our responses
7274 mock_classify = MockAsyncIterator (test_responses )
@@ -121,9 +123,10 @@ async def test_client_context_manager_success(
121123 ) # Success response will have default empty global_error
122124
123125 with mock .patch (
124- "resolver_athena_client.client.athena_client.ClassifierServiceClient"
126+ "resolver_athena_client.client.athena_client.ClassifierServiceClient" ,
127+ spec = ClassifierServiceClient ,
125128 ) as mock_client_cls :
126- mock_client = mock_client_cls .return_value
129+ mock_client = cast ( "mock.MagicMock" , mock_client_cls .return_value )
127130
128131 # Create mock stream that returns our response
129132 mock_classify = MockAsyncIterator ([init_response ])
@@ -157,7 +160,8 @@ async def get_one_response() -> None:
157160 await classify_task
158161
159162 # Verify channel was closed
160- mock_channel .close .assert_called_once ()
163+ close_mock = cast ("mock.MagicMock" , mock_channel .close )
164+ close_mock .assert_called_once ()
161165
162166
163167@pytest .mark .asyncio
@@ -176,9 +180,10 @@ async def test_client_context_manager_error(
176180 )
177181
178182 with mock .patch (
179- "resolver_athena_client.client.athena_client.ClassifierServiceClient"
183+ "resolver_athena_client.client.athena_client.ClassifierServiceClient" ,
184+ spec = ClassifierServiceClient ,
180185 ) as mock_client_cls :
181- mock_client = mock_client_cls .return_value
186+ mock_client = cast ( "mock.MagicMock" , mock_client_cls .return_value )
182187
183188 # Create mock stream that returns our error response
184189 mock_classify = MockAsyncIterator ([error_response ])
@@ -225,9 +230,10 @@ async def test_client_transformers_disabled(
225230 )
226231
227232 with mock .patch (
228- "resolver_athena_client.client.athena_client.ClassifierServiceClient"
233+ "resolver_athena_client.client.athena_client.ClassifierServiceClient" ,
234+ spec = ClassifierServiceClient ,
229235 ) as mock_client_cls :
230- mock_client = mock_client_cls .return_value
236+ mock_client = cast ( "mock.MagicMock" , mock_client_cls .return_value )
231237 mock_classify = MockAsyncIterator ([test_response ])
232238 mock_client .classify = mock_classify
233239
@@ -277,9 +283,10 @@ async def test_client_transformers_enabled(
277283 )
278284
279285 with mock .patch (
280- "resolver_athena_client.client.athena_client.ClassifierServiceClient"
286+ "resolver_athena_client.client.athena_client.ClassifierServiceClient" ,
287+ spec = ClassifierServiceClient ,
281288 ) as mock_client_cls :
282- mock_client = mock_client_cls .return_value
289+ mock_client = cast ( "mock.MagicMock" , mock_client_cls .return_value )
283290 mock_classify = MockAsyncIterator ([test_response ])
284291 mock_client .classify = mock_classify
285292
@@ -337,13 +344,14 @@ async def test_client_num_workers_configuration(
337344
338345 with (
339346 mock .patch (
340- "resolver_athena_client.client.athena_client.ClassifierServiceClient"
347+ "resolver_athena_client.client.athena_client.ClassifierServiceClient" ,
348+ spec = ClassifierServiceClient ,
341349 ) as mock_client_cls ,
342350 mock .patch (
343351 "resolver_athena_client.client.athena_client.WorkerBatcher"
344352 ) as mock_worker_batcher_cls ,
345353 ):
346- mock_client = mock_client_cls .return_value
354+ mock_client = cast ( "mock.MagicMock" , mock_client_cls .return_value )
347355 mock_classify = MockAsyncIterator ([test_response ])
348356 mock_client .classify = mock_classify
349357
@@ -391,4 +399,5 @@ async def test_client_close(
391399
392400 await client .close ()
393401
394- mock_channel .close .assert_called_once ()
402+ close_mock = cast ("mock.MagicMock" , mock_channel .close )
403+ close_mock .assert_called_once ()
0 commit comments