Skip to content

Commit 27e275a

Browse files
committed
chore: pre-commit fix and actually match mkl_fft
1 parent a7b9fde commit 27e275a

File tree

3 files changed

+15
-63
lines changed

3 files changed

+15
-63
lines changed

.pylintrc

Lines changed: 0 additions & 5 deletions
This file was deleted.

mkl_random/_patch_numpy.py

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,55 +25,15 @@
2525

2626
"""Define functions for patching NumPy with MKL-based NumPy interface."""
2727

28+
import warnings
2829
from contextlib import ContextDecorator
2930
from threading import Lock, local
3031

3132
import numpy as _np
3233

33-
from . import mklrand as _mr
34-
35-
36-
_DEFAULT_NAMES = (
37-
# Legacy seeding / state
38-
"seed",
39-
"get_state",
40-
"set_state",
41-
"RandomState",
42-
# Common global sampling helpers
43-
"random",
44-
"random_sample",
45-
"sample",
46-
"rand",
47-
"randn",
48-
"bytes",
49-
# Integers
50-
"randint",
51-
# Common distributions (only patched if present on both sides)
52-
"standard_normal",
53-
"normal",
54-
"uniform",
55-
"exponential",
56-
"gamma",
57-
"beta",
58-
"chisquare",
59-
"f",
60-
"lognormal",
61-
"laplace",
62-
"logistic",
63-
"multivariate_normal",
64-
"poisson",
65-
"power",
66-
"rayleigh",
67-
"triangular",
68-
"vonmises",
69-
"wald",
70-
"weibull",
71-
"zipf",
72-
# Permutations / choices
73-
"choice",
74-
"permutation",
75-
"shuffle",
76-
)
34+
import mkl_random.interfaces.numpy_random as _nrand
35+
36+
_DEFAULT_NAMES = tuple(_nrand.__all__)
7737

7838

7939
class _GlobalPatch:
@@ -131,16 +91,16 @@ def _initialize_patch(self, numpy_module, names, strict):
13191
if name not in self._patched_functions:
13292
missing.append(name)
13393
continue
134-
if not hasattr(np_random, name) or not hasattr(_mr, name):
94+
if not hasattr(np_random, name) or not hasattr(_nrand, name):
13595
missing.append(name)
13696
continue
13797
patchable.append(name)
13898

13999
if strict and missing:
140100
raise AttributeError(
141101
"Could not patch these names (missing on numpy.random or "
142-
"mkl_random.mklrand): "
143-
+ ", ".join([str(x) for x in missing])
102+
"mkl_random.interfaces.numpy_random): "
103+
+ ", ".join(str(x) for x in missing)
144104
)
145105

146106
self._numpy_module = numpy_module
@@ -174,7 +134,7 @@ def do_patch(
174134
"https://github.com/IntelPython/mkl_random"
175135
)
176136
for name in self._active_names:
177-
self._register_func(name, getattr(_mr, name))
137+
self._register_func(name, getattr(_nrand, name))
178138
else:
179139
if self._numpy_module is not numpy_module:
180140
raise RuntimeError(
@@ -194,9 +154,10 @@ def do_restore(self, verbose=False):
194154
local_count = getattr(self._tls, "local_count", 0)
195155
if local_count <= 0:
196156
if verbose:
197-
print(
157+
warnings.warn(
198158
"Warning: restore_numpy_random called more times than "
199-
"patch_numpy_random in this thread."
159+
"patch_numpy_random in this thread.",
160+
stacklevel=2,
200161
)
201162
return
202163

@@ -279,16 +240,12 @@ def restore_numpy_random(verbose=False):
279240

280241

281242
def is_patched():
282-
"""
283-
Returns whether NumPy has been patched with mkl_random.
284-
"""
243+
"""Return whether NumPy has been patched with mkl_random."""
285244
return _patch.is_patched()
286245

287246

288247
def patched_names():
289-
"""
290-
Returns the names actually patched in `numpy.random`.
291-
"""
248+
"""Return names actually patched in `numpy.random`."""
292249
return _patch.patched_names()
293250

294251

mkl_random/tests/test_patch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ def test_patch_redundant_patching():
135135
mkl_random.patch_numpy_random(np)
136136
mkl_random.patch_numpy_random(np)
137137
assert mkl_random.is_patched()
138-
assert np.random.normal is mkl_random.mklrand.normal
138+
assert np.random.normal is mkl_random.normal
139139
mkl_random.restore_numpy_random()
140140
assert mkl_random.is_patched()
141-
assert np.random.normal is mkl_random.mklrand.normal
141+
assert np.random.normal is mkl_random.normal
142142
mkl_random.restore_numpy_random()
143143
assert not mkl_random.is_patched()
144144
assert np.random.normal is orig_normal

0 commit comments

Comments
 (0)