Skip to content

Commit 38fe23d

Browse files
committed
fix: patching to match mkl_fft, lint, and review
1 parent 15404d2 commit 38fe23d

File tree

6 files changed

+392
-312
lines changed

6 files changed

+392
-312
lines changed

.pylintrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[MASTER]
2+
extension-pkg-allow-list=numpy,mkl_random.mklrand
3+
4+
[TYPECHECK]
5+
generated-members=RandomState,min,max

mkl_random/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,16 @@
9393
test = PytestTester(__name__)
9494
del PytestTester
9595

96-
from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random
9796
from mkl_random import interfaces
9897

98+
from ._patch_numpy import (
99+
is_patched,
100+
mkl_random,
101+
patch_numpy_random,
102+
patched_names,
103+
restore_numpy_random,
104+
)
105+
99106
__all__ = [
100107
"MKLRandomState",
101108
"RandomState",

mkl_random/_patch_numpy.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright (c) 2019, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
"""Define functions for patching NumPy with MKL-based NumPy interface."""
27+
28+
import warnings
29+
from contextlib import ContextDecorator
30+
from threading import Lock, local
31+
32+
import numpy as _np
33+
34+
import mkl_random.interfaces.numpy_random as _nrand
35+
36+
_DEFAULT_NAMES = tuple(_nrand.__all__)
37+
38+
39+
class _GlobalPatch:
40+
def __init__(self):
41+
self._lock = Lock()
42+
self._patch_count = 0
43+
self._restore_dict = {}
44+
self._patched_functions = tuple(_DEFAULT_NAMES)
45+
self._numpy_module = None
46+
self._requested_names = None
47+
self._active_names = ()
48+
self._patched = ()
49+
self._tls = local()
50+
51+
def _normalize_names(self, names):
52+
if names is None:
53+
names = _DEFAULT_NAMES
54+
return tuple(names)
55+
56+
def _validate_module(self, numpy_module):
57+
if not hasattr(numpy_module, "random"):
58+
raise TypeError(
59+
"Expected a numpy-like module with a `.random` attribute."
60+
)
61+
62+
def _register_func(self, name, func):
63+
if name not in self._patched_functions:
64+
raise ValueError(f"{name} not an mkl_random function.")
65+
np_random = self._numpy_module.random
66+
if name not in self._restore_dict:
67+
self._restore_dict[name] = getattr(np_random, name)
68+
setattr(np_random, name, func)
69+
70+
def _restore_func(self, name, verbose=False):
71+
if name not in self._patched_functions:
72+
raise ValueError(f"{name} not an mkl_random function.")
73+
try:
74+
val = self._restore_dict[name]
75+
except KeyError:
76+
if verbose:
77+
print(f"failed to restore {name}")
78+
return
79+
else:
80+
if verbose:
81+
print(f"found and restoring {name}...")
82+
np_random = self._numpy_module.random
83+
setattr(np_random, name, val)
84+
85+
def _initialize_patch(self, numpy_module, names, strict):
86+
self._validate_module(numpy_module)
87+
np_random = numpy_module.random
88+
missing = []
89+
patchable = []
90+
for name in names:
91+
if name not in self._patched_functions:
92+
missing.append(name)
93+
continue
94+
if not hasattr(np_random, name) or not hasattr(_nrand, name):
95+
missing.append(name)
96+
continue
97+
patchable.append(name)
98+
99+
if strict and missing:
100+
raise AttributeError(
101+
"Could not patch these names (missing on numpy.random or "
102+
"mkl_random.interfaces.numpy_random): "
103+
+ ", ".join(str(x) for x in missing)
104+
)
105+
106+
self._numpy_module = numpy_module
107+
self._requested_names = names
108+
self._active_names = tuple(patchable)
109+
self._patched = tuple(patchable)
110+
111+
def do_patch(
112+
self,
113+
numpy_module=None,
114+
names=None,
115+
strict=False,
116+
verbose=False,
117+
):
118+
if numpy_module is None:
119+
numpy_module = _np
120+
names = self._normalize_names(names)
121+
strict = bool(strict)
122+
123+
with self._lock:
124+
local_count = getattr(self._tls, "local_count", 0)
125+
if self._patch_count == 0:
126+
self._initialize_patch(numpy_module, names, strict)
127+
if verbose:
128+
print(
129+
"Now patching NumPy random submodule with mkl_random "
130+
"NumPy interface."
131+
)
132+
print(
133+
"Please direct bug reports to "
134+
"https://github.com/IntelPython/mkl_random"
135+
)
136+
for name in self._active_names:
137+
self._register_func(name, getattr(_nrand, name))
138+
else:
139+
if self._numpy_module is not numpy_module:
140+
raise RuntimeError(
141+
"Already patched a different numpy module; "
142+
"call restore() first."
143+
)
144+
if names != self._requested_names:
145+
raise RuntimeError(
146+
"Already patched with a different names set; "
147+
"call restore() first."
148+
)
149+
self._patch_count += 1
150+
self._tls.local_count = local_count + 1
151+
152+
def do_restore(self, verbose=False):
153+
with self._lock:
154+
local_count = getattr(self._tls, "local_count", 0)
155+
if local_count <= 0:
156+
if verbose:
157+
warnings.warn(
158+
"Warning: restore_numpy_random called more times than "
159+
"patch_numpy_random in this thread.",
160+
stacklevel=2,
161+
)
162+
return
163+
164+
self._tls.local_count = local_count - 1
165+
self._patch_count -= 1
166+
if self._patch_count == 0:
167+
if verbose:
168+
print("Now restoring original NumPy random submodule.")
169+
for name in tuple(self._restore_dict):
170+
self._restore_func(name, verbose=verbose)
171+
self._restore_dict.clear()
172+
self._numpy_module = None
173+
self._requested_names = None
174+
self._active_names = ()
175+
self._patched = ()
176+
177+
def is_patched(self):
178+
with self._lock:
179+
return self._patch_count > 0
180+
181+
def patched_names(self):
182+
with self._lock:
183+
return list(self._patched)
184+
185+
186+
_patch = _GlobalPatch()
187+
188+
189+
def patch_numpy_random(
190+
numpy_module=None,
191+
names=None,
192+
strict=False,
193+
verbose=False,
194+
):
195+
"""
196+
Patch NumPy's random submodule with mkl_random's NumPy interface.
197+
198+
Parameters
199+
----------
200+
numpy_module : module, optional
201+
NumPy-like module to patch. Defaults to imported NumPy.
202+
names : iterable[str], optional
203+
Attributes under `numpy_module.random` to patch.
204+
strict : bool, optional
205+
Raise if any requested symbol cannot be patched.
206+
verbose : bool, optional
207+
Print messages when starting the patching process.
208+
209+
Examples
210+
--------
211+
>>> import numpy as np
212+
>>> import mkl_random
213+
>>> mkl_random.is_patched()
214+
False
215+
>>> mkl_random.patch_numpy_random(np)
216+
>>> mkl_random.is_patched()
217+
True
218+
>>> mkl_random.restore()
219+
>>> mkl_random.is_patched()
220+
False
221+
"""
222+
_patch.do_patch(
223+
numpy_module=numpy_module,
224+
names=names,
225+
strict=bool(strict),
226+
verbose=bool(verbose),
227+
)
228+
229+
230+
def restore_numpy_random(verbose=False):
231+
"""
232+
Restore NumPy's random submodule to its original implementations.
233+
234+
Parameters
235+
----------
236+
verbose : bool, optional
237+
Print message when starting restoration process.
238+
"""
239+
_patch.do_restore(verbose=bool(verbose))
240+
241+
242+
def is_patched():
243+
"""Return whether NumPy has been patched with mkl_random."""
244+
return _patch.is_patched()
245+
246+
247+
def patched_names():
248+
"""Return names actually patched in `numpy.random`."""
249+
return _patch.patched_names()
250+
251+
252+
class mkl_random(ContextDecorator):
253+
"""
254+
Context manager and decorator to temporarily patch NumPy random submodule
255+
with MKL-based implementations.
256+
257+
Examples
258+
--------
259+
>>> import numpy as np
260+
>>> import mkl_random
261+
>>> with mkl_random.mkl_random(np):
262+
... x = np.random.normal(size=10)
263+
"""
264+
265+
def __init__(self, numpy_module=None, names=None, strict=False):
266+
self._numpy_module = numpy_module
267+
self._names = names
268+
self._strict = strict
269+
270+
def __enter__(self):
271+
patch_numpy_random(
272+
numpy_module=self._numpy_module,
273+
names=self._names,
274+
strict=self._strict,
275+
)
276+
return self
277+
278+
def __exit__(self, *exc):
279+
restore_numpy_random()
280+
return False

0 commit comments

Comments
 (0)