Skip to content

Commit 5b92f74

Browse files
committed
Remove racy closes
1 parent 784c6cd commit 5b92f74

2 files changed

Lines changed: 31 additions & 98 deletions

File tree

pymongo/pool_shared.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,11 @@ async def _configured_protocol_interface(
322322
timeout = options.socket_timeout
323323

324324
if ssl_context is None:
325-
try:
326-
return AsyncNetworkingInterface(
327-
await asyncio.get_running_loop().create_connection(
328-
lambda: PyMongoProtocol(timeout=timeout), sock=sock
329-
)
325+
return AsyncNetworkingInterface(
326+
await asyncio.get_running_loop().create_connection(
327+
lambda: PyMongoProtocol(timeout=timeout), sock=sock
330328
)
331-
except BaseException:
332-
# Protect against cancellation or interruption before the transport
333-
# takes ownership of the raw socket.
334-
sock.close()
335-
raise
329+
)
336330

337331
host = address[0]
338332
try:
@@ -354,11 +348,6 @@ async def _configured_protocol_interface(
354348
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
355349
details = _get_timeout_details(options)
356350
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
357-
except BaseException:
358-
# Protect against cancellation or interruption before the transport
359-
# takes ownership of the raw socket.
360-
sock.close()
361-
raise
362351
try:
363352
if (
364353
ssl_context.verify_mode

test/asynchronous/test_async_cancellation.py

Lines changed: 27 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from __future__ import annotations
1717

1818
import asyncio
19-
import contextlib
2019
import functools
2120
import socket as _socket
2221
import ssl as _ssl
@@ -179,41 +178,45 @@ async def slow_sock_connect(sock, addr):
179178
f"socket leaked across cancellation: {sock!r}",
180179
)
181180

182-
async def _assert_cancellation_closes_socket(
183-
self,
184-
*,
185-
connection_creator,
186-
loop_method_name,
187-
make_slow,
188-
ssl_context=None,
189-
):
181+
async def test_cancellation_closes_socket_during_ssl_wrap_socket(self):
190182
address = (await async_client_context.host, await async_client_context.port)
191183
options = (await async_get_pool(self.client)).opts
184+
fake_ssl_context = _ssl.create_default_context()
192185

193-
created_sockets = []
186+
created_sockets: list[_socket.socket] = []
194187
real_socket_cls = _socket.socket
195188
target_task = None
196189

197-
def is_target():
198-
return asyncio.current_task() is target_task
199-
200-
def tracked_socket(*args, **kwargs):
190+
def tracking_socket(*args, **kwargs):
201191
s = real_socket_cls(*args, **kwargs)
202-
if is_target():
192+
if asyncio.current_task() is target_task:
203193
created_sockets.append(s)
204194
return s
205195

206196
loop = asyncio.get_running_loop()
197+
real_run_in_executor = loop.run_in_executor
207198
started = asyncio.Event()
208-
block_forever = asyncio.Event()
209-
slow_method = make_slow(getattr(loop, loop_method_name), started, block_forever, is_target)
210-
211-
with contextlib.ExitStack() as stack:
212-
stack.enter_context(patch.object(_socket, "socket", tracked_socket))
213-
stack.enter_context(patch.object(loop, loop_method_name, slow_method))
214-
if ssl_context is not None:
215-
stack.enter_context(patch.object(options, "_PoolOptions__ssl_context", ssl_context))
216-
task = asyncio.create_task(connection_creator(address, options))
199+
200+
def slow_run_in_executor(executor, func, *args):
201+
# Need to unwrap the SNI branch here if present
202+
inner = func.func if isinstance(func, functools.partial) else func
203+
# Each `ctx.wrap_socket` access returns a fresh bound-method
204+
# object, so we check the bound instance (__self__) instead
205+
if (
206+
getattr(inner, "__self__", None) is fake_ssl_context
207+
and asyncio.current_task() is target_task
208+
):
209+
started.set()
210+
# Return a future that never completes for cancellation.
211+
return asyncio.get_running_loop().create_future()
212+
return real_run_in_executor(executor, func, *args)
213+
214+
with (
215+
patch.object(_socket, "socket", tracking_socket),
216+
patch.object(loop, "run_in_executor", slow_run_in_executor),
217+
patch.object(options, "_PoolOptions__ssl_context", fake_ssl_context),
218+
):
219+
task = asyncio.create_task(pool_shared._async_configured_socket(address, options))
217220
target_task = task
218221
await asyncio.wait_for(started.wait(), timeout=5)
219222
task.cancel()
@@ -227,62 +230,3 @@ def tracked_socket(*args, **kwargs):
227230
-1,
228231
f"socket leaked across cancellation: {sock!r}",
229232
)
230-
231-
async def test_cancellation_closes_socket_during_protocol_create_connection(self):
232-
def make_slow(real, started, block_forever, is_target):
233-
async def slow_create_connection(*args, **kwargs):
234-
if is_target():
235-
started.set()
236-
await block_forever.wait()
237-
return await real(*args, **kwargs)
238-
239-
return slow_create_connection
240-
241-
await self._assert_cancellation_closes_socket(
242-
connection_creator=pool_shared._configured_protocol_interface,
243-
loop_method_name="create_connection",
244-
make_slow=make_slow,
245-
)
246-
247-
async def test_cancellation_closes_socket_during_ssl_wrap_socket(self):
248-
fake_ssl_context = _ssl.create_default_context()
249-
250-
def make_slow(real, started, _, is_target):
251-
def slow_run_in_executor(executor, func, *args):
252-
# Need to unwrap the SNI branch here if present
253-
inner = func.func if isinstance(func, functools.partial) else func
254-
# Each `ctx.wrap_socket` access returns a fresh bound-method
255-
# object, so we check the bound instance (__self__) instead
256-
if getattr(inner, "__self__", None) is fake_ssl_context and is_target():
257-
started.set()
258-
# Return a future that never completes for cancellation.
259-
return asyncio.get_running_loop().create_future()
260-
return real(executor, func, *args)
261-
262-
return slow_run_in_executor
263-
264-
await self._assert_cancellation_closes_socket(
265-
connection_creator=pool_shared._async_configured_socket,
266-
loop_method_name="run_in_executor",
267-
make_slow=make_slow,
268-
ssl_context=fake_ssl_context,
269-
)
270-
271-
async def test_cancellation_closes_socket_during_ssl_create_connection(self):
272-
fake_ssl_context = _ssl.create_default_context()
273-
274-
def make_slow(real, started, block_forever, is_target):
275-
async def slow_create_connection(*args, **kwargs):
276-
if kwargs.get("ssl") is fake_ssl_context and is_target():
277-
started.set()
278-
await block_forever.wait()
279-
return await real(*args, **kwargs)
280-
281-
return slow_create_connection
282-
283-
await self._assert_cancellation_closes_socket(
284-
connection_creator=pool_shared._configured_protocol_interface,
285-
loop_method_name="create_connection",
286-
make_slow=make_slow,
287-
ssl_context=fake_ssl_context,
288-
)

0 commit comments

Comments
 (0)