@@ -171,6 +171,9 @@ def __init__(self, *, num_threads=None, device_id=None, cuda_stream=None):
171171
172172 # Used to disallow the EvalContext to be active in two threads simultaneously
173173 self ._lock = threading .RLock ()
174+ # Python's RLock doesn't expose the owner nor the number of locks held so we track it here
175+ self ._lock_count = 0
176+ self ._locking_thread : threading .Thread | None = None
174177
175178 def _purge_operator_cache (self ):
176179 """Empties the operator instance cache"""
@@ -210,14 +213,13 @@ def _is_current(self) -> bool:
210213 return self is _tls .default .get (current_device_id )
211214
212215 def __enter__ (self ):
213- if not self ._lock .acquire (blocking = False ):
214- raise RuntimeError ("An EvalContext cannot be active in two threads simultaneously." )
216+ self ._try_acquire_lock ()
215217 try :
216218 _tls .stack .append (self )
217219 if self ._device :
218220 self ._device .__enter__ ()
219221 except Exception :
220- self ._lock . release ()
222+ self ._try_release_lock ()
221223 raise
222224 return self
223225
@@ -235,7 +237,7 @@ def __exit__(self, exc_type, exc_value, traceback):
235237 if self ._device :
236238 self ._device .__exit__ (exc_type , exc_value , traceback )
237239 finally :
238- self ._lock . release ()
240+ self ._try_release_lock ()
239241
240242 def evaluate_all (self ):
241243 """Evaluates all pending invocations."""
@@ -331,6 +333,52 @@ def _snapshot(self):
331333 def _is_in_background_thread (self ):
332334 return threading .current_thread () is self ._async_executor ._thread
333335
336+ def _should_acquire_lock (self ):
337+ # Skip checking if the eval context is already active if either:
338+ # - We're in the background thread for async execution and it's active in the main thread
339+ # - We're in the main thread and it's active in the background thread
340+ background_thread = self ._async_executor ._thread
341+ main_thread = self ._async_executor ._main_thread
342+
343+ if self ._locking_thread is None or background_thread is None :
344+ return True
345+
346+ return (threading .current_thread , self ._locking_thread ) not in (
347+ (main_thread , background_thread ),
348+ (background_thread , main_thread ),
349+ )
350+
351+ def _try_acquire_lock (self ):
352+ # Skip checking if the eval context is already active if either:
353+ # - We're in the background thread for async execution and it's active in the main thread
354+ # - We're in the main thread and it's active in the background thread
355+ background_thread = self ._async_executor ._thread
356+ main_thread = self ._async_executor ._main_thread
357+ if (
358+ self ._locking_thread is not None
359+ and background_thread is not None
360+ and (threading .current_thread (), self ._locking_thread )
361+ in (
362+ (main_thread , background_thread ),
363+ (background_thread , main_thread ),
364+ )
365+ ):
366+ return
367+
368+ if not self ._lock .acquire (blocking = False ):
369+ raise RuntimeError ("An EvalContext cannot be active in two threads simultaneously." )
370+ self ._locking_thread = threading .current_thread ()
371+ self ._lock_count += 1
372+
373+ def _try_release_lock (self ):
374+ if self ._locking_thread is not threading .current_thread ():
375+ return
376+
377+ self ._lock_count -= 1
378+ if self ._lock_count == 0 :
379+ self ._locking_thread = None
380+ self ._lock .release ()
381+
334382
335383__all__ = [
336384 "EvalContext" ,
0 commit comments