Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mkl_random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
test = PytestTester(__name__)
del PytestTester

from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random

del _init_helper
275 changes: 275 additions & 0 deletions mkl_random/src/_patch.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (c) 2019, Intel Corporation
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of Intel Corporation nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# distutils: language = c
# cython: language_level=3

"""
Patch NumPy's `numpy.random` symbols to use mkl_random implementations.

This is attribute-level monkey patching. It can replace legacy APIs like
`numpy.random.RandomState` and global distribution functions, but it does not
replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
compatible replacements.
"""

from threading import local as threading_local
from contextlib import ContextDecorator

import numpy as _np
from . import mklrand as _mr


cdef tuple _DEFAULT_NAMES = (
# Legacy seeding / state
"seed",
"get_state",
"set_state",
"RandomState",

# Common global sampling helpers
"random",
"random_sample",
"sample",
"rand",
"randn",
"bytes",

# Integers
"randint",

# Common distributions (only patched if present on both sides)
"standard_normal",
"normal",
"uniform",
"exponential",
"gamma",
"beta",
"chisquare",
"f",
"lognormal",
"laplace",
"logistic",
"multivariate_normal",
"poisson",
"power",
"rayleigh",
"triangular",
"vonmises",
"wald",
"weibull",
"zipf",

# Permutations / choices
"choice",
"permutation",
"shuffle",
)


cdef class patch:
cdef bint _is_patched
cdef object _numpy_module
cdef object _originals # dict: name -> original object
cdef object _patched # list of names actually patched

def __cinit__(self):
self._is_patched = False
self._numpy_module = None
self._originals = {}
self._patched = []

def do_patch(self, numpy_module=None, names=None, bint strict=False):
"""
Patch the given numpy module (default: imported numpy) in-place.

Parameters
----------
numpy_module : module, optional
The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`).
names : iterable[str], optional
Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES.
strict : bool
If True, raise if any requested symbol cannot be patched.
"""
if numpy_module is None:
numpy_module = _np
if names is None:
names = _DEFAULT_NAMES

if not hasattr(numpy_module, "random"):
raise TypeError("Expected a numpy-like module with a `.random` attribute.")

# If already patched, only allow idempotent re-entry for the same numpy module.
if self._is_patched:
if self._numpy_module is numpy_module:
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
return
raise RuntimeError("Already patched a different numpy module; call restore() first.")

np_random = numpy_module.random

originals = {}
patched = []
missing = []

for name in names:
if not hasattr(np_random, name) or not hasattr(_mr, name):
missing.append(name)
continue
originals[name] = getattr(np_random, name)
setattr(np_random, name, getattr(_mr, name))
patched.append(name)

if strict and missing:
# revert partial patch before raising
for n, v in originals.items():
setattr(np_random, n, v)
raise AttributeError(
"Could not patch these names (missing on numpy.random or mkl_random.mklrand): "
+ ", ".join([str(x) for x in missing])
)

self._numpy_module = numpy_module
self._originals = originals
self._patched = patched
self._is_patched = True

def do_unpatch(self):
"""
Restore the previously patched numpy module.
"""
if not self._is_patched:
return
numpy_module = self._numpy_module
np_random = numpy_module.random
for n, v in self._originals.items():
setattr(np_random, n, v)

self._numpy_module = None
self._originals = {}
self._patched = []
self._is_patched = False

Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
def is_patched(self):
return self._is_patched

def patched_names(self):
"""
Returns list of names that were actually patched.
"""
return list(self._patched)


_tls = threading_local()


def _is_tls_initialized():
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True)
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated


def _initialize_tls():
_tls.patch = patch()
_tls.initialized = True
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated


def monkey_patch(numpy_module=None, names=None, strict=False):
"""
Enables using mkl_random in the given NumPy module by patching `numpy.random`.

Examples
--------
>>> import numpy as np
>>> import mkl_random
>>> mkl_random.is_patched()
False
>>> mkl_random.monkey_patch(np)
>>> mkl_random.is_patched()
True
>>> mkl_random.restore()
>>> mkl_random.is_patched()
False
"""
if not _is_tls_initialized():
_initialize_tls()
_tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict))


