2525
2626"""Define functions for patching NumPy with MKL-based NumPy interface."""
2727
28- import warnings
2928from contextlib import ContextDecorator
3029from threading import Lock , local
3130
32- import numpy as _np
31+ import numpy as np
3332
3433import mkl_random .interfaces .numpy_random as _nrand
3534
36- _DEFAULT_NAMES = tuple (_nrand .__all__ )
37-
3835
3936class _GlobalPatch :
4037 def __init__ (self ):
4138 self ._lock = Lock ()
4239 self ._patch_count = 0
4340 self ._restore_dict = {}
44- self ._patched_functions = tuple (_DEFAULT_NAMES )
45- self ._numpy_module = None
46- self ._requested_names = None
47- self ._active_names = ()
48- self ._patched = ()
41+ # make _patched_functions a tuple (immutable)
42+ self ._patched_functions = tuple (_nrand .__all__ )
4943 self ._tls = local ()
5044
51- def _normalize_names (self , names ):
52- if names is None :
53- names = _DEFAULT_NAMES
54- return tuple (names )
55-
56- def _validate_module (self , numpy_module ):
57- if not hasattr (numpy_module , "random" ):
58- raise TypeError (
59- "Expected a numpy-like module with a `.random` attribute."
60- )
61-
6245 def _register_func (self , name , func ):
6346 if name not in self ._patched_functions :
6447 raise ValueError (f"{ name } not an mkl_random function." )
65- np_random = self ._numpy_module .random
6648 if name not in self ._restore_dict :
67- self ._restore_dict [name ] = getattr (np_random , name )
68- setattr (np_random , name , func )
49+ self ._restore_dict [name ] = getattr (np . random , name )
50+ setattr (np . random , name , func )
6951
7052 def _restore_func (self , name , verbose = False ):
7153 if name not in self ._patched_functions :
@@ -79,51 +61,12 @@ def _restore_func(self, name, verbose=False):
7961 else :
8062 if verbose :
8163 print (f"found and restoring { name } ..." )
82- np_random = self ._numpy_module .random
83- setattr (np_random , name , val )
84-
85- def _initialize_patch (self , numpy_module , names , strict ):
86- self ._validate_module (numpy_module )
87- np_random = numpy_module .random
88- missing = []
89- patchable = []
90- for name in names :
91- if name not in self ._patched_functions :
92- missing .append (name )
93- continue
94- if not hasattr (np_random , name ) or not hasattr (_nrand , name ):
95- missing .append (name )
96- continue
97- patchable .append (name )
98-
99- if strict and missing :
100- raise AttributeError (
101- "Could not patch these names (missing on numpy.random or "
102- "mkl_random.interfaces.numpy_random): "
103- + ", " .join (str (x ) for x in missing )
104- )
105-
106- self ._numpy_module = numpy_module
107- self ._requested_names = names
108- self ._active_names = tuple (patchable )
109- self ._patched = tuple (patchable )
110-
111- def do_patch (
112- self ,
113- numpy_module = None ,
114- names = None ,
115- strict = False ,
116- verbose = False ,
117- ):
118- if numpy_module is None :
119- numpy_module = _np
120- names = self ._normalize_names (names )
121- strict = bool (strict )
64+ setattr (np .random , name , val )
12265
66+ def do_patch (self , verbose = False ):
12367 with self ._lock :
12468 local_count = getattr (self ._tls , "local_count" , 0 )
12569 if self ._patch_count == 0 :
126- self ._initialize_patch (numpy_module , names , strict )
12770 if verbose :
12871 print (
12972 "Now patching NumPy random submodule with mkl_random "
@@ -133,19 +76,8 @@ def do_patch(
13376 "Please direct bug reports to "
13477 "https://github.com/IntelPython/mkl_random"
13578 )
136- for name in self ._active_names :
137- self ._register_func (name , getattr (_nrand , name ))
138- else :
139- if self ._numpy_module is not numpy_module :
140- raise RuntimeError (
141- "Already patched a different numpy module; "
142- "call restore() first."
143- )
144- if names != self ._requested_names :
145- raise RuntimeError (
146- "Already patched with a different names set; "
147- "call restore() first."
148- )
79+ for f in self ._patched_functions :
80+ self ._register_func (f , getattr (_nrand , f ))
14981 self ._patch_count += 1
15082 self ._tls .local_count = local_count + 1
15183
@@ -154,77 +86,47 @@ def do_restore(self, verbose=False):
15486 local_count = getattr (self ._tls , "local_count" , 0 )
15587 if local_count <= 0 :
15688 if verbose :
157- warnings . warn (
89+ print (
15890 "Warning: restore_numpy_random called more times than "
159- "patch_numpy_random in this thread." ,
160- stacklevel = 2 ,
91+ "patch_numpy_random in this thread."
16192 )
16293 return
163-
164- self ._tls .local_count = local_count - 1
94+ self ._tls .local_count -= 1
16595 self ._patch_count -= 1
16696 if self ._patch_count == 0 :
16797 if verbose :
16898 print ("Now restoring original NumPy random submodule." )
16999 for name in tuple (self ._restore_dict ):
170100 self ._restore_func (name , verbose = verbose )
171101 self ._restore_dict .clear ()
172- self ._numpy_module = None
173- self ._requested_names = None
174- self ._active_names = ()
175- self ._patched = ()
176102
177103 def is_patched (self ):
178104 with self ._lock :
179105 return self ._patch_count > 0
180106
181- def patched_names (self ):
182- with self ._lock :
183- return list (self ._patched )
184-
185107
186108_patch = _GlobalPatch ()
187109
188110
189- def patch_numpy_random (
190- numpy_module = None ,
191- names = None ,
192- strict = False ,
193- verbose = False ,
194- ):
111+ def patch_numpy_random (verbose = False ):
195112 """
196- Patch NumPy's random submodule with mkl_random's NumPy interface .
113+ Patch NumPy's random submodule with mkl_random's numpy_interface .
197114
198115 Parameters
199116 ----------
200- numpy_module : module, optional
201- NumPy-like module to patch. Defaults to imported NumPy.
202- names : iterable[str], optional
203- Attributes under `numpy_module.random` to patch.
204- strict : bool, optional
205- Raise if any requested symbol cannot be patched.
206117 verbose : bool, optional
207- Print messages when starting the patching process.
118+ print message when starting the patching process.
119+
120+ Notes
121+ -----
122+ This function uses reference-counted semantics. Each call increments a
123+ global patch counter. Restoration requires a matching number of calls
124+ between `patch_numpy_random` and `restore_numpy_random`.
125+
126+ In multi-threaded programs, prefer the `mkl_random` context manager.
208127
209- Examples
210- --------
211- >>> import numpy as np
212- >>> import mkl_random
213- >>> mkl_random.is_patched()
214- False
215- >>> mkl_random.patch_numpy_random(np)
216- >>> mkl_random.is_patched()
217- True
218- >>> mkl_random.restore()
219- >>> mkl_random.is_patched()
220- False
221128 """
222- _patch .do_patch (
223- numpy_module = numpy_module ,
224- names = names ,
225- strict = bool (strict ),
226- verbose = bool (verbose ),
227- )
129+ _patch .do_patch (verbose = verbose )
228130
229131
230132def restore_numpy_random (verbose = False ):
@@ -234,45 +136,47 @@ def restore_numpy_random(verbose=False):
234136 Parameters
235137 ----------
236138 verbose : bool, optional
237- Print message when starting restoration process.
139+ print message when starting restoration process.
140+
141+ Notes
142+ -----
143+ This function uses reference-counted semantics. Each call decrements a
144+ global patch counter. Restoration requires a matching number of calls
145+ between `patch_numpy_random` and `restore_numpy_random`.
146+
147+ In multi-threaded programs, prefer the `mkl_random` context manager.
148+
238149 """
239- _patch .do_restore (verbose = bool ( verbose ) )
150+ _patch .do_restore (verbose = verbose )
240151
241152
242153def is_patched ():
243- """Return whether NumPy has been patched with mkl_random."""
154+ """Return True if NumPy's random sm is currently patched by mkl_random."""
244155 return _patch .is_patched ()
245156
246157
247- def patched_names ():
248- """Return names actually patched in `numpy.random`."""
249- return _patch .patched_names ()
250-
251-
252158class mkl_random (ContextDecorator ):
253159 """
254160 Context manager and decorator to temporarily patch NumPy random submodule
255161 with MKL-based implementations.
256162
257163 Examples
258164 --------
259- >>> import numpy as np
260165 >>> import mkl_random
261- >>> with mkl_random.mkl_random(np):
262- ... x = np.random.normal(size=10)
263- """
166+ >>> mkl_random.is_patched()
167+ # False
168+
169+ >>> with mkl_random.mkl_random(): # Enable mkl_random in NumPy
170+ >>> print(mkl_random.is_patched())
171+ # True
264172
265- def __init__ ( self , numpy_module = None , names = None , strict = False ):
266- self . _numpy_module = numpy_module
267- self . _names = names
268- self . _strict = strict
173+ >>> mkl_random.is_patched()
174+ # False
175+
176+ """
269177
270178 def __enter__ (self ):
271- patch_numpy_random (
272- numpy_module = self ._numpy_module ,
273- names = self ._names ,
274- strict = self ._strict ,
275- )
179+ patch_numpy_random ()
276180 return self
277181
278182 def __exit__ (self , * exc ):
0 commit comments