Skip to content

Commit 5ffdbce

Browse files
authored
Raise exceptions when an EvalContext is active in multiple threads (#6221)
* Disallow EvalContext from being active in multiple threads * Update the threading guide to reflect the fact that exceptins are raised * Skip checking if the eval context is active if we're in the background thread of async execution --------- Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 1d6d57f commit 5ffdbce

3 files changed

Lines changed: 62 additions & 54 deletions

File tree

dali/python/nvidia/dali/experimental/dynamic/_eval_context.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import sys
17+
import threading
1618
import weakref
17-
from threading import current_thread, local
1819

1920
import nvidia.dali.backend_impl as _b
2021

2122
from . import _device, _stream
2223
from ._async import _AsyncExecutor
2324

2425

25-
class _ThreadLocalStorage(local):
26+
class _ThreadLocalStorage(threading.local):
2627
def __init__(self):
2728
super().__init__()
2829
self.default = {} # per-device default context
@@ -35,7 +36,6 @@ def __init__(self):
3536
def _default_num_threads():
3637
"""Gets the default number of threads used in DALI dynamic mode."""
3738
import os
38-
import sys
3939
from functools import wraps
4040

4141
mod = sys.modules[__name__]
@@ -164,11 +164,14 @@ def __init__(self, *, num_threads=None, device_id=None, cuda_stream=None):
164164
self._instance_cache = {}
165165

166166
# The thread pool needs to be thread-local because of eager execution
167-
self._tls = local()
167+
self._tls = threading.local()
168168

169169
self._async_executor = _AsyncExecutor()
170170
weakref.finalize(self, self._async_executor.shutdown)
171171

172+
# Used to disallow the EvalContext to be active in two threads simultaneously
173+
self._lock = threading.RLock()
174+
172175
def _purge_operator_cache(self):
173176
"""Empties the operator instance cache"""
174177
self._instance_cache = {}
@@ -207,18 +210,35 @@ def _is_current(self) -> bool:
207210
return self is _tls.default.get(current_device_id)
208211

209212
def __enter__(self):
210-
_tls.stack.append(self)
211-
if self._device:
212-
self._device.__enter__()
213+
skip_lock = self._is_in_background_thread()
214+
if not skip_lock and not self._lock.acquire(blocking=False):
215+
raise RuntimeError("An EvalContext cannot be active in two threads simultaneously.")
216+
try:
217+
_tls.stack.append(self)
218+
if self._device:
219+
self._device.__enter__()
220+
except Exception:
221+
if not skip_lock:
222+
self._lock.release()
223+
raise
213224
return self
214225

215226
def __exit__(self, exc_type, exc_value, traceback):
216-
assert _tls.stack[-1] is self
217-
if len(_tls.stack) < 2 or (_tls.stack[-2] is not self):
218-
self.evaluate_all()
219-
_tls.stack.pop()
220-
if self._device:
221-
self._device.__exit__(exc_type, exc_value, traceback)
227+
try:
228+
# During interpreter shutdown, finalizers of objects created in background threads
229+
# can be called from the main thread.
230+
if _tls.stack:
231+
assert _tls.stack[-1] is self
232+
if len(_tls.stack) < 2 or (_tls.stack[-2] is not self):
233+
self.evaluate_all()
234+
_tls.stack.pop()
235+
else:
236+
assert sys.is_finalizing()
237+
if self._device:
238+
self._device.__exit__(exc_type, exc_value, traceback)
239+
finally:
240+
if not self._is_in_background_thread():
241+
self._lock.release()
222242

223243
def evaluate_all(self):
224244
"""Evaluates all pending invocations."""
@@ -312,7 +332,7 @@ def _snapshot(self):
312332
return ctx
313333

314334
def _is_in_background_thread(self):
315-
return current_thread() is self._async_executor._thread
335+
return threading.current_thread() is self._async_executor._thread
316336

317337

318338
__all__ = [

dali/test/python/experimental_mode/test_multithreading.py

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,52 +13,31 @@
1313
# limitations under the License.
1414

1515

16-
import functools
1716
import os
18-
import sys
1917
import threading
2018
from collections.abc import Callable
2119
from typing import TypeVar
2220

2321
import numpy as np
2422
import nvidia.dali.experimental.dynamic as ndd
2523
from nose2.tools import cartesian_params, params
26-
from nose_utils import SkipTest
27-
28-
29-
def allow_nogil_failure(exc_type: type[Exception]):
30-
"""
31-
Skip the test on free-threaded Python if a specific exception is raised.
32-
This is useful until https://github.com/python/cpython/pull/133305 is backported.
33-
"""
34-
35-
def decorator(test_func):
36-
if getattr(sys, "_is_gil_enabled", lambda: True)():
37-
return test_func
38-
39-
@functools.wraps(test_func)
40-
def wrapper(*args, **kwargs):
41-
try:
42-
return test_func(*args, **kwargs)
43-
except exc_type:
44-
raise SkipTest(f"{exc_type.__name__} allowed for this test with the GIL disabled")
45-
46-
return wrapper
47-
48-
return decorator
49-
24+
from nose_utils import assert_raises
5025

5126
T = TypeVar("T")
5227

5328

54-
def run_parallel(function: Callable[[int], T], num_threads: int | None = None) -> dict[int, T]:
29+
def get_num_threads(num_threads: int | None = None):
5530
if num_threads is None:
5631
try:
5732
num_threads = len(os.sched_getaffinity(0))
5833
except AttributeError:
5934
num_threads = os.cpu_count() or 4
6035

61-
num_threads = min(32, num_threads)
36+
return min(32, num_threads)
37+
38+
39+
def run_parallel(function: Callable[[int], T], num_threads: int | None = None) -> dict[int, T]:
40+
num_threads = get_num_threads(num_threads)
6241

6342
barrier = threading.Barrier(num_threads)
6443
results = {}
@@ -89,7 +68,6 @@ def wrapper(thread_id: int):
8968
return results
9069

9170

92-
@allow_nogil_failure(KeyError)
9371
@params("cpu", "gpu")
9472
def test_parallel_eval_contexts(device):
9573
def worker(thread_id: int):
@@ -109,7 +87,6 @@ def worker(thread_id: int):
10987
np.testing.assert_equal(actual.cpu(), expected)
11088

11189

112-
@allow_nogil_failure(KeyError)
11390
@params("cpu", "gpu")
11491
def test_parallel_creation(device):
11592
def worker(thread_id: int):
@@ -135,7 +112,6 @@ def worker(thread_id: int):
135112
np.testing.assert_array_equal(actual.cpu(), expected)
136113

137114

138-
@allow_nogil_failure(KeyError)
139115
def test_parallel_different_devices():
140116
def worker(thread_id: int):
141117
device = "cpu" if thread_id % 2 == 0 else "gpu"
@@ -154,7 +130,6 @@ def worker(thread_id: int):
154130
np.testing.assert_equal(result.cpu(), expected)
155131

156132

157-
@allow_nogil_failure(KeyError)
158133
@cartesian_params(("cpu", "gpu"), ndd.EvalMode)
159134
def test_parallel_eval_modes(device, eval_mode):
160135
def worker(thread_id: int):
@@ -173,7 +148,6 @@ def worker(thread_id: int):
173148
np.testing.assert_array_almost_equal(actual.cpu(), expected)
174149

175150

176-
@allow_nogil_failure(KeyError)
177151
@params("cpu", "gpu")
178152
def test_parallel_mixed_eval_modes(device):
179153
eval_modes = tuple(ndd.EvalMode)
@@ -200,7 +174,6 @@ def worker(thread_id: int):
200174
np.testing.assert_array_almost_equal(data["result"].cpu(), data["expected"])
201175

202176

203-
@allow_nogil_failure(KeyError)
204177
@params("cpu", "gpu")
205178
def test_parallel_indexing(device):
206179
tensor = ndd.tensor([[1, 2, 3], [4, 5, 6]], device=device)
@@ -220,7 +193,6 @@ def worker(thread_id: int):
220193
assert result == tensor.cpu()[slice].item()
221194

222195

223-
@allow_nogil_failure(KeyError)
224196
@params("cpu", "gpu")
225197
def test_thread_local_rng_determinism(device):
226198
def worker(_):
@@ -242,7 +214,6 @@ def worker(_):
242214
np.testing.assert_array_equal(data["normal"].cpu(), reference["normal"].cpu())
243215

244216

245-
@allow_nogil_failure(KeyError)
246217
@params("cpu", "gpu")
247218
def test_chained_threads(device):
248219
source = ndd.tensor([1, 2, 3, 4], dtype=ndd.float32, device=device).evaluate()
@@ -266,3 +237,18 @@ def worker2(tensor: ndd.Tensor):
266237

267238
assert result is not None
268239
np.testing.assert_array_almost_equal(result.cpu(), source.cpu())
240+
241+
242+
def test_error_parallel_eval_contexts():
243+
def worker(_):
244+
with ctx:
245+
try:
246+
barrier.wait(0.1)
247+
except threading.BrokenBarrierError:
248+
pass
249+
250+
barrier = threading.Barrier(get_num_threads())
251+
ctx = ndd.EvalContext()
252+
253+
with assert_raises(RuntimeError, glob="EvalContext"):
254+
run_parallel(worker)

docs/dali_dynamic/threading.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ threads.
1919
:octicon:`alert-fill;1.2em;align-text-bottom text-warning` Multiple threads using the same :class:`EvalContext`:
2020

2121
.. code-block:: python
22-
:emphasize-lines: 4
22+
:emphasize-lines: 7
2323
2424
import threading
2525
import nvidia.dali.experimental.dynamic as ndd
2626
2727
ctx = ndd.EvalContext(num_threads=4)
2828
2929
def worker():
30-
with ctx: # Bad: using the same EvalContext in multiple threads simultaneously
30+
with ctx: # Raises an exception
3131
img = ndd.random.uniform(shape=(100, 100, 3), range=(0, 255), dtype=ndd.uint8)
3232
flipped = ndd.flip(img, horizontal=True)
3333
...
@@ -39,10 +39,12 @@ threads.
3939
t.join()
4040
4141
Here, the code should either create an instance of the evaluation context per thread, or use
42-
:func:`set_num_threads`. Here's a corrected version:
42+
:func:`set_num_threads`.
43+
44+
:octicon:`check-circle-fill;1.2em;align-text-bottom text-success` Correct code using
45+
:func:`set_num_threads`:
4346

4447
.. code-block:: python
45-
:emphasize-lines: 4
4648
4749
import threading
4850
import nvidia.dali.experimental.dynamic as ndd

0 commit comments

Comments
 (0)