Skip to content

Commit c56e34b

Browse files
committed
fix: patching to match mkl_fft
1 parent 15404d2 commit c56e34b

File tree

2 files changed

+121
-85
lines changed

2 files changed

+121
-85
lines changed

mkl_random/src/_patch.pyx

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
3535
compatible replacements.
3636
"""
3737

38-
from threading import local as threading_local
38+
from threading import Lock, local
3939
from contextlib import ContextDecorator
4040

4141
import numpy as _np
@@ -89,51 +89,30 @@ cdef tuple _DEFAULT_NAMES = (
8989
)
9090

9191

92-
cdef class patch:
93-
cdef bint _is_patched
94-
cdef object _numpy_module
95-
cdef object _originals # dict: name -> original object
96-
cdef object _patched # list of names actually patched
97-
98-
def __cinit__(self):
99-
self._is_patched = False
92+
class _GlobalPatch:
93+
def __init__(self):
94+
self._lock = Lock()
95+
self._patch_count = 0
10096
self._numpy_module = None
97+
self._requested_names = None
10198
self._originals = {}
102-
self._patched = []
103-
104-
def do_patch(self, numpy_module=None, names=None, bint strict=False):
105-
"""
106-
Patch the given numpy module (default: imported numpy) in-place.
107-
108-
Parameters
109-
----------
110-
numpy_module : module, optional
111-
The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`).
112-
names : iterable[str], optional
113-
Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES.
114-
strict : bool
115-
If True, raise if any requested symbol cannot be patched.
116-
"""
117-
if numpy_module is None:
118-
numpy_module = _np
99+
self._patched = ()
100+
self._tls = local()
101+
102+
def _normalize_names(self, names):
119103
if names is None:
120104
names = _DEFAULT_NAMES
105+
return tuple(names)
121106

107+
def _validate_module(self, numpy_module):
122108
if not hasattr(numpy_module, "random"):
123109
raise TypeError("Expected a numpy-like module with a `.random` attribute.")
124110

125-
# If already patched, only allow idempotent re-entry for the same numpy module.
126-
if self._is_patched:
127-
if self._numpy_module is numpy_module:
128-
return
129-
raise RuntimeError("Already patched a different numpy module; call restore() first.")
130-
111+
def _apply_patch(self, numpy_module, names, strict):
131112
np_random = numpy_module.random
132-
133113
originals = {}
134114
patched = []
135115
missing = []
136-
137116
for name in names:
138117
if not hasattr(np_random, name) or not hasattr(_mr, name):
139118
missing.append(name)
@@ -143,58 +122,75 @@ cdef class patch:
143122
patched.append(name)
144123

145124
if strict and missing:
146-
# revert partial patch before raising
147-
for n, v in originals.items():
148-
setattr(np_random, n, v)
125+
for name, value in originals.items():
126+
setattr(np_random, name, value)
149127
raise AttributeError(
150128
"Could not patch these names (missing on numpy.random or mkl_random.mklrand): "
151129
+ ", ".join([str(x) for x in missing])
152130
)
153131

154132
self._numpy_module = numpy_module
133+
self._requested_names = names
155134
self._originals = originals
156-
self._patched = patched
157-
self._is_patched = True
158-
159-
def do_unpatch(self):
160-
"""
161-
Restore the previously patched numpy module.
162-
"""
163-
if not self._is_patched:
164-
return
165-
numpy_module = self._numpy_module
166-
np_random = numpy_module.random
167-
for n, v in self._originals.items():
168-
setattr(np_random, n, v)
135+
self._patched = tuple(patched)
169136

170-
self._numpy_module = None
171-
self._originals = {}
172-
self._patched = []
173-
self._is_patched = False
137+
def do_patch(self, numpy_module=None, names=None, strict=False, verbose=False):
138+
if numpy_module is None:
139+
numpy_module = _np
140+
names = self._normalize_names(names)
141+
self._validate_module(numpy_module)
142+
strict = bool(strict)
143+
144+
with self._lock:
145+
local_count = getattr(self._tls, "local_count", 0)
146+
if self._patch_count == 0:
147+
self._apply_patch(numpy_module, names, strict)
148+
else:
149+
if self._numpy_module is not numpy_module:
150+
raise RuntimeError(
151+
"Already patched a different numpy module; call restore() first."
152+
)
153+
if names != self._requested_names:
154+
raise RuntimeError(
155+
"Already patched with a different names set; call restore() first."
156+
)
157+
self._patch_count += 1
158+
self._tls.local_count = local_count + 1
159+
160+
def do_restore(self, verbose=False):
161+
with self._lock:
162+
local_count = getattr(self._tls, "local_count", 0)
163+
if local_count <= 0:
164+
if verbose:
165+
print(
166+
"Warning: restore called more times than monkey_patch in this thread."
167+
)
168+
return
169+
170+
self._tls.local_count = local_count - 1
171+
self._patch_count -= 1
172+
if self._patch_count == 0:
173+
np_random = self._numpy_module.random
174+
for name, value in self._originals.items():
175+
setattr(np_random, name, value)
176+
self._numpy_module = None
177+
self._requested_names = None
178+
self._originals = {}
179+
self._patched = ()
174180

175181
def is_patched(self):
176-
return self._is_patched
182+
with self._lock:
183+
return self._patch_count > 0
177184

178185
def patched_names(self):
179-
"""
180-
Returns list of names that were actually patched.
181-
"""
182-
return list(self._patched)
183-
184-
185-
_tls = threading_local()
186-
187-
188-
def _is_tls_initialized():
189-
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
186+
with self._lock:
187+
return list(self._patched)
190188

191189

192-
def _initialize_tls():
193-
_tls.patch = patch()
194-
_tls.initialized = True
190+
_patch = _GlobalPatch()
195191

196192

197-
def monkey_patch(numpy_module=None, names=None, strict=False):
193+
def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
198194
"""
199195
Enables using mkl_random in the given NumPy module by patching `numpy.random`.
200196
@@ -211,43 +207,45 @@ def monkey_patch(numpy_module=None, names=None, strict=False):
211207
>>> mkl_random.is_patched()
212208
False
213209
"""
214-
if not _is_tls_initialized():
215-
_initialize_tls()
216-
_tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict))
210+
_patch.do_patch(
211+
numpy_module=numpy_module,
212+
names=names,
213+
strict=bool(strict),
214+
verbose=bool(verbose),
215+
)
217216

218217

219-
def use_in_numpy(numpy_module=None, names=None, strict=False):
218+
def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
220219
"""
221220
Backward-compatible alias for monkey_patch().
222221
"""
223-
monkey_patch(numpy_module=numpy_module, names=names, strict=strict)
222+
monkey_patch(
223+
numpy_module=numpy_module,
224+
names=names,
225+
strict=strict,
226+
verbose=verbose,
227+
)
224228

225229

226-
def restore():
230+
def restore(verbose=False):
227231
"""
228232
Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols.
229233
"""
230-
if not _is_tls_initialized():
231-
_initialize_tls()
232-
_tls.patch.do_unpatch()
234+
_patch.do_restore(verbose=bool(verbose))
233235

234236

235237
def is_patched():
236238
"""
237239
Returns whether NumPy has been patched with mkl_random.
238240
"""
239-
if not _is_tls_initialized():
240-
_initialize_tls()
241-
return bool(_tls.patch.is_patched())
241+
return _patch.is_patched()
242242

243243

244244
def patched_names():
245245
"""
246246
Returns the names actually patched in `numpy.random`.
247247
"""
248-
if not _is_tls_initialized():
249-
_initialize_tls()
250-
return _tls.patch.patched_names()
248+
return _patch.patched_names()
251249

252250

253251
class mkl_random(ContextDecorator):

mkl_random/tests/test_patch.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,41 @@ def test_patched_names():
9393
assert "RandomState" in names
9494
finally:
9595
mkl_random.restore()
96+
97+
98+
def test_patch_redundant_patching():
99+
orig_normal = np.random.normal
100+
assert not mkl_random.is_patched()
101+
102+
mkl_random.monkey_patch(np)
103+
mkl_random.monkey_patch(np)
104+
105+
assert mkl_random.is_patched()
106+
assert np.random.normal is mkl_random.mklrand.normal
107+
108+
mkl_random.restore()
109+
assert mkl_random.is_patched()
110+
assert np.random.normal is mkl_random.mklrand.normal
111+
112+
mkl_random.restore()
113+
assert not mkl_random.is_patched()
114+
assert np.random.normal is orig_normal
115+
116+
117+
def test_patch_reentrant():
118+
orig_uniform = np.random.uniform
119+
assert not mkl_random.is_patched()
120+
121+
with mkl_random.mkl_random(np):
122+
assert mkl_random.is_patched()
123+
assert np.random.uniform is not orig_uniform
124+
125+
with mkl_random.mkl_random(np):
126+
assert mkl_random.is_patched()
127+
assert np.random.uniform is not orig_uniform
128+
129+
assert mkl_random.is_patched()
130+
assert np.random.uniform is not orig_uniform
131+
132+
assert not mkl_random.is_patched()
133+
assert np.random.uniform is orig_uniform

0 commit comments

Comments
 (0)