1616from __future__ import annotations
1717
1818import asyncio
19- import contextlib
2019import functools
2120import socket as _socket
2221import ssl as _ssl
@@ -179,41 +178,45 @@ async def slow_sock_connect(sock, addr):
179178 f"socket leaked across cancellation: { sock !r} " ,
180179 )
181180
182- async def _assert_cancellation_closes_socket (
183- self ,
184- * ,
185- connection_creator ,
186- loop_method_name ,
187- make_slow ,
188- ssl_context = None ,
189- ):
181+ async def test_cancellation_closes_socket_during_ssl_wrap_socket (self ):
190182 address = (await async_client_context .host , await async_client_context .port )
191183 options = (await async_get_pool (self .client )).opts
184+ fake_ssl_context = _ssl .create_default_context ()
192185
193- created_sockets = []
186+ created_sockets : list [ _socket . socket ] = []
194187 real_socket_cls = _socket .socket
195188 target_task = None
196189
197- def is_target ():
198- return asyncio .current_task () is target_task
199-
200- def tracked_socket (* args , ** kwargs ):
190+ def tracking_socket (* args , ** kwargs ):
201191 s = real_socket_cls (* args , ** kwargs )
202- if is_target () :
192+ if asyncio . current_task () is target_task :
203193 created_sockets .append (s )
204194 return s
205195
206196 loop = asyncio .get_running_loop ()
197+ real_run_in_executor = loop .run_in_executor
207198 started = asyncio .Event ()
208- block_forever = asyncio .Event ()
209- slow_method = make_slow (getattr (loop , loop_method_name ), started , block_forever , is_target )
210-
211- with contextlib .ExitStack () as stack :
212- stack .enter_context (patch .object (_socket , "socket" , tracked_socket ))
213- stack .enter_context (patch .object (loop , loop_method_name , slow_method ))
214- if ssl_context is not None :
215- stack .enter_context (patch .object (options , "_PoolOptions__ssl_context" , ssl_context ))
216- task = asyncio .create_task (connection_creator (address , options ))
199+
200+ def slow_run_in_executor (executor , func , * args ):
201+ # Need to unwrap the SNI branch here if present
202+ inner = func .func if isinstance (func , functools .partial ) else func
203+ # Each `ctx.wrap_socket` access returns a fresh bound-method
204+ # object, so we check the bound instance (__self__) instead
205+ if (
206+ getattr (inner , "__self__" , None ) is fake_ssl_context
207+ and asyncio .current_task () is target_task
208+ ):
209+ started .set ()
210+ # Return a future that never completes for cancellation.
211+ return asyncio .get_running_loop ().create_future ()
212+ return real_run_in_executor (executor , func , * args )
213+
214+ with (
215+ patch .object (_socket , "socket" , tracking_socket ),
216+ patch .object (loop , "run_in_executor" , slow_run_in_executor ),
217+ patch .object (options , "_PoolOptions__ssl_context" , fake_ssl_context ),
218+ ):
219+ task = asyncio .create_task (pool_shared ._async_configured_socket (address , options ))
217220 target_task = task
218221 await asyncio .wait_for (started .wait (), timeout = 5 )
219222 task .cancel ()
@@ -227,62 +230,3 @@ def tracked_socket(*args, **kwargs):
227230 - 1 ,
228231 f"socket leaked across cancellation: { sock !r} " ,
229232 )
230-
231- async def test_cancellation_closes_socket_during_protocol_create_connection (self ):
232- def make_slow (real , started , block_forever , is_target ):
233- async def slow_create_connection (* args , ** kwargs ):
234- if is_target ():
235- started .set ()
236- await block_forever .wait ()
237- return await real (* args , ** kwargs )
238-
239- return slow_create_connection
240-
241- await self ._assert_cancellation_closes_socket (
242- connection_creator = pool_shared ._configured_protocol_interface ,
243- loop_method_name = "create_connection" ,
244- make_slow = make_slow ,
245- )
246-
247- async def test_cancellation_closes_socket_during_ssl_wrap_socket (self ):
248- fake_ssl_context = _ssl .create_default_context ()
249-
250- def make_slow (real , started , _ , is_target ):
251- def slow_run_in_executor (executor , func , * args ):
252- # Need to unwrap the SNI branch here if present
253- inner = func .func if isinstance (func , functools .partial ) else func
254- # Each `ctx.wrap_socket` access returns a fresh bound-method
255- # object, so we check the bound instance (__self__) instead
256- if getattr (inner , "__self__" , None ) is fake_ssl_context and is_target ():
257- started .set ()
258- # Return a future that never completes for cancellation.
259- return asyncio .get_running_loop ().create_future ()
260- return real (executor , func , * args )
261-
262- return slow_run_in_executor
263-
264- await self ._assert_cancellation_closes_socket (
265- connection_creator = pool_shared ._async_configured_socket ,
266- loop_method_name = "run_in_executor" ,
267- make_slow = make_slow ,
268- ssl_context = fake_ssl_context ,
269- )
270-
271- async def test_cancellation_closes_socket_during_ssl_create_connection (self ):
272- fake_ssl_context = _ssl .create_default_context ()
273-
274- def make_slow (real , started , block_forever , is_target ):
275- async def slow_create_connection (* args , ** kwargs ):
276- if kwargs .get ("ssl" ) is fake_ssl_context and is_target ():
277- started .set ()
278- await block_forever .wait ()
279- return await real (* args , ** kwargs )
280-
281- return slow_create_connection
282-
283- await self ._assert_cancellation_closes_socket (
284- connection_creator = pool_shared ._configured_protocol_interface ,
285- loop_method_name = "create_connection" ,
286- make_slow = make_slow ,
287- ssl_context = fake_ssl_context ,
288- )
0 commit comments