Skip to content

Commit 6dc84e4

Browse files
committed
Skip checking if the eval context is active if we're in the background thread of async execution.
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 2acd546 commit 6dc84e4

2 files changed

Lines changed: 38 additions & 4 deletions

File tree

dali/python/nvidia/dali/experimental/dynamic/_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self):
7878
self._queue = queue.SimpleQueue[Optional[_Future]]()
7979
self._thread: Optional[threading.Thread] = None
8080
self._event = threading.Event()
81+
self._main_thread = threading.current_thread()
8182

8283
def _worker(self):
8384
while True:

dali/python/nvidia/dali/experimental/dynamic/_eval_context.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ def __init__(self, *, num_threads=None, device_id=None, cuda_stream=None):
171171

172172
# Used to disallow the EvalContext to be active in two threads simultaneously
173173
self._lock = threading.RLock()
174+
# Python's RLock doesn't expose the owner nor the number of locks held so we track it here
175+
self._lock_count = 0
176+
self._locking_thread: threading.Thread | None = None
174177

175178
def _purge_operator_cache(self):
176179
"""Empties the operator instance cache"""
@@ -210,14 +213,13 @@ def _is_current(self) -> bool:
210213
return self is _tls.default.get(current_device_id)
211214

212215
def __enter__(self):
213-
if not self._lock.acquire(blocking=False):
214-
raise RuntimeError("An EvalContext cannot be active in two threads simultaneously.")
216+
self._try_acquire_lock()
215217
try:
216218
_tls.stack.append(self)
217219
if self._device:
218220
self._device.__enter__()
219221
except Exception:
220-
self._lock.release()
222+
self._try_release_lock()
221223
raise
222224
return self
223225

@@ -235,7 +237,7 @@ def __exit__(self, exc_type, exc_value, traceback):
235237
if self._device:
236238
self._device.__exit__(exc_type, exc_value, traceback)
237239
finally:
238-
self._lock.release()
240+
self._try_release_lock()
239241

240242
def evaluate_all(self):
241243
"""Evaluates all pending invocations."""
@@ -331,6 +333,37 @@ def _snapshot(self):
331333
def _is_in_background_thread(self):
332334
return threading.current_thread() is self._async_executor._thread
333335

336+
def _try_acquire_lock(self):
337+
# Skip checking if the eval context is already active if either:
338+
# - We're in the background thread for async execution and it's active in the main thread
339+
# - We're in the main thread and it's active in the background thread
340+
background_thread = self._async_executor._thread
341+
main_thread = self._async_executor._main_thread
342+
if (
343+
self._locking_thread is not None
344+
and background_thread is not None
345+
and (threading.current_thread(), self._locking_thread)
346+
in (
347+
(main_thread, background_thread),
348+
(background_thread, main_thread),
349+
)
350+
):
351+
return
352+
353+
if not self._lock.acquire(blocking=False):
354+
raise RuntimeError("An EvalContext cannot be active in two threads simultaneously.")
355+
self._locking_thread = threading.current_thread()
356+
self._lock_count += 1
357+
358+
def _try_release_lock(self):
359+
if self._locking_thread is not threading.current_thread():
360+
return
361+
362+
self._lock_count -= 1
363+
if self._lock_count == 0:
364+
self._locking_thread = None
365+
self._lock.release()
366+
334367

335368
__all__ = [
336369
"EvalContext",

0 commit comments

Comments
 (0)