def use_in_numpy(numpy_module=None, names=None, strict=False):
"""
Backward-compatible alias for monkey_patch().
"""
monkey_patch(numpy_module=numpy_module, names=names, strict=strict)


def restore():
"""
Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols.
"""
if not _is_tls_initialized():
_initialize_tls()
_tls.patch.do_unpatch()


def is_patched():
"""
Returns whether NumPy has been patched with mkl_random.
"""
if not _is_tls_initialized():
_initialize_tls()
return bool(_tls.patch.is_patched())


def patched_names():
"""
Returns the names actually patched in `numpy.random`.
"""
if not _is_tls_initialized():
_initialize_tls()
return _tls.patch.patched_names()


class mkl_random(ContextDecorator):
"""
Context manager and decorator to temporarily patch NumPy's `numpy.random`.

Examples
--------
>>> import numpy as np
>>> import mkl_random
>>> with mkl_random.mkl_random():
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
... x = np.random.normal(size=10)
"""
def __init__(self, numpy_module=None, names=None, strict=False):
self._numpy_module = numpy_module
self._names = names
self._strict = strict

def __enter__(self):
monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict)
return self

def __exit__(self, *exc):
restore()
return False
95 changes: 95 additions & 0 deletions mkl_random/tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
Comment thread
jharlow-intel marked this conversation as resolved.
import mkl_random
import pytest

def test_is_patched():
"""
Test that is_patched() returns correct status.
"""
assert not mkl_random.is_patched()
mkl_random.monkey_patch(np)
assert mkl_random.is_patched()
mkl_random.restore()
assert not mkl_random.is_patched()

def test_monkey_patch_and_restore():
"""
Test that monkey_patch replaces and restore brings back original functions.
"""
# Store original functions
orig_normal = np.random.normal
orig_randint = np.random.randint
orig_RandomState = np.random.RandomState

try:
mkl_random.monkey_patch(np)

# Check that functions are now different objects
assert np.random.normal is not orig_normal
assert np.random.randint is not orig_randint
assert np.random.RandomState is not orig_RandomState

# Check that they are from mkl_random
assert np.random.normal is mkl_random.mklrand.normal
assert np.random.RandomState is mkl_random.mklrand.RandomState

finally:
mkl_random.restore()

# Check that original functions are restored
assert mkl_random.is_patched() is False
assert np.random.normal is orig_normal
assert np.random.randint is orig_randint
assert np.random.RandomState is orig_RandomState

def test_context_manager():
"""
Test that the context manager patches and automatically restores.
"""
orig_uniform = np.random.uniform
assert not mkl_random.is_patched()

with mkl_random.mkl_random(np):
assert mkl_random.is_patched() is True
assert np.random.uniform is not orig_uniform
# Smoke test inside context
arr = np.random.uniform(size=10)
assert arr.shape == (10,)

assert not mkl_random.is_patched()
assert np.random.uniform is orig_uniform

def test_patched_functions_callable():
"""
Smoke test to ensure some patched functions can be called without error.
"""
mkl_random.monkey_patch(np)
try:
# These calls should now be routed to mkl_random's implementations
x = np.random.standard_normal(size=100)
assert x.shape == (100,)

y = np.random.randint(0, 100, size=50)
assert y.shape == (50,)
assert np.all(y >= 0) and np.all(y < 100)

st = np.random.RandomState(12345)
z = st.rand(10)
assert z.shape == (10,)

finally:
mkl_random.restore()

def test_patched_names():
"""
Test that patched_names() returns a list of patched symbols.
"""
try:
mkl_random.monkey_patch(np)
names = mkl_random.patched_names()
assert isinstance(names, list)
assert len(names) > 0
assert "normal" in names
assert "RandomState" in names
finally:
mkl_random.restore()
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def extensions():
extra_compile_args = eca,
define_macros=defs + [("NDEBUG", None)],
language="c++"
),

Extension(
"mkl_random._patch",
sources=[join("mkl_random", "src", "_patch.pyx")],
include_dirs=[np.get_include()],
define_macros=defs + [("NDEBUG", None)],
language="c",
)
Comment thread
jharlow-intel marked this conversation as resolved.
Outdated
]

Expand Down
Loading