Skip to content

Commit f007ddd

Browse files
committed
task: add patch methods for mkl_random
1 parent 5941905 commit f007ddd

File tree

4 files changed

+380
-0
lines changed

4 files changed

+380
-0
lines changed

mkl_random/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,6 @@
4242
test = PytestTester(__name__)
4343
del PytestTester
4444

45+
from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random
46+
4547
del _init_helper

mkl_random/src/_patch.pyx

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright (c) 2019, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
26+
# distutils: language = c
27+
# cython: language_level=3
28+
29+
"""
30+
Patch NumPy's `numpy.random` symbols to use mkl_random implementations.
31+
32+
This is attribute-level monkey patching. It can replace legacy APIs like
33+
`numpy.random.RandomState` and global distribution functions, but it does not
34+
replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
35+
compatible replacements.
36+
"""
37+
38+
from threading import local as threading_local
39+
from contextlib import ContextDecorator
40+
41+
import numpy as _np
42+
from . import mklrand as _mr
43+
44+
45+
cdef tuple _DEFAULT_NAMES = (
46+
# Legacy seeding / state
47+
"seed",
48+
"get_state",
49+
"set_state",
50+
"RandomState",
51+
52+
# Common global sampling helpers
53+
"random",
54+
"random_sample",
55+
"sample",
56+
"rand",
57+
"randn",
58+
"bytes",
59+
60+
# Integers
61+
"randint",
62+
63+
# Common distributions (only patched if present on both sides)
64+
"standard_normal",
65+
"normal",
66+
"uniform",
67+
"exponential",
68+
"gamma",
69+
"beta",
70+
"chisquare",
71+
"f",
72+
"lognormal",
73+
"laplace",
74+
"logistic",
75+
"multivariate_normal",
76+
"poisson",
77+
"power",
78+
"rayleigh",
79+
"triangular",
80+
"vonmises",
81+
"wald",
82+
"weibull",
83+
"zipf",
84+
85+
# Permutations / choices
86+
"choice",
87+
"permutation",
88+
"shuffle",
89+
)
90+
91+
92+
cdef class patch:
93+
cdef bint _is_patched
94+
cdef object _numpy_module
95+
cdef object _originals # dict: name -> original object
96+
cdef object _patched # list of names actually patched
97+
98+
def __cinit__(self):
99+
self._is_patched = False
100+
self._numpy_module = None
101+
self._originals = {}
102+
self._patched = []
103+
104+
def do_patch(self, numpy_module=None, names=None, bint strict=False):
105+
"""
106+
Patch the given numpy module (default: imported numpy) in-place.
107+
108+
Parameters
109+
----------
110+
numpy_module : module, optional
111+
The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`).
112+
names : iterable[str], optional
113+
Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES.
114+
strict : bool
115+
If True, raise if any requested symbol cannot be patched.
116+
"""
117+
if numpy_module is None:
118+
numpy_module = _np
119+
if names is None:
120+
names = _DEFAULT_NAMES
121+
122+
if not hasattr(numpy_module, "random"):
123+
raise TypeError("Expected a numpy-like module with a `.random` attribute.")
124+
125+
# If already patched, only allow idempotent re-entry for the same numpy module.
126+
if self._is_patched:
127+
if self._numpy_module is numpy_module:
128+
return
129+
raise RuntimeError("Already patched a different numpy module; call restore() first.")
130+
131+
np_random = numpy_module.random
132+
133+
originals = {}
134+
patched = []
135+
missing = []
136+
137+
for name in names:
138+
if not hasattr(np_random, name) or not hasattr(_mr, name):
139+
missing.append(name)
140+
continue
141+
originals[name] = getattr(np_random, name)
142+
setattr(np_random, name, getattr(_mr, name))
143+
patched.append(name)
144+
145+
if strict and missing:
146+
# revert partial patch before raising
147+
for n, v in originals.items():
148+
setattr(np_random, n, v)
149+
raise AttributeError(
150+
"Could not patch these names (missing on numpy.random or mkl_random.mklrand): "
151+
+ ", ".join([str(x) for x in missing])
152+
)
153+
154+
self._numpy_module = numpy_module
155+
self._originals = originals
156+
self._patched = patched
157+
self._is_patched = True
158+
159+
def do_unpatch(self):
160+
"""
161+
Restore the previously patched numpy module.
162+
"""
163+
if not self._is_patched:
164+
return
165+
numpy_module = self._numpy_module
166+
np_random = numpy_module.random
167+
for n, v in self._originals.items():
168+
setattr(np_random, n, v)
169+
170+
self._numpy_module = None
171+
self._originals = {}
172+
self._patched = []
173+
self._is_patched = False
174+
175+
def is_patched(self):
176+
return self._is_patched
177+
178+
def patched_names(self):
179+
"""
180+
Returns list of names that were actually patched.
181+
"""
182+
return list(self._patched)
183+
184+
185+
_tls = threading_local()
186+
187+
188+
def _is_tls_initialized():
189+
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
190+
191+
192+
def _initialize_tls():
193+
_tls.patch = patch()
194+
_tls.initialized = True
195+
196+
197+
def monkey_patch(numpy_module=None, names=None, strict=False):
198+
"""
199+
Enables using mkl_random in the given NumPy module by patching `numpy.random`.
200+
201+
Examples
202+
--------
203+
>>> import numpy as np
204+
>>> import mkl_random
205+
>>> mkl_random.is_patched()
206+
False
207+
>>> mkl_random.monkey_patch(np)
208+
>>> mkl_random.is_patched()
209+
True
210+
>>> mkl_random.restore()
211+
>>> mkl_random.is_patched()
212+
False
213+
"""
214+
if not _is_tls_initialized():
215+
_initialize_tls()
216+
_tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict))
217+
218+
219+
def use_in_numpy(numpy_module=None, names=None, strict=False):
220+
"""
221+
Backward-compatible alias for monkey_patch().
222+
"""
223+
monkey_patch(numpy_module=numpy_module, names=names, strict=strict)
224+
225+
226+
def restore():
227+
"""
228+
Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols.
229+
"""
230+
if not _is_tls_initialized():
231+
_initialize_tls()
232+
_tls.patch.do_unpatch()
233+
234+
235+
def is_patched():
236+
"""
237+
Returns whether NumPy has been patched with mkl_random.
238+
"""
239+
if not _is_tls_initialized():
240+
_initialize_tls()
241+
return bool(_tls.patch.is_patched())
242+
243+
244+
def patched_names():
245+
"""
246+
Returns the names actually patched in `numpy.random`.
247+
"""
248+
if not _is_tls_initialized():
249+
_initialize_tls()
250+
return _tls.patch.patched_names()
251+
252+
253+
class mkl_random(ContextDecorator):
254+
"""
255+
Context manager and decorator to temporarily patch NumPy's `numpy.random`.
256+
257+
Examples
258+
--------
259+
>>> import numpy as np
260+
>>> import mkl_random
261+
>>> with mkl_random.mkl_random():
262+
... x = np.random.normal(size=10)
263+
"""
264+
def __init__(self, numpy_module=None, names=None, strict=False):
265+
self._numpy_module = numpy_module
266+
self._names = names
267+
self._strict = strict
268+
269+
def __enter__(self):
270+
monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict)
271+
return self
272+
273+
def __exit__(self, *exc):
274+
restore()
275+
return False

