Skip to content

Commit 4055e3c

Browse files
committed
task: review fixes
1 parent 8a8b942 commit 4055e3c

File tree

5 files changed

+97
-251
lines changed

5 files changed

+97
-251
lines changed

.pylintrc

Lines changed: 0 additions & 5 deletions
This file was deleted.

mkl_random/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@
9999
is_patched,
100100
mkl_random,
101101
patch_numpy_random,
102-
patched_names,
103102
restore_numpy_random,
104103
)
105104

@@ -155,6 +154,10 @@
155154
"shuffle",
156155
"permutation",
157156
"interfaces",
157+
"mkl_random",
158+
"patch_numpy_random",
159+
"restore_numpy_random",
160+
"is_patched",
158161
]
159162

160163
del _init_helper

mkl_random/_patch_numpy.py

Lines changed: 47 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -25,47 +25,29 @@
2525

2626
"""Define functions for patching NumPy with MKL-based NumPy interface."""
2727

28-
import warnings
2928
from contextlib import ContextDecorator
3029
from threading import Lock, local
3130

32-
import numpy as _np
31+
import numpy as np
3332

3433
import mkl_random.interfaces.numpy_random as _nrand
3534

36-
_DEFAULT_NAMES = tuple(_nrand.__all__)
37-
3835

3936
class _GlobalPatch:
4037
def __init__(self):
4138
self._lock = Lock()
4239
self._patch_count = 0
4340
self._restore_dict = {}
44-
self._patched_functions = tuple(_DEFAULT_NAMES)
45-
self._numpy_module = None
46-
self._requested_names = None
47-
self._active_names = ()
48-
self._patched = ()
41+
# make _patched_functions a tuple (immutable)
42+
self._patched_functions = tuple(_nrand.__all__)
4943
self._tls = local()
5044

51-
def _normalize_names(self, names):
52-
if names is None:
53-
names = _DEFAULT_NAMES
54-
return tuple(names)
55-
56-
def _validate_module(self, numpy_module):
57-
if not hasattr(numpy_module, "random"):
58-
raise TypeError(
59-
"Expected a numpy-like module with a `.random` attribute."
60-
)
61-
6245
def _register_func(self, name, func):
6346
if name not in self._patched_functions:
6447
raise ValueError(f"{name} not an mkl_random function.")
65-
np_random = self._numpy_module.random
6648
if name not in self._restore_dict:
67-
self._restore_dict[name] = getattr(np_random, name)
68-
setattr(np_random, name, func)
49+
self._restore_dict[name] = getattr(np.random, name)
50+
setattr(np.random, name, func)
6951

7052
def _restore_func(self, name, verbose=False):
7153
if name not in self._patched_functions:
@@ -79,51 +61,12 @@ def _restore_func(self, name, verbose=False):
7961
else:
8062
if verbose:
8163
print(f"found and restoring {name}...")
82-
np_random = self._numpy_module.random
83-
setattr(np_random, name, val)
84-
85-
def _initialize_patch(self, numpy_module, names, strict):
86-
self._validate_module(numpy_module)
87-
np_random = numpy_module.random
88-
missing = []
89-
patchable = []
90-
for name in names:
91-
if name not in self._patched_functions:
92-
missing.append(name)
93-
continue
94-
if not hasattr(np_random, name) or not hasattr(_nrand, name):
95-
missing.append(name)
96-
continue
97-
patchable.append(name)
98-
99-
if strict and missing:
100-
raise AttributeError(
101-
"Could not patch these names (missing on numpy.random or "
102-
"mkl_random.interfaces.numpy_random): "
103-
+ ", ".join(str(x) for x in missing)
104-
)
105-
106-
self._numpy_module = numpy_module
107-
self._requested_names = names
108-
self._active_names = tuple(patchable)
109-
self._patched = tuple(patchable)
110-
111-
def do_patch(
112-
self,
113-
numpy_module=None,
114-
names=None,
115-
strict=False,
116-
verbose=False,
117-
):
118-
if numpy_module is None:
119-
numpy_module = _np
120-
names = self._normalize_names(names)
121-
strict = bool(strict)
64+
setattr(np.random, name, val)
12265

