Skip to content

Commit 2ecbb53

Browse files
committed
Skip checking if the eval context is active if we're in the background thread of async execution.
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 2acd546 commit 2ecbb53

2 files changed

Lines changed: 53 additions & 4 deletions

File tree

dali/python/nvidia/dali/experimental/dynamic/_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self):
7878
self._queue = queue.SimpleQueue[Optional[_Future]]()
7979
self._thread: Optional[threading.Thread] = None
8080
self._event = threading.Event()
81+
self._main_thread = threading.current_thread()
8182

8283
def _worker(self):
8384
while True:

dali/python/nvidia/dali/experimental/dynamic/_eval_context.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ def __init__(self, *, num_threads=None, device_id=None, cuda_stream=None):
171171

172172
# Used to disallow the EvalContext to be active in two threads simultaneously
173173
self._lock = threading.RLock()
174+
# Python's RLock doesn't expose the owner nor the number of locks held so we track it here
175+
self._lock_count = 0
176+
self._locking_thread: threading.Thread | None = None
174177

175178
def _purge_operator_cache(self):
176179
"""Empties the operator instance cache"""
@@ -210,14 +213,13 @@ def _is_current(self) -> bool:
210213
return self is _tls.default.get(current_device_id)
211214

212215
def __enter__(self):
213-
if not self._lock.acquire(blocking=False):
214-
raise RuntimeError("An EvalContext cannot be active in two threads simultaneously.")
216+
self._try_acquire_lock()
215217
try:
216218
_tls.stack.append(self)
217219
if self._device:
218220
self._device.__enter__()
219221
except Exception:
220-
self._lock.release()
222+
self._try_release_lock()
221223
raise
222224
return self
223225

@@ -235,7 +237,7 @@ def __exit__(self, exc_type, exc_value, traceback):
235237
if self._device:
236238
self._device.__exit__(exc_type, exc_value, traceback)
237239
finally:
238-
self._lock.release()
240+
self._try_release_lock()
239241

240242
def evaluate_all(self):
241243
"""Evaluates all pending invocations."""
@@ -331,6 +333,52 @@ def _snapshot(self):
331333
def _is_in_background_thread(self):
332334
return threading.current_thread() is self._async_executor._thread
333335

336+
def _should_acquire_lock(self):
337+
# Skip checking if the eval context is already active if either:
338+
# - We're in the background thread for async execution and it's active in the main thread
339+
# - We're in the main thread and it's active in the background thread
340+
background_thread = self._async_executor._thread
341+
main_thread = self._async_executor._main_thread
342+
343+
if self._locking_thread is None or background_thread is None:
344+
return True
345+
346+
return (threading.current_thread, self._locking_thread) not in (
347+
(main_thread, background_thread),
348+
(background_thread, main_thread),
349+
)
350+
351+
def _try_acquire_lock(self):
352+
# Skip checking if the eval context is already active if either:
353+
# - We're in the background thread for async execution and it's active in the main thread
354+
# - We're in the main thread and it's active in the background thread
355+
background_thread = self._async_executor._thread
356+
main_thread = self._async_executor._main_thread
357+
if (
358+
self._locking_thread is not None
359+
and background_thread is not None
360+
and (threading.current_thread(), self._locking_thread)
361+
in (
362+
(main_thread, background_thread),
363+
(background_thread, main_thread),
364+
)
365+
):
366+
return
367+
368+
if not self._lock.acquire(blocking=False):
369+
raise RuntimeError("An EvalContext cannot be active in two threads simultaneously.")
370+
self._locking_thread = threading.current_thread()
371+
self._lock_count += 1
372+
373+
def _try_release_lock(self):
374+
if self._locking_thread is not threading.current_thread():
375+
return
376+
377+
self._lock_count -= 1
378+
if self._lock_count == 0:
379+
self._locking_thread = None
380+
self._lock.release()
381+
334382

335383
__all__ = [
336384
"EvalContext",

0 commit comments

Comments
 (0)