@@ -171,6 +171,9 @@ def __init__(self, *, num_threads=None, device_id=None, cuda_stream=None):
171171
172172 # Used to disallow the EvalContext to be active in two threads simultaneously
173173 self ._lock = threading .RLock ()
174+ # Python's RLock doesn't expose the owner nor the number of locks held so we track it here
175+ self ._lock_count = 0
176+ self ._locking_thread : threading .Thread | None = None
174177
175178 def _purge_operator_cache (self ):
176179 """Empties the operator instance cache"""
@@ -210,14 +213,13 @@ def _is_current(self) -> bool:
210213 return self is _tls .default .get (current_device_id )
211214
212215 def __enter__ (self ):
213- if not self ._lock .acquire (blocking = False ):
214- raise RuntimeError ("An EvalContext cannot be active in two threads simultaneously." )
216+ self ._try_acquire_lock ()
215217 try :
216218 _tls .stack .append (self )
217219 if self ._device :
218220 self ._device .__enter__ ()
219221 except Exception :
220- self ._lock . release ()
222+ self ._try_release_lock ()
221223 raise
222224 return self
223225
@@ -235,7 +237,7 @@ def __exit__(self, exc_type, exc_value, traceback):
235237 if self ._device :
236238 self ._device .__exit__ (exc_type , exc_value , traceback )
237239 finally :
238- self ._lock . release ()
240+ self ._try_release_lock ()
239241
240242 def evaluate_all (self ):
241243 """Evaluates all pending invocations."""
@@ -331,6 +333,37 @@ def _snapshot(self):
331333 def _is_in_background_thread (self ):
332334 return threading .current_thread () is self ._async_executor ._thread
333335
336+ def _try_acquire_lock (self ):
337+ # Skip checking if the eval context is already active if either:
338+ # - We're in the background thread for async execution and it's active in the main thread
339+ # - We're in the main thread and it's active in the background thread
340+ background_thread = self ._async_executor ._thread
341+ main_thread = self ._async_executor ._main_thread
342+ if (
343+ self ._locking_thread is not None
344+ and background_thread is not None
345+ and (threading .current_thread (), self ._locking_thread )
346+ in (
347+ (main_thread , background_thread ),
348+ (background_thread , main_thread ),
349+ )
350+ ):
351+ return
352+
353+ if not self ._lock .acquire (blocking = False ):
354+ raise RuntimeError ("An EvalContext cannot be active in two threads simultaneously." )
355+ self ._locking_thread = threading .current_thread ()
356+ self ._lock_count += 1
357+
358+ def _try_release_lock (self ):
359+ if self ._locking_thread is not threading .current_thread ():
360+ return
361+
362+ self ._lock_count -= 1
363+ if self ._lock_count == 0 :
364+ self ._locking_thread = None
365+ self ._lock .release ()
366+
334367
335368__all__ = [
336369 "EvalContext" ,
0 commit comments