2525
2626"""Define functions for patching NumPy with MKL-based NumPy interface."""
2727
28+ import warnings
2829from contextlib import ContextDecorator
2930from threading import Lock , local
3031
3132import numpy as _np
3233
33- from . import mklrand as _mr
34-
35-
36- _DEFAULT_NAMES = (
37- # Legacy seeding / state
38- "seed" ,
39- "get_state" ,
40- "set_state" ,
41- "RandomState" ,
42- # Common global sampling helpers
43- "random" ,
44- "random_sample" ,
45- "sample" ,
46- "rand" ,
47- "randn" ,
48- "bytes" ,
49- # Integers
50- "randint" ,
51- # Common distributions (only patched if present on both sides)
52- "standard_normal" ,
53- "normal" ,
54- "uniform" ,
55- "exponential" ,
56- "gamma" ,
57- "beta" ,
58- "chisquare" ,
59- "f" ,
60- "lognormal" ,
61- "laplace" ,
62- "logistic" ,
63- "multivariate_normal" ,
64- "poisson" ,
65- "power" ,
66- "rayleigh" ,
67- "triangular" ,
68- "vonmises" ,
69- "wald" ,
70- "weibull" ,
71- "zipf" ,
72- # Permutations / choices
73- "choice" ,
74- "permutation" ,
75- "shuffle" ,
76- )
34+ import mkl_random .interfaces .numpy_random as _nrand
35+
36+ _DEFAULT_NAMES = tuple (_nrand .__all__ )
7737
7838
7939class _GlobalPatch :
@@ -131,16 +91,16 @@ def _initialize_patch(self, numpy_module, names, strict):
13191 if name not in self ._patched_functions :
13292 missing .append (name )
13393 continue
134- if not hasattr (np_random , name ) or not hasattr (_mr , name ):
94+ if not hasattr (np_random , name ) or not hasattr (_nrand , name ):
13595 missing .append (name )
13696 continue
13797 patchable .append (name )
13898
13999 if strict and missing :
140100 raise AttributeError (
141101 "Could not patch these names (missing on numpy.random or "
142- "mkl_random.mklrand ): "
143- + ", " .join ([ str (x ) for x in missing ] )
102+ "mkl_random.interfaces.numpy_random ): "
103+ + ", " .join (str (x ) for x in missing )
144104 )
145105
146106 self ._numpy_module = numpy_module
@@ -174,7 +134,7 @@ def do_patch(
174134 "https://github.com/IntelPython/mkl_random"
175135 )
176136 for name in self ._active_names :
177- self ._register_func (name , getattr (_mr , name ))
137+ self ._register_func (name , getattr (_nrand , name ))
178138 else :
179139 if self ._numpy_module is not numpy_module :
180140 raise RuntimeError (
@@ -194,9 +154,10 @@ def do_restore(self, verbose=False):
194154 local_count = getattr (self ._tls , "local_count" , 0 )
195155 if local_count <= 0 :
196156 if verbose :
197- print (
157+ warnings . warn (
198158 "Warning: restore_numpy_random called more times than "
199- "patch_numpy_random in this thread."
159+ "patch_numpy_random in this thread." ,
160+ stacklevel = 2 ,
200161 )
201162 return
202163
@@ -279,16 +240,12 @@ def restore_numpy_random(verbose=False):
279240
280241
281242def is_patched ():
282- """
283- Returns whether NumPy has been patched with mkl_random.
284- """
243+ """Return whether NumPy has been patched with mkl_random."""
285244 return _patch .is_patched ()
286245
287246
288247def patched_names ():
289- """
290- Returns the names actually patched in `numpy.random`.
291- """
248+ """Return names actually patched in `numpy.random`."""
292249 return _patch .patched_names ()
293250
294251
0 commit comments