@@ -35,7 +35,7 @@ replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
3535compatible replacements.
3636"""
3737
38- from threading import local as threading_local
38+ from threading import Lock, local
3939from contextlib import ContextDecorator
4040
4141import numpy as _np
@@ -89,51 +89,30 @@ cdef tuple _DEFAULT_NAMES = (
8989)
9090
9191
92- cdef class patch:
93- cdef bint _is_patched
94- cdef object _numpy_module
95- cdef object _originals # dict: name -> original object
96- cdef object _patched # list of names actually patched
97-
98- def __cinit__ (self ):
99- self ._is_patched = False
92+ class _GlobalPatch :
93+ def __init__ (self ):
94+ self ._lock = Lock()
95+ self ._patch_count = 0
10096 self ._numpy_module = None
97+ self ._requested_names = None
10198 self ._originals = {}
102- self ._patched = []
103-
104- def do_patch (self , numpy_module = None , names = None , bint strict = False ):
105- """
106- Patch the given numpy module (default: imported numpy) in-place.
107-
108- Parameters
109- ----------
110- numpy_module : module, optional
111- The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`).
112- names : iterable[str], optional
113- Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES.
114- strict : bool
115- If True, raise if any requested symbol cannot be patched.
116- """
117- if numpy_module is None :
118- numpy_module = _np
99+ self ._patched = ()
100+ self ._tls = local()
101+
102+ def _normalize_names (self , names ):
119103 if names is None :
120104 names = _DEFAULT_NAMES
105+ return tuple (names)
121106
107+ def _validate_module (self , numpy_module ):
122108 if not hasattr (numpy_module, " random" ):
123109 raise TypeError (" Expected a numpy-like module with a `.random` attribute." )
124110
125- # If already patched, only allow idempotent re-entry for the same numpy module.
126- if self ._is_patched:
127- if self ._numpy_module is numpy_module:
128- return
129- raise RuntimeError (" Already patched a different numpy module; call restore() first." )
130-
111+ def _apply_patch (self , numpy_module , names , strict ):
131112 np_random = numpy_module.random
132-
133113 originals = {}
134114 patched = []
135115 missing = []
136-
137116 for name in names:
138117 if not hasattr (np_random, name) or not hasattr (_mr, name):
139118 missing.append(name)
@@ -143,58 +122,75 @@ cdef class patch:
143122 patched.append(name)
144123
145124 if strict and missing:
146- # revert partial patch before raising
147- for n, v in originals.items():
148- setattr (np_random, n, v)
125+ for name, value in originals.items():
126+ setattr (np_random, name, value)
149127 raise AttributeError (
150128 " Could not patch these names (missing on numpy.random or mkl_random.mklrand): "
151129 + " , " .join([str (x) for x in missing])
152130 )
153131
154132 self ._numpy_module = numpy_module
133+ self ._requested_names = names
155134 self ._originals = originals
156- self ._patched = patched
157- self ._is_patched = True
158-
159- def do_unpatch (self ):
160- """
161- Restore the previously patched numpy module.
162- """
163- if not self ._is_patched:
164- return
165- numpy_module = self ._numpy_module
166- np_random = numpy_module.random
167- for n, v in self ._originals.items():
168- setattr (np_random, n, v)
135+ self ._patched = tuple (patched)
169136
170- self ._numpy_module = None
171- self ._originals = {}
172- self ._patched = []
173- self ._is_patched = False
137+ def do_patch (self , numpy_module = None , names = None , strict = False , verbose = False ):
138+ if numpy_module is None :
139+ numpy_module = _np
140+ names = self ._normalize_names(names)
141+ self ._validate_module(numpy_module)
142+ strict = bool (strict)
143+
144+ with self ._lock:
145+ local_count = getattr (self ._tls, " local_count" , 0 )
146+ if self ._patch_count == 0 :
147+ self ._apply_patch(numpy_module, names, strict)
148+ else :
149+ if self ._numpy_module is not numpy_module:
150+ raise RuntimeError (
151+ " Already patched a different numpy module; call restore() first."
152+ )
153+ if names != self ._requested_names:
154+ raise RuntimeError (
155+ " Already patched with a different names set; call restore() first."
156+ )
157+ self ._patch_count += 1
158+ self ._tls.local_count = local_count + 1
159+
160+ def do_restore (self , verbose = False ):
161+ with self ._lock:
162+ local_count = getattr (self ._tls, " local_count" , 0 )
163+ if local_count <= 0 :
164+ if verbose:
165+ print (
166+ " Warning: restore called more times than monkey_patch in this thread."
167+ )
168+ return
169+
170+ self ._tls.local_count = local_count - 1
171+ self ._patch_count -= 1
172+ if self ._patch_count == 0 :
173+ np_random = self ._numpy_module.random
174+ for name, value in self ._originals.items():
175+ setattr (np_random, name, value)
176+ self ._numpy_module = None
177+ self ._requested_names = None
178+ self ._originals = {}
179+ self ._patched = ()
174180
175181 def is_patched (self ):
176- return self ._is_patched
182+ with self ._lock:
183+ return self ._patch_count > 0
177184
178185 def patched_names (self ):
179- """
180- Returns list of names that were actually patched.
181- """
182- return list (self ._patched)
183-
184-
185- _tls = threading_local()
186-
187-
188- def _is_tls_initialized ():
189- return (getattr (_tls, " initialized" , None ) is not None ) and (_tls.initialized is True )
186+ with self ._lock:
187+ return list (self ._patched)
190188
191189
192- def _initialize_tls ():
193- _tls.patch = patch()
194- _tls.initialized = True
190+ _patch = _GlobalPatch()
195191
196192
197- def monkey_patch (numpy_module = None , names = None , strict = False ):
193+ def monkey_patch (numpy_module = None , names = None , strict = False , verbose = False ):
198194 """
199195 Enables using mkl_random in the given NumPy module by patching `numpy.random`.
200196
@@ -211,43 +207,45 @@ def monkey_patch(numpy_module=None, names=None, strict=False):
211207 >>> mkl_random.is_patched()
212208 False
213209 """
214- if not _is_tls_initialized():
215- _initialize_tls()
216- _tls.patch.do_patch(numpy_module = numpy_module, names = names, strict = bool (strict))
210+ _patch.do_patch(
211+ numpy_module = numpy_module,
212+ names = names,
213+ strict = bool (strict),
214+ verbose = bool (verbose),
215+ )
217216
218217
219- def use_in_numpy (numpy_module = None , names = None , strict = False ):
218+ def use_in_numpy (numpy_module = None , names = None , strict = False , verbose = False ):
220219 """
221220 Backward-compatible alias for monkey_patch().
222221 """
223- monkey_patch(numpy_module = numpy_module, names = names, strict = strict)
222+ monkey_patch(
223+ numpy_module = numpy_module,
224+ names = names,
225+ strict = strict,
226+ verbose = verbose,
227+ )
224228
225229
226- def restore ():
230+ def restore (verbose = False ):
227231 """
228232 Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols.
229233 """
230- if not _is_tls_initialized():
231- _initialize_tls()
232- _tls.patch.do_unpatch()
234+ _patch.do_restore(verbose = bool (verbose))
233235
234236
235237def is_patched ():
236238 """
237239 Returns whether NumPy has been patched with mkl_random.
238240 """
239- if not _is_tls_initialized():
240- _initialize_tls()
241- return bool (_tls.patch.is_patched())
241+ return _patch.is_patched()
242242
243243
244244def patched_names ():
245245 """
246246 Returns the names actually patched in `numpy.random`.
247247 """
248- if not _is_tls_initialized():
249- _initialize_tls()
250- return _tls.patch.patched_names()
248+ return _patch.patched_names()
251249
252250
253251class mkl_random (ContextDecorator ):
0 commit comments