Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions fastapi/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import AsyncGenerator
import functools
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractContextManager
from contextlib import asynccontextmanager as asynccontextmanager
from typing import TypeVar
from typing import ParamSpec, TypeVar

import anyio.to_thread
from anyio import CapacityLimiter
Expand All @@ -11,31 +12,44 @@
run_until_first_complete as run_until_first_complete,
)

_P = ParamSpec("_P")
_T = TypeVar("_T")

# Blocking __exit__ and other teardown operations from running can create race
# conditions/deadlocks if the context manager itself has its own internal pool
# (e.g. a database connection pool).
# To avoid this maintain a separate limiter for teardown operations, so that the
# operations acquiring resources can never block operations releasing resources.
# NOTE: 5 is arbitrary, we would like more than 1 so that teardowns are not serialised.
_teardown_limiter = CapacityLimiter(5)


@asynccontextmanager
async def contextmanager_in_threadpool(
cm: AbstractContextManager[_T],
) -> AsyncGenerator[_T, None]:
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool)
# to avoid this we let __exit__ run without a capacity limit
# since we're creating a new limiter for each call, any non-zero limit
# works (1 is arbitrary)
exit_limiter = CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
)
await run_in_teardown_threadpool(cm.__exit__, type(e), e, e.__traceback__)
)
if not ok:
raise e
else:
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)
await run_in_teardown_threadpool(cm.__exit__, None, None, None)


async def run_in_teardown_threadpool(
func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
"""Run a function in the separate teardown threadpool.

This will run the function in the teardown threadpool in order to avoid it
being blocked by other operations waiting to acquire resources.

Unless you know what you are doing, you probably don't want this function,
use run_in_threadpool instead.
"""
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func, limiter=_teardown_limiter)
8 changes: 6 additions & 2 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
Undefined,
lenient_issubclass,
)
from fastapi.concurrency import (
iterate_in_threadpool,
run_in_teardown_threadpool,
run_in_threadpool,
)
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
Expand Down Expand Up @@ -75,7 +80,6 @@
from starlette import routing
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool
from starlette.datastructures import FormData
from starlette.exceptions import HTTPException
from starlette.requests import Request
Expand Down Expand Up @@ -292,7 +296,7 @@ async def serialize_response(
if is_coroutine:
value, errors = field.validate(response_content, {}, loc=("response",))
else:
value, errors = await run_in_threadpool(
value, errors = await run_in_teardown_threadpool(
field.validate, response_content, {}, loc=("response",)
)
if errors:
Expand Down
68 changes: 68 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import contextlib
import time
from collections.abc import Iterator

import anyio.to_thread
import pytest
from anyio import CapacityLimiter
from fastapi import concurrency


@pytest.fixture
def reset_teardown_limiter(monkeypatch: pytest.MonkeyPatch) -> None:
"""Reset the teardown limiter before/after tests to avoid interference
between different anyio backends."""
monkeypatch.setattr(concurrency, "_teardown_limiter", CapacityLimiter(5))


@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_run_in_teardown_threadpool() -> None:
def func(x: int, y: int) -> int:
return x + y

result = await concurrency.run_in_teardown_threadpool(func, 1, y=2)
assert result == 3


@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_contextmanager_in_threadpool() -> None:
@contextlib.contextmanager
def context_manager() -> Iterator[str]:
yield "entered"

async with concurrency.contextmanager_in_threadpool(context_manager()) as result:
assert result == "entered"


@pytest.mark.anyio
@pytest.mark.usefixtures("reset_teardown_limiter")
async def test_competing_acquire_release() -> None:
"""Check that the main threadpool does not block the teardown threadpool."""
pool_size = anyio.to_thread.current_default_thread_limiter().total_tokens
acquirable = False
acquired = []

def acquire() -> None:
while not acquirable:
time.sleep(0.001)
acquired.append(True)

def release() -> bool:
nonlocal acquirable
time.sleep(0.001)
acquirable = True
return acquirable

async with anyio.create_task_group() as tg:
for _ in range(pool_size):
tg.start_soon(concurrency.run_in_threadpool, acquire)

await anyio.sleep(0.001)

# The threadpool should now be full of threads waiting to acquire
# The release function should be able to run without being blocked by acquires
await concurrency.run_in_teardown_threadpool(release)

assert len(acquired) == pool_size
56 changes: 56 additions & 0 deletions tests/test_depends_deadlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import asyncio
import threading
import time
from collections.abc import Iterator

from fastapi import Depends, FastAPI
from httpx import ASGITransport, AsyncClient
from pydantic import BaseModel

# Mutex, and dependency acting as our "connection pool" for a database for example
mutex = threading.Lock()


# Simulate releaasing a pooled resource in the teardown of a Depends,
# which in reality is usually a database connection or similar.
def release_resource() -> Iterator[None]:
try:
time.sleep(0.001)
yield
finally:
time.sleep(0.001)
mutex.release()


app = FastAPI()


class Item(BaseModel):
name: str
id: int


# An endpoint that uses Depends for resource management and also includes
# a response_model definition would previously deadlock in the validation
# of the model and the cleanup of the Depends
@app.get("/deadlock", response_model=Item)
def get_deadlock(dep: None = Depends(release_resource)) -> Item:
mutex.acquire()
return Item(name="foo", id=1)


# Fire off 100 requests in parallel(ish) in order to create contention
# over the shared resource (simulating a fastapi server that interacts with
# a database connection pool).
def test_depends_deadlock() -> None:
async def make_request(client: AsyncClient):
await client.get("/deadlock")

async def run_requests() -> None:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://testserver"
) as aclient:
tasks = [make_request(aclient) for _ in range(100)]
await asyncio.gather(*tasks)

asyncio.run(run_requests())
Loading