Skip to content

Commit 31394d1

Browse files
author
remimd
committed
feat: ✨ define_scope can be threadsafe
1 parent af3a276 commit 31394d1

4 files changed

Lines changed: 79 additions & 53 deletions

File tree

injection/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ def adefine_scope(
3535
name: str,
3636
/,
3737
kind: ScopeKind | ScopeKindStr = ...,
38+
threadsafe: bool = ...,
3839
) -> AsyncIterator[Scope]: ...
3940
@contextmanager
4041
def define_scope(
4142
name: str,
4243
/,
4344
kind: ScopeKind | ScopeKindStr = ...,
45+
threadsafe: bool = ...,
4446
) -> Iterator[Scope]: ...
4547
def mod(name: str = ..., /) -> Module:
4648
"""

injection/_core/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ class InjectMetadata[**P, T](Caller[P, T], EventListener):
996996

997997
def __init__(self, wrapped: Callable[P, T], /, threadsafe: bool) -> None:
998998
self.__dependencies = Dependencies.empty()
999-
self.__lock = threading.Lock() if threadsafe else nullcontext()
999+
self.__lock = threading.RLock() if threadsafe else nullcontext()
10001000
self.__owner = None
10011001
self.__tasks = deque()
10021002
self.__wrapped = wrapped

injection/_core/scope.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from __future__ import annotations
22

33
import itertools
4+
import threading
45
from abc import ABC, abstractmethod
56
from collections import defaultdict
67
from collections.abc import AsyncIterator, Iterator, Mapping, MutableMapping
7-
from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager
8+
from contextlib import (
9+
AsyncExitStack,
10+
ExitStack,
11+
asynccontextmanager,
12+
contextmanager,
13+
nullcontext,
14+
)
815
from contextvars import ContextVar
916
from dataclasses import dataclass, field
1017
from enum import StrEnum
@@ -129,9 +136,10 @@ async def adefine_scope(
129136
name: str,
130137
/,
131138
kind: ScopeKind | ScopeKindStr = ScopeKind.get_default(),
139+
threadsafe: bool = False,
132140
) -> AsyncIterator[ScopeFacade]:
133141
async with AsyncScope() as scope:
134-
with _bind_scope(name, scope, kind) as facade:
142+
with _bind_scope(name, scope, kind, threadsafe) as facade:
135143
yield facade
136144

137145

@@ -140,9 +148,10 @@ def define_scope(
140148
name: str,
141149
/,
142150
kind: ScopeKind | ScopeKindStr = ScopeKind.get_default(),
151+
threadsafe: bool = False,
143152
) -> Iterator[ScopeFacade]:
144153
with SyncScope() as scope:
145-
with _bind_scope(name, scope, kind) as facade:
154+
with _bind_scope(name, scope, kind, threadsafe) as facade:
146155
yield facade
147156

148157

@@ -191,27 +200,39 @@ def _bind_scope(
191200
name: str,
192201
scope: Scope,
193202
kind: ScopeKind | ScopeKindStr,
203+
threadsafe: bool,
194204
) -> Iterator[ScopeFacade]:
195-
match ScopeKind(kind):
196-
case ScopeKind.CONTEXTUAL:
197-
is_already_defined = bool(get_scope(name, default=None))
198-
states = __CONTEXTUAL_SCOPES
205+
lock = threading.RLock() if threadsafe else nullcontext()
199206

200-
case ScopeKind.SHARED:
201-
is_already_defined = bool(get_active_scopes(name))
202-
states = __SHARED_SCOPES
207+
with lock:
208+
match ScopeKind(kind):
209+
case ScopeKind.CONTEXTUAL:
210+
is_already_defined = bool(get_scope(name, default=None))
211+
states = __CONTEXTUAL_SCOPES
203212

204-
case _:
205-
raise NotImplementedError
213+
case ScopeKind.SHARED:
214+
is_already_defined = bool(get_active_scopes(name))
215+
states = __SHARED_SCOPES
206216

207-
if is_already_defined:
208-
raise ScopeAlreadyDefinedError(
209-
f"Scope `{name}` is already defined in the current context."
210-
)
217+
case _:
218+
raise NotImplementedError
219+
220+
if is_already_defined:
221+
raise ScopeAlreadyDefinedError(
222+
f"Scope `{name}` is already defined in the current context."
223+
)
211224

212-
with states[name].bind(scope):
225+
stack = ExitStack()
226+
binder = states[name].bind(scope)
227+
stack.enter_context(binder)
228+
229+
try:
213230
yield _UserScope(scope)
214231

232+
finally:
233+
with lock:
234+
stack.close()
235+
215236

216237
@runtime_checkable
217238
class Scope(Protocol):

0 commit comments

Comments
 (0)