Skip to content

Commit e2d91b2

Browse files
authored
Merge pull request #90 from jharlow-intel/task/patch-numpy
task: add patch methods for mkl_random
2 parents 18bd0f4 + 4055e3c commit e2d91b2

File tree

6 files changed

+309
-1
lines changed

6 files changed

+309
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [dev] (MM/DD/YYYY)
88

99
### Added
10+
* Added `mkl_random` patching for NumPy, with `mkl_random` context manager, `is_patched` query, and `patch_numpy_random` and `restore_numpy_random` calls to replace `numpy.random` calls with calls from `mkl_random.interfaces.numpy_random` [gh-90](https://github.com/IntelPython/mkl_random/pull/90)
11+
1012
* Added `mkl_random.interfaces` with `mkl_random.interfaces.numpy_random` interface, which aliases `mkl_random` functionality to more strictly adhere to NumPy's API (i.e., drops arguments and functions which are not part of standard NumPy) [gh-92](https://github.com/IntelPython/mkl_random/pull/92)
1113

1214
### Removed

mkl_random/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@
9595

9696
from mkl_random import interfaces
9797

98+
from ._patch_numpy import (
99+
is_patched,
100+
mkl_random,
101+
patch_numpy_random,
102+
restore_numpy_random,
103+
)
104+
98105
__all__ = [
99106
"MKLRandomState",
100107
"RandomState",
@@ -147,6 +154,10 @@
147154
"shuffle",
148155
"permutation",
149156
"interfaces",
157+
"mkl_random",
158+
"patch_numpy_random",
159+
"restore_numpy_random",
160+
"is_patched",
150161
]
151162

152163
del _init_helper

mkl_random/_patch_numpy.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2019, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
"""Define functions for patching NumPy with MKL-based NumPy interface."""
27+
28+
from contextlib import ContextDecorator
29+
from threading import Lock, local
30+
31+
import numpy as np
32+
33+
import mkl_random.interfaces.numpy_random as _nrand
34+
35+
36+
class _GlobalPatch:
37+
def __init__(self):
38+
self._lock = Lock()
39+
self._patch_count = 0
40+
self._restore_dict = {}
41+
# make _patched_functions a tuple (immutable)
42+
self._patched_functions = tuple(_nrand.__all__)
43+
self._tls = local()
44+
45+
def _register_func(self, name, func):
46+
if name not in self._patched_functions:
47+
raise ValueError(f"{name} not an mkl_random function.")
48+
if name not in self._restore_dict:
49+
self._restore_dict[name] = getattr(np.random, name)
50+
setattr(np.random, name, func)
51+
52+
def _restore_func(self, name, verbose=False):
53+
if name not in self._patched_functions:
54+
raise ValueError(f"{name} not an mkl_random function.")
55+
try:
56+
val = self._restore_dict[name]
57+
except KeyError:
58+
if verbose:
59+
print(f"failed to restore {name}")
60+
return
61+
else:
62+
if verbose:
63+
print(f"found and restoring {name}...")
64+
setattr(np.random, name, val)
65+
66+
def do_patch(self, verbose=False):
67+
with self._lock:
68+
local_count = getattr(self._tls, "local_count", 0)
69+
if self._patch_count == 0:
70+
if verbose:
71+
print(
72+
"Now patching NumPy random submodule with mkl_random "
73+
"NumPy interface."
74+
)
75+
print(
76+
"Please direct bug reports to "
77+
"https://github.com/IntelPython/mkl_random"
78+
)
79+
for f in self._patched_functions:
80+
self._register_func(f, getattr(_nrand, f))
81+
self._patch_count += 1
82+
self._tls.local_count = local_count + 1
83+
84+
def do_restore(self, verbose=False):
85+
with self._lock:
86+
local_count = getattr(self._tls, "local_count", 0)
87+
if local_count <= 0:
88+
if verbose:
89+
print(
90+
"Warning: restore_numpy_random called more times than "
91+
"patch_numpy_random in this thread."
92+
)
93+
return
94+
self._tls.local_count -= 1
95+
self._patch_count -= 1
96+
if self._patch_count == 0:
97+
if verbose:
98+
print("Now restoring original NumPy random submodule.")
99+
for name in tuple(self._restore_dict):
100+
self._restore_func(name, verbose=verbose)
101+
self._restore_dict.clear()
102+
103+
def is_patched(self):
104+
with self._lock:
105+
return self._patch_count > 0
106+
107+
108+
_patch = _GlobalPatch()
109+
110+
111+
def patch_numpy_random(verbose=False):
112+
"""
113+
Patch NumPy's random submodule with mkl_random's numpy_interface.
114+
115+
Parameters
116+
----------
117+
verbose : bool, optional
118+
print message when starting the patching process.
119+
120+
Notes
121+
-----
122+
This function uses reference-counted semantics. Each call increments a
123+
global patch counter. Restoration requires a matching number of calls
124+
between `patch_numpy_random` and `restore_numpy_random`.
125+
126+
In multi-threaded programs, prefer the `mkl_random` context manager.
127+
128+
"""
129+
_patch.do_patch(verbose=verbose)
130+
131+
132+
def restore_numpy_random(verbose=False):
133+
"""
134+
Restore NumPy's random submodule to its original implementations.
135+
136+
Parameters
137+
----------
138+
verbose : bool, optional
139+
print message when starting restoration process.
140+
141+
Notes
142+
-----
143+
This function uses reference-counted semantics. Each call decrements a
144+
global patch counter. Restoration requires a matching number of calls
145+
between `patch_numpy_random` and `restore_numpy_random`.
146+
147+
In multi-threaded programs, prefer the `mkl_random` context manager.
148+
149+
"""
150+
_patch.do_restore(verbose=verbose)
151+
152+
153+
def is_patched():
154+
"""Return True if NumPy's random sm is currently patched by mkl_random."""
155+
return _patch.is_patched()
156+
157+
158+
class mkl_random(ContextDecorator):
159+
"""
160+
Context manager and decorator to temporarily patch NumPy random submodule
161+
with MKL-based implementations.
162+
163+
Examples
164+
--------
165+
>>> import mkl_random
166+
>>> mkl_random.is_patched()
167+
# False
168+
169+
>>> with mkl_random.mkl_random(): # Enable mkl_random in NumPy
170+
>>> print(mkl_random.is_patched())
171+
# True
172+
173+
>>> mkl_random.is_patched()
174+
# False
175+
176+
"""
177+
178+
def __enter__(self):
179+
patch_numpy_random()
180+
return self
181+
182+
def __exit__(self, *exc):
183+
restore_numpy_random()
184+
return False

mkl_random/tests/test_patch.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2017, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
import numpy as np
27+
28+
import mkl_random
29+
import mkl_random.interfaces.numpy_random as _nrand
30+
31+
32+
def test_is_patched():
33+
"""Test that is_patched() returns correct status."""
34+
assert not mkl_random.is_patched()
35+
try:
36+
mkl_random.patch_numpy_random()
37+
assert mkl_random.is_patched()
38+
mkl_random.restore_numpy_random()
39+
assert not mkl_random.is_patched()
40+
finally:
41+
while mkl_random.is_patched():
42+
mkl_random.restore_numpy_random()
43+
44+
45+
def test_patch():
46+
old_module = np.random.normal.__module__
47+
assert not mkl_random.is_patched()
48+
49+
try:
50+
mkl_random.patch_numpy_random() # Enable mkl_random in NumPy
51+
assert mkl_random.is_patched()
52+
assert np.random.normal.__module__ == _nrand.normal.__module__
53+
54+
mkl_random.restore_numpy_random() # Disable mkl_random in NumPy
55+
assert not mkl_random.is_patched()
56+
assert np.random.normal.__module__ == old_module
57+
finally:
58+
while mkl_random.is_patched():
59+
mkl_random.restore_numpy_random()
60+
61+
62+
def test_patch_redundant_patching():
63+
old_module = np.random.normal.__module__
64+
assert not mkl_random.is_patched()
65+
66+
try:
67+
mkl_random.patch_numpy_random()
68+
mkl_random.patch_numpy_random()
69+
70+
assert mkl_random.is_patched()
71+
assert np.random.normal.__module__ == _nrand.normal.__module__
72+
73+
mkl_random.restore_numpy_random()
74+
assert mkl_random.is_patched()
75+
assert np.random.normal.__module__ == _nrand.normal.__module__
76+
77+
mkl_random.restore_numpy_random()
78+
assert not mkl_random.is_patched()
79+
assert np.random.normal.__module__ == old_module
80+
finally:
81+
while mkl_random.is_patched():
82+
mkl_random.restore_numpy_random()
83+
84+
85+
def test_patch_reentrant():
86+
old_module = np.random.normal.__module__
87+
assert not mkl_random.is_patched()
88+
89+
try:
90+
with mkl_random.mkl_random():
91+
assert mkl_random.is_patched()
92+
assert np.random.normal.__module__ == _nrand.normal.__module__
93+
94+
with mkl_random.mkl_random():
95+
assert mkl_random.is_patched()
96+
assert np.random.normal.__module__ == _nrand.normal.__module__
97+
98+
assert mkl_random.is_patched()
99+
assert np.random.normal.__module__ == _nrand.normal.__module__
100+
101+
assert not mkl_random.is_patched()
102+
assert np.random.normal.__module__ == old_module
103+
finally:
104+
while mkl_random.is_patched():
105+
mkl_random.restore_numpy_random()

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ line_length = 80
9393
multi_line_output = 3
9494
use_parentheses = true
9595

96+
[tool.pylint.main]
97+
extension-pkg-allow-list = ["numpy", "mkl_random.mklrand"]
98+
99+
[tool.pylint.typecheck]
100+
generated-members = ["RandomState", "min", "max"]
101+
96102
[tool.setuptools]
97103
include-package-data = true
98104

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def extensions():
9292
extra_compile_args=eca,
9393
define_macros=defs + [("NDEBUG", None)],
9494
language="c++",
95-
)
95+
),
9696
]
9797

9898
return exts

0 commit comments

Comments
 (0)