Skip to content

Commit 8fcf603

Browse files
committed
Skip checking if the eval context is active if we're in the background thread of async execution.
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 2acd546 commit 8fcf603

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

dali/python/nvidia/dali/experimental/dynamic/_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self):
7878
self._queue = queue.SimpleQueue[Optional[_Future]]()
7979
self._thread: Optional[threading.Thread] = None
8080
self._event = threading.Event()
81+
self._main_thread = threading.current_thread()
8182

8283
def _worker(self):
8384
while True:

dali/python/nvidia/dali/experimental/dynamic/_eval_context.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,16 @@ def _is_current(self) -> bool:
210210
return self is _tls.default.get(current_device_id)
211211

212212
def __enter__(self):
213-
if not self._lock.acquire(blocking=False):
213+
skip_lock = self._is_in_background_thread()
214+
if not skip_lock and not self._lock.acquire(blocking=False):
214215
raise RuntimeError("An EvalContext cannot be active in two threads simultaneously.")
215216
try:
216217
_tls.stack.append(self)
217218
if self._device:
218219
self._device.__enter__()
219220
except Exception:
220-
self._lock.release()
221+
if not skip_lock:
222+
self._lock.release()
221223
raise
222224
return self
223225

@@ -235,7 +237,8 @@ def __exit__(self, exc_type, exc_value, traceback):
235237
if self._device:
236238
self._device.__exit__(exc_type, exc_value, traceback)
237239
finally:
238-
self._lock.release()
240+
if not self._is_in_background_thread():
241+
self._lock.release()
239242

240243
def evaluate_all(self):
241244
"""Evaluates all pending invocations."""

0 commit comments

Comments
 (0)