Skip to content

Commit 5d178a8

Browse files
committed
fix: patching to match mkl_fft, lint, and review
1 parent 15404d2 commit 5d178a8

File tree

6 files changed

+456
-303
lines changed

6 files changed

+456
-303
lines changed

.pylintrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[MASTER]
2+
extension-pkg-allow-list=numpy,mkl_random.mklrand
3+
4+
[TYPECHECK]
5+
generated-members=RandomState,min,max

mkl_random/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,19 @@
9393
test = PytestTester(__name__)
9494
del PytestTester
9595

96-
from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random
9796
from mkl_random import interfaces
9897

98+
from ._patch import (
99+
is_patched,
100+
mkl_random,
101+
monkey_patch,
102+
patch_numpy_random,
103+
patched_names,
104+
restore,
105+
restore_numpy_random,
106+
use_in_numpy,
107+
)
108+
99109
__all__ = [
100110
"MKLRandomState",
101111
"RandomState",

mkl_random/_patch.py

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# Copyright (c) 2019, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
"""Define functions for patching NumPy with MKL-based NumPy interface."""
27+
28+
from contextlib import ContextDecorator
29+
from threading import Lock, local
30+
31+
import numpy as _np
32+
33+
from . import mklrand as _mr
34+
35+
36+
_DEFAULT_NAMES = (
37+
# Legacy seeding / state
38+
"seed",
39+
"get_state",
40+
"set_state",
41+
"RandomState",
42+
# Common global sampling helpers
43+
"random",
44+
"random_sample",
45+
"sample",
46+
"rand",
47+
"randn",
48+
"bytes",
49+
# Integers
50+
"randint",
51+
# Common distributions (only patched if present on both sides)
52+
"standard_normal",
53+
"normal",
54+
"uniform",
55+
"exponential",
56+
"gamma",
57+
"beta",
58+
"chisquare",
59+
"f",
60+
"lognormal",
61+
"laplace",
62+
"logistic",
63+
"multivariate_normal",
64+
"poisson",
65+
"power",
66+
"rayleigh",
67+
"triangular",
68+
"vonmises",
69+
"wald",
70+
"weibull",
71+
"zipf",
72+
# Permutations / choices
73+
"choice",
74+
"permutation",
75+
"shuffle",
76+
)
77+
78+
79+
class _GlobalPatch:
80+
def __init__(self):
81+
self._lock = Lock()
82+
self._patch_count = 0
83+
self._restore_dict = {}
84+
self._patched_functions = tuple(_DEFAULT_NAMES)
85+
self._numpy_module = None
86+
self._requested_names = None
87+
self._active_names = ()
88+
self._patched = ()
89+
self._tls = local()
90+
91+
def _normalize_names(self, names):
92+
if names is None:
93+
names = _DEFAULT_NAMES
94+
return tuple(names)
95+
96+
def _validate_module(self, numpy_module):
97+
if not hasattr(numpy_module, "random"):
98+
raise TypeError(
99+
"Expected a numpy-like module with a `.random` attribute."
100+
)
101+
102+
def _register_func(self, name, func):
103+
if name not in self._patched_functions:
104+
raise ValueError(f"{name} not an mkl_random function.")
105+
np_random = self._numpy_module.random
106+
if name not in self._restore_dict:
107+
self._restore_dict[name] = getattr(np_random, name)
108+
setattr(np_random, name, func)
109+
110+
def _restore_func(self, name, verbose=False):
111+
if name not in self._patched_functions:
112+
raise ValueError(f"{name} not an mkl_random function.")
113+
try:
114+
val = self._restore_dict[name]
115+
except KeyError:
116+
if verbose:
117+
print(f"failed to restore {name}")
118+
return
119+
else:
120+
if verbose:
121+
print(f"found and restoring {name}...")
122+
np_random = self._numpy_module.random
123+
setattr(np_random, name, val)
124+
125+
def _initialize_patch(self, numpy_module, names, strict):
126+
self._validate_module(numpy_module)
127+
np_random = numpy_module.random
128+
missing = []
129+
patchable = []
130+
for name in names:
131+
if name not in self._patched_functions:
132+
missing.append(name)
133+
continue
134+
if not hasattr(np_random, name) or not hasattr(_mr, name):
135+
missing.append(name)
136+
continue
137+
patchable.append(name)
138+
139+
if strict and missing:
140+
raise AttributeError(
141+
"Could not patch these names (missing on numpy.random or "
142+
"mkl_random.mklrand): "
143+
+ ", ".join([str(x) for x in missing])
144+
)
145+
146+
self._numpy_module = numpy_module
147+
self._requested_names = names
148+
self._active_names = tuple(patchable)
149+
self._patched = tuple(patchable)
150+
151+
def do_patch(
152+
self,
153+
numpy_module=None,
154+
names=None,
155+
strict=False,
156+
verbose=False,
157+
):
158+
if numpy_module is None:
159+
numpy_module = _np
160+
names = self._normalize_names(names)
161+
strict = bool(strict)
162+
163+
with self._lock:
164+
local_count = getattr(self._tls, "local_count", 0)
165+
if self._patch_count == 0:
166+
self._initialize_patch(numpy_module, names, strict)
167+
if verbose:
168+
print(
169+
"Now patching NumPy random submodule with mkl_random "
170+
"NumPy interface."
171+
)
172+
print(
173+
"Please direct bug reports to "
174+
"https://github.com/IntelPython/mkl_random"
175+
)
176+
for name in self._active_names:
177+
self._register_func(name, getattr(_mr, name))
178+
else:
179+
if self._numpy_module is not numpy_module:
180+
raise RuntimeError(
181+
"Already patched a different numpy module; "
182+
"call restore() first."
183+
)
184+
if names != self._requested_names:
185+
raise RuntimeError(
186+
"Already patched with a different names set; "
187+
"call restore() first."
188+
)
189+
self._patch_count += 1
190+
self._tls.local_count = local_count + 1
191+
192+
def do_restore(self, verbose=False):
193+
with self._lock:
194+
local_count = getattr(self._tls, "local_count", 0)
195+
if local_count <= 0:
196+
if verbose:
197+
print(
198+
"Warning: restore_numpy_random called more times than "
199+
"patch_numpy_random in this thread."
200+
)
201+
return
202+
203+
self._tls.local_count = local_count - 1
204+
self._patch_count -= 1
205+
if self._patch_count == 0:
206+
if verbose:
207+
print("Now restoring original NumPy random submodule.")
208+
for name in tuple(self._restore_dict):
209+
self._restore_func(name, verbose=verbose)
210+
self._restore_dict.clear()
211+
self._numpy_module = None
212+
self._requested_names = None
213+
self._active_names = ()
214+
self._patched = ()
215+
216+
def is_patched(self):
217+
with self._lock:
218+
return self._patch_count > 0
219+
220+
def patched_names(self):
221+
with self._lock:
222+
return list(self._patched)
223+
224+
225+
_patch = _GlobalPatch()
226+
227+
228+
def patch_numpy_random(
229+
numpy_module=None,
230+
names=None,
231+
strict=False,
232+
verbose=False,
233+
):
234+
"""
235+
Patch NumPy's random submodule with mkl_random's NumPy interface.
236+
237+
Parameters
238+
----------
239+
numpy_module : module, optional
240+
NumPy-like module to patch. Defaults to imported NumPy.
241+
names : iterable[str], optional
242+
Attributes under `numpy_module.random` to patch.
243+
strict : bool, optional
244+
Raise if any requested symbol cannot be patched.
245+
verbose : bool, optional
246+
Print messages when starting the patching process.
247+
248+
Examples
249+
--------
250+
>>> import numpy as np
251+
>>> import mkl_random
252+
>>> mkl_random.is_patched()
253+
False
254+
>>> mkl_random.patch_numpy_random(np)
255+
>>> mkl_random.is_patched()
256+
True
257+
>>> mkl_random.restore()
258+
>>> mkl_random.is_patched()
259+
False
260+
"""
261+
_patch.do_patch(
262+
numpy_module=numpy_module,
263+
names=names,
264+
strict=bool(strict),
265+
verbose=bool(verbose),
266+
)
267+
268+
269+
def restore_numpy_random(verbose=False):
270+
"""
271+
Restore NumPy's random submodule to its original implementations.
272+
273+
Parameters
274+
----------
275+
verbose : bool, optional
276+
Print message when starting restoration process.
277+
"""
278+
_patch.do_restore(verbose=bool(verbose))
279+
280+
281+
def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
282+
"""Backward-compatible alias for patch_numpy_random()."""
283+
patch_numpy_random(
284+
numpy_module=numpy_module,
285+
names=names,
286+
strict=strict,
287+
verbose=verbose,
288+
)
289+
290+
291+
def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
292+
"""Backward-compatible alias for patch_numpy_random()."""
293+
patch_numpy_random(
294+
numpy_module=numpy_module,
295+
names=names,
296+
strict=strict,
297+
verbose=verbose,
298+
)
299+
300+
301+
def restore(verbose=False):
302+
"""Backward-compatible alias for restore_numpy_random()."""
303+
restore_numpy_random(verbose=verbose)
304+
305+
306+
def is_patched():
307+
"""
308+
Returns whether NumPy has been patched with mkl_random.
309+
"""
310+
return _patch.is_patched()
311+
312+
313+
def patched_names():
314+
"""
315+
Returns the names actually patched in `numpy.random`.
316+
"""
317+
return _patch.patched_names()
318+
319+
320+
class mkl_random(ContextDecorator):
321+
"""
322+
Context manager and decorator to temporarily patch NumPy random submodule
323+
with MKL-based implementations.
324+
325+
Examples
326+
--------
327+
>>> import numpy as np
328+
>>> import mkl_random
329+
>>> with mkl_random.mkl_random(np):
330+
... x = np.random.normal(size=10)
331+
"""
332+
333+
def __init__(self, numpy_module=None, names=None, strict=False):
334+
self._numpy_module = numpy_module
335+
self._names = names
336+
self._strict = strict
337+
338+
def __enter__(self):
339+
patch_numpy_random(
340+
numpy_module=self._numpy_module,
341+
names=self._names,
342+
strict=self._strict,
343+
)
344+
return self
345+
346+
def __exit__(self, *exc):
347+
restore_numpy_random()
348+
return False

0 commit comments

Comments
 (0)