mkl_random/tests/test_patch.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import numpy as np
2+
import mkl_random
3+
import pytest
4+
5+
def test_is_patched():
6+
"""
7+
Test that is_patched() returns correct status.
8+
"""
9+
assert not mkl_random.is_patched()
10+
mkl_random.monkey_patch(np)
11+
assert mkl_random.is_patched()
12+
mkl_random.restore()
13+
assert not mkl_random.is_patched()
14+
15+
def test_monkey_patch_and_restore():
16+
"""
17+
Test that monkey_patch replaces and restore brings back original functions.
18+
"""
19+
# Store original functions
20+
orig_normal = np.random.normal
21+
orig_randint = np.random.randint
22+
orig_RandomState = np.random.RandomState
23+
24+
try:
25+
mkl_random.monkey_patch(np)
26+
27+
# Check that functions are now different objects
28+
assert np.random.normal is not orig_normal
29+
assert np.random.randint is not orig_randint
30+
assert np.random.RandomState is not orig_RandomState
31+
32+
# Check that they are from mkl_random
33+
assert np.random.normal is mkl_random.mklrand.normal
34+
assert np.random.RandomState is mkl_random.mklrand.RandomState
35+
36+
finally:
37+
mkl_random.restore()
38+
39+
# Check that original functions are restored
40+
assert mkl_random.is_patched() is False
41+
assert np.random.normal is orig_normal
42+
assert np.random.randint is orig_randint
43+
assert np.random.RandomState is orig_RandomState
44+
45+
def test_context_manager():
46+
"""
47+
Test that the context manager patches and automatically restores.
48+
"""
49+
orig_uniform = np.random.uniform
50+
assert not mkl_random.is_patched()
51+
52+
with mkl_random.mkl_random(np):
53+
assert mkl_random.is_patched() is True
54+
assert np.random.uniform is not orig_uniform
55+
# Smoke test inside context
56+
arr = np.random.uniform(size=10)
57+
assert arr.shape == (10,)
58+
59+
assert not mkl_random.is_patched()
60+
assert np.random.uniform is orig_uniform
61+
62+
def test_patched_functions_callable():
63+
"""
64+
Smoke test to ensure some patched functions can be called without error.
65+
"""
66+
mkl_random.monkey_patch(np)
67+
try:
68+
# These calls should now be routed to mkl_random's implementations
69+
x = np.random.standard_normal(size=100)
70+
assert x.shape == (100,)
71+
72+
y = np.random.randint(0, 100, size=50)
73+
assert y.shape == (50,)
74+
assert np.all(y >= 0) and np.all(y < 100)
75+
76+
st = np.random.RandomState(12345)
77+
z = st.rand(10)
78+
assert z.shape == (10,)
79+
80+
finally:
81+
mkl_random.restore()
82+
83+
def test_patched_names():
84+
"""
85+
Test that patched_names() returns a list of patched symbols.
86+
"""
87+
try:
88+
mkl_random.monkey_patch(np)
89+
names = mkl_random.patched_names()
90+
assert isinstance(names, list)
91+
assert len(names) > 0
92+
assert "normal" in names
93+
assert "RandomState" in names
94+
finally:
95+
mkl_random.restore()

setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def extensions():
8181
extra_compile_args = eca,
8282
define_macros=defs + [("NDEBUG", None)],
8383
language="c++"
84+
),
85+
86+
Extension(
87+
"mkl_random._patch",
88+
sources=[join("mkl_random", "src", "_patch.pyx")],
89+
include_dirs=[np.get_include()],
90+
define_macros=defs + [("NDEBUG", None)],
91+
language="c",
8492
)
8593
]
8694

0 commit comments

Comments
 (0)