Skip to content

Commit 26d801b

Browse files
committed
fix: 重构 SessionLockManager 以支持每个事件循环的独立锁管理,并添加单元测试
1 parent 729363f commit 26d801b

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

astrbot/core/utils/session_lock.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,54 @@
11
import asyncio
2+
import threading
3+
import weakref
24
from collections import defaultdict
35
from contextlib import asynccontextmanager
6+
from dataclasses import dataclass, field
7+
8+
9+
@dataclass
10+
class _LoopLockState:
11+
access_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
12+
locks: dict[str, asyncio.Lock] = field(default_factory=dict)
13+
lock_count: dict[str, int] = field(default_factory=lambda: defaultdict(int))
414

515

616
class SessionLockManager:
717
def __init__(self) -> None:
8-
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
9-
self._lock_count: dict[str, int] = defaultdict(int)
10-
self._access_lock = asyncio.Lock()
18+
self._state_guard = threading.Lock()
19+
self._loop_states: weakref.WeakKeyDictionary[
20+
asyncio.AbstractEventLoop, _LoopLockState
21+
] = weakref.WeakKeyDictionary()
22+
23+
def _get_loop_state(self) -> _LoopLockState:
24+
loop = asyncio.get_running_loop()
25+
with self._state_guard:
26+
state = self._loop_states.get(loop)
27+
if state is None:
28+
state = _LoopLockState()
29+
self._loop_states[loop] = state
30+
return state
1131

1232
@asynccontextmanager
1333
async def acquire_lock(self, session_id: str):
14-
async with self._access_lock:
15-
lock = self._locks[session_id]
16-
self._lock_count[session_id] += 1
34+
state = self._get_loop_state()
35+
36+
async with state.access_lock:
37+
lock = state.locks.get(session_id)
38+
if lock is None:
39+
lock = asyncio.Lock()
40+
state.locks[session_id] = lock
41+
state.lock_count[session_id] += 1
1742

1843
try:
1944
async with lock:
2045
yield
2146
finally:
22-
async with self._access_lock:
23-
self._lock_count[session_id] -= 1
24-
if self._lock_count[session_id] == 0:
25-
self._locks.pop(session_id, None)
26-
self._lock_count.pop(session_id, None)
47+
async with state.access_lock:
48+
state.lock_count[session_id] -= 1
49+
if state.lock_count[session_id] == 0:
50+
state.locks.pop(session_id, None)
51+
state.lock_count.pop(session_id, None)
2752

2853

2954
session_lock_manager = SessionLockManager()

autogen-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 13e144e5476a76ca0d76bf4f07a6401d133a03ed

tests/unit/test_session_lock.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import asyncio
2+
import threading
3+
4+
from astrbot.core.utils.session_lock import SessionLockManager
5+
6+
7+
def _run_worker(manager: SessionLockManager, errors: list[BaseException]) -> None:
8+
async def worker():
9+
for _ in range(200):
10+
async with manager.acquire_lock("shared-session"):
11+
await asyncio.sleep(0)
12+
13+
try:
14+
asyncio.run(worker())
15+
except BaseException as exc: # noqa: BLE001
16+
errors.append(exc)
17+
18+
19+
def test_session_lock_manager_isolated_per_event_loop():
20+
manager = SessionLockManager()
21+
errors: list[BaseException] = []
22+
23+
threads = [
24+
threading.Thread(target=_run_worker, args=(manager, errors)),
25+
threading.Thread(target=_run_worker, args=(manager, errors)),
26+
]
27+
for thread in threads:
28+
thread.start()
29+
for thread in threads:
30+
thread.join()
31+
32+
assert not errors

0 commit comments

Comments
 (0)