66+
def do_patch(self, verbose=False):
12367
with self._lock:
12468
local_count = getattr(self._tls, "local_count", 0)
12569
if self._patch_count == 0:
126-
self._initialize_patch(numpy_module, names, strict)
12770
if verbose:
12871
print(
12972
"Now patching NumPy random submodule with mkl_random "
@@ -133,19 +76,8 @@ def do_patch(
13376
"Please direct bug reports to "
13477
"https://github.com/IntelPython/mkl_random"
13578
)
136-
for name in self._active_names:
137-
self._register_func(name, getattr(_nrand, name))
138-
else:
139-
if self._numpy_module is not numpy_module:
140-
raise RuntimeError(
141-
"Already patched a different numpy module; "
142-
"call restore() first."
143-
)
144-
if names != self._requested_names:
145-
raise RuntimeError(
146-
"Already patched with a different names set; "
147-
"call restore() first."
148-
)
79+
for f in self._patched_functions:
80+
self._register_func(f, getattr(_nrand, f))
14981
self._patch_count += 1
15082
self._tls.local_count = local_count + 1
15183

@@ -154,77 +86,47 @@ def do_restore(self, verbose=False):
15486
local_count = getattr(self._tls, "local_count", 0)
15587
if local_count <= 0:
15688
if verbose:
157-
warnings.warn(
89+
print(
15890
"Warning: restore_numpy_random called more times than "
159-
"patch_numpy_random in this thread.",
160-
stacklevel=2,
91+
"patch_numpy_random in this thread."
16192
)
16293
return
163-
164-
self._tls.local_count = local_count - 1
94+
self._tls.local_count -= 1
16595
self._patch_count -= 1
16696
if self._patch_count == 0:
16797
if verbose:
16898
print("Now restoring original NumPy random submodule.")
16999
for name in tuple(self._restore_dict):
170100
self._restore_func(name, verbose=verbose)
171101
self._restore_dict.clear()
172-
self._numpy_module = None
173-
self._requested_names = None
174-
self._active_names = ()
175-
self._patched = ()
176102

177103
def is_patched(self):
178104
with self._lock:
179105
return self._patch_count > 0
180106

181-
def patched_names(self):
182-
with self._lock:
183-
return list(self._patched)
184-
185107

186108
_patch = _GlobalPatch()
187109

188110

189-
def patch_numpy_random(
190-
numpy_module=None,
191-
names=None,
192-
strict=False,
193-
verbose=False,
194-
):
111+
def patch_numpy_random(verbose=False):
195112
"""
196-
Patch NumPy's random submodule with mkl_random's NumPy interface.
113+
Patch NumPy's random submodule with mkl_random's numpy_interface.
197114
198115
Parameters
199116
----------
200-
numpy_module : module, optional
201-
NumPy-like module to patch. Defaults to imported NumPy.
202-
names : iterable[str], optional
203-
Attributes under `numpy_module.random` to patch.
204-
strict : bool, optional
205-
Raise if any requested symbol cannot be patched.
206117
verbose : bool, optional
207-
Print messages when starting the patching process.
118+
print message when starting the patching process.
119+
120+
Notes
121+
-----
122+
This function uses reference-counted semantics. Each call increments a
123+
global patch counter. Restoration requires a matching number of calls
124+
between `patch_numpy_random` and `restore_numpy_random`.
125+
126+
In multi-threaded programs, prefer the `mkl_random` context manager.
208127
209-
Examples
210-
--------
211-
>>> import numpy as np
212-
>>> import mkl_random
213-
>>> mkl_random.is_patched()
214-
False
215-
>>> mkl_random.patch_numpy_random(np)
216-
>>> mkl_random.is_patched()
217-
True
218-
>>> mkl_random.restore()
219-
>>> mkl_random.is_patched()
220-
False
221128
"""
222-
_patch.do_patch(
223-
numpy_module=numpy_module,
224-
names=names,
225-
strict=bool(strict),
226-
verbose=bool(verbose),
227-
)
129+
_patch.do_patch(verbose=verbose)
228130

229131

230132
def restore_numpy_random(verbose=False):
@@ -234,45 +136,47 @@ def restore_numpy_random(verbose=False):
234136
Parameters
235137
----------
236138
verbose : bool, optional
237-
Print message when starting restoration process.
139+
print message when starting restoration process.
140+
141+
Notes
142+
-----
143+
This function uses reference-counted semantics. Each call decrements a
144+
global patch counter. Restoration requires a matching number of calls
145+
between `patch_numpy_random` and `restore_numpy_random`.
146+
147+
In multi-threaded programs, prefer the `mkl_random` context manager.
148+
238149
"""
239-
_patch.do_restore(verbose=bool(verbose))
150+
_patch.do_restore(verbose=verbose)
240151

241152

242153
def is_patched():
243-
"""Return whether NumPy has been patched with mkl_random."""
154+
"""Return True if NumPy's random sm is currently patched by mkl_random."""
244155
return _patch.is_patched()
245156

246157

247-
def patched_names():
248-
"""Return names actually patched in `numpy.random`."""
249-
return _patch.patched_names()
250-
251-
252158
class mkl_random(ContextDecorator):
253159
"""
254160
Context manager and decorator to temporarily patch NumPy random submodule
255161
with MKL-based implementations.
256162
257163
Examples
258164
--------
259-
>>> import numpy as np
260165
>>> import mkl_random
261-
>>> with mkl_random.mkl_random(np):
262-
... x = np.random.normal(size=10)
263-
"""
166+
>>> mkl_random.is_patched()
167+
# False
168+
169+
>>> with mkl_random.mkl_random(): # Enable mkl_random in NumPy
170+
>>> print(mkl_random.is_patched())
171+
# True
264172
265-
def __init__(self, numpy_module=None, names=None, strict=False):
266-
self._numpy_module = numpy_module
267-
self._names = names
268-
self._strict = strict
173+
>>> mkl_random.is_patched()
174+
# False
175+
176+
"""
269177

270178
def __enter__(self):
271-
patch_numpy_random(
272-
numpy_module=self._numpy_module,
273-
names=self._names,
274-
strict=self._strict,
275-
)
179+
patch_numpy_random()
276180
return self
277181

278182
def __exit__(self, *exc):

0 commit comments

Comments
 (0)