|
1 | 1 | import asyncio |
| 2 | +import threading |
| 3 | +import weakref |
2 | 4 | from collections import defaultdict |
3 | 5 | from contextlib import asynccontextmanager |
| 6 | +from dataclasses import dataclass, field |
| 7 | + |
| 8 | + |
| 9 | +@dataclass |
| 10 | +class _LoopLockState: |
| 11 | + access_lock: asyncio.Lock = field(default_factory=asyncio.Lock) |
| 12 | + locks: dict[str, asyncio.Lock] = field(default_factory=dict) |
| 13 | + lock_count: dict[str, int] = field(default_factory=lambda: defaultdict(int)) |
4 | 14 |
|
5 | 15 |
|
6 | 16 | class SessionLockManager: |
7 | 17 | def __init__(self) -> None: |
8 | | - self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) |
9 | | - self._lock_count: dict[str, int] = defaultdict(int) |
10 | | - self._access_lock = asyncio.Lock() |
| 18 | + self._state_guard = threading.Lock() |
| 19 | + self._loop_states: weakref.WeakKeyDictionary[ |
| 20 | + asyncio.AbstractEventLoop, _LoopLockState |
| 21 | + ] = weakref.WeakKeyDictionary() |
| 22 | + |
| 23 | + def _get_loop_state(self) -> _LoopLockState: |
| 24 | + loop = asyncio.get_running_loop() |
| 25 | + with self._state_guard: |
| 26 | + state = self._loop_states.get(loop) |
| 27 | + if state is None: |
| 28 | + state = _LoopLockState() |
| 29 | + self._loop_states[loop] = state |
| 30 | + return state |
11 | 31 |
|
12 | 32 | @asynccontextmanager |
13 | 33 | async def acquire_lock(self, session_id: str): |
14 | | - async with self._access_lock: |
15 | | - lock = self._locks[session_id] |
16 | | - self._lock_count[session_id] += 1 |
| 34 | + state = self._get_loop_state() |
| 35 | + |
| 36 | + async with state.access_lock: |
| 37 | + lock = state.locks.get(session_id) |
| 38 | + if lock is None: |
| 39 | + lock = asyncio.Lock() |
| 40 | + state.locks[session_id] = lock |
| 41 | + state.lock_count[session_id] += 1 |
17 | 42 |
|
18 | 43 | try: |
19 | 44 | async with lock: |
20 | 45 | yield |
21 | 46 | finally: |
22 | | - async with self._access_lock: |
23 | | - self._lock_count[session_id] -= 1 |
24 | | - if self._lock_count[session_id] == 0: |
25 | | - self._locks.pop(session_id, None) |
26 | | - self._lock_count.pop(session_id, None) |
| 47 | + async with state.access_lock: |
| 48 | + state.lock_count[session_id] -= 1 |
| 49 | + if state.lock_count[session_id] == 0: |
| 50 | + state.locks.pop(session_id, None) |
| 51 | + state.lock_count.pop(session_id, None) |
27 | 52 |
|
28 | 53 |
|
29 | 54 | session_lock_manager = SessionLockManager() |
0 commit comments