-
Notifications
You must be signed in to change notification settings - Fork 14
task: add patch methods for mkl_random #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ndgrigorian
merged 6 commits into
IntelPython:master
from
jharlow-intel:task/patch-numpy
Mar 10, 2026
Merged
Changes from 1 commit
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
f007ddd
task: add patch methods for mkl_random
jharlow-intel 15404d2
Merge branch 'master' into task/patch-numpy
jharlow-intel 38fe23d
fix: patching to match mkl_fft, lint, and review
jharlow-intel 0a94878
fix: testing
jharlow-intel 8a8b942
chore: update CHANGELOG
jharlow-intel 4055e3c
task: review fixes
jharlow-intel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,275 @@ | ||
| # Copyright (c) 2019, Intel Corporation | ||
| # | ||
| # Redistribution and use in source and binary forms, with or without | ||
| # modification, are permitted provided that the following conditions are met: | ||
| # | ||
| # * Redistributions of source code must retain the above copyright notice, | ||
| # this list of conditions and the following disclaimer. | ||
| # * Redistributions in binary form must reproduce the above copyright | ||
| # notice, this list of conditions and the following disclaimer in the | ||
| # documentation and/or other materials provided with the distribution. | ||
| # * Neither the name of Intel Corporation nor the names of its contributors | ||
| # may be used to endorse or promote products derived from this software | ||
| # without specific prior written permission. | ||
| # | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE | ||
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
|
||
| # distutils: language = c | ||
| # cython: language_level=3 | ||
|
|
||
| """ | ||
| Patch NumPy's `numpy.random` symbols to use mkl_random implementations. | ||
|
|
||
| This is attribute-level monkey patching. It can replace legacy APIs like | ||
| `numpy.random.RandomState` and global distribution functions, but it does not | ||
| replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully | ||
| compatible replacements. | ||
| """ | ||
|
|
||
| from threading import local as threading_local | ||
| from contextlib import ContextDecorator | ||
|
|
||
| import numpy as _np | ||
| from . import mklrand as _mr | ||
|
|
||
|
|
||
| cdef tuple _DEFAULT_NAMES = ( | ||
| # Legacy seeding / state | ||
| "seed", | ||
| "get_state", | ||
| "set_state", | ||
| "RandomState", | ||
|
|
||
| # Common global sampling helpers | ||
| "random", | ||
| "random_sample", | ||
| "sample", | ||
| "rand", | ||
| "randn", | ||
| "bytes", | ||
|
|
||
| # Integers | ||
| "randint", | ||
|
|
||
| # Common distributions (only patched if present on both sides) | ||
| "standard_normal", | ||
| "normal", | ||
| "uniform", | ||
| "exponential", | ||
| "gamma", | ||
| "beta", | ||
| "chisquare", | ||
| "f", | ||
| "lognormal", | ||
| "laplace", | ||
| "logistic", | ||
| "multivariate_normal", | ||
| "poisson", | ||
| "power", | ||
| "rayleigh", | ||
| "triangular", | ||
| "vonmises", | ||
| "wald", | ||
| "weibull", | ||
| "zipf", | ||
|
|
||
| # Permutations / choices | ||
| "choice", | ||
| "permutation", | ||
| "shuffle", | ||
| ) | ||
|
|
||
|
|
||
| cdef class patch: | ||
| cdef bint _is_patched | ||
| cdef object _numpy_module | ||
| cdef object _originals # dict: name -> original object | ||
| cdef object _patched # list of names actually patched | ||
|
|
||
| def __cinit__(self): | ||
| self._is_patched = False | ||
| self._numpy_module = None | ||
| self._originals = {} | ||
| self._patched = [] | ||
|
|
||
| def do_patch(self, numpy_module=None, names=None, bint strict=False): | ||
| """ | ||
| Patch the given numpy module (default: imported numpy) in-place. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| numpy_module : module, optional | ||
| The numpy module to patch (e.g. `import numpy as np; use_in_numpy(np)`). | ||
| names : iterable[str], optional | ||
| Attributes under `numpy_module.random` to patch. Defaults to _DEFAULT_NAMES. | ||
| strict : bool | ||
| If True, raise if any requested symbol cannot be patched. | ||
| """ | ||
| if numpy_module is None: | ||
| numpy_module = _np | ||
| if names is None: | ||
| names = _DEFAULT_NAMES | ||
|
|
||
| if not hasattr(numpy_module, "random"): | ||
| raise TypeError("Expected a numpy-like module with a `.random` attribute.") | ||
|
|
||
| # If already patched, only allow idempotent re-entry for the same numpy module. | ||
| if self._is_patched: | ||
| if self._numpy_module is numpy_module: | ||
| return | ||
| raise RuntimeError("Already patched a different numpy module; call restore() first.") | ||
|
|
||
| np_random = numpy_module.random | ||
|
|
||
| originals = {} | ||
| patched = [] | ||
| missing = [] | ||
|
|
||
| for name in names: | ||
| if not hasattr(np_random, name) or not hasattr(_mr, name): | ||
| missing.append(name) | ||
| continue | ||
| originals[name] = getattr(np_random, name) | ||
| setattr(np_random, name, getattr(_mr, name)) | ||
| patched.append(name) | ||
|
|
||
| if strict and missing: | ||
| # revert partial patch before raising | ||
| for n, v in originals.items(): | ||
| setattr(np_random, n, v) | ||
| raise AttributeError( | ||
| "Could not patch these names (missing on numpy.random or mkl_random.mklrand): " | ||
| + ", ".join([str(x) for x in missing]) | ||
| ) | ||
|
|
||
| self._numpy_module = numpy_module | ||
| self._originals = originals | ||
| self._patched = patched | ||
| self._is_patched = True | ||
|
|
||
| def do_unpatch(self): | ||
| """ | ||
| Restore the previously patched numpy module. | ||
| """ | ||
| if not self._is_patched: | ||
| return | ||
| numpy_module = self._numpy_module | ||
| np_random = numpy_module.random | ||
| for n, v in self._originals.items(): | ||
| setattr(np_random, n, v) | ||
|
|
||
| self._numpy_module = None | ||
| self._originals = {} | ||
| self._patched = [] | ||
| self._is_patched = False | ||
|
|
||
|
jharlow-intel marked this conversation as resolved.
Outdated
|
||
| def is_patched(self): | ||
| return self._is_patched | ||
|
|
||
| def patched_names(self): | ||
| """ | ||
| Returns list of names that were actually patched. | ||
| """ | ||
| return list(self._patched) | ||
|
|
||
|
|
||
| _tls = threading_local() | ||
|
|
||
|
|
||
| def _is_tls_initialized(): | ||
| return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized is True) | ||
|
jharlow-intel marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def _initialize_tls(): | ||
| _tls.patch = patch() | ||
| _tls.initialized = True | ||
|
jharlow-intel marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def monkey_patch(numpy_module=None, names=None, strict=False): | ||
| """ | ||
| Enables using mkl_random in the given NumPy module by patching `numpy.random`. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import numpy as np | ||
| >>> import mkl_random | ||
| >>> mkl_random.is_patched() | ||
| False | ||
| >>> mkl_random.monkey_patch(np) | ||
| >>> mkl_random.is_patched() | ||
| True | ||
| >>> mkl_random.restore() | ||
| >>> mkl_random.is_patched() | ||
| False | ||
| """ | ||
| if not _is_tls_initialized(): | ||
| _initialize_tls() | ||
| _tls.patch.do_patch(numpy_module=numpy_module, names=names, strict=bool(strict)) | ||
|
|
||
|
|
||
| def use_in_numpy(numpy_module=None, names=None, strict=False): | ||
| """ | ||
| Backward-compatible alias for monkey_patch(). | ||
| """ | ||
| monkey_patch(numpy_module=numpy_module, names=names, strict=strict) | ||
|
|
||
|
|
||
| def restore(): | ||
| """ | ||
| Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols. | ||
| """ | ||
| if not _is_tls_initialized(): | ||
| _initialize_tls() | ||
| _tls.patch.do_unpatch() | ||
|
|
||
|
|
||
| def is_patched(): | ||
| """ | ||
| Returns whether NumPy has been patched with mkl_random. | ||
| """ | ||
| if not _is_tls_initialized(): | ||
| _initialize_tls() | ||
| return bool(_tls.patch.is_patched()) | ||
|
|
||
|
|
||
| def patched_names(): | ||
| """ | ||
| Returns the names actually patched in `numpy.random`. | ||
| """ | ||
| if not _is_tls_initialized(): | ||
| _initialize_tls() | ||
| return _tls.patch.patched_names() | ||
|
|
||
|
|
||
| class mkl_random(ContextDecorator): | ||
| """ | ||
| Context manager and decorator to temporarily patch NumPy's `numpy.random`. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import numpy as np | ||
| >>> import mkl_random | ||
| >>> with mkl_random.mkl_random(): | ||
|
jharlow-intel marked this conversation as resolved.
Outdated
|
||
| ... x = np.random.normal(size=10) | ||
| """ | ||
| def __init__(self, numpy_module=None, names=None, strict=False): | ||
| self._numpy_module = numpy_module | ||
| self._names = names | ||
| self._strict = strict | ||
|
|
||
| def __enter__(self): | ||
| monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict) | ||
| return self | ||
|
|
||
| def __exit__(self, *exc): | ||
| restore() | ||
| return False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| import numpy as np | ||
|
jharlow-intel marked this conversation as resolved.
|
||
| import mkl_random | ||
| import pytest | ||
|
|
||
| def test_is_patched(): | ||
| """ | ||
| Test that is_patched() returns correct status. | ||
| """ | ||
| assert not mkl_random.is_patched() | ||
| mkl_random.monkey_patch(np) | ||
| assert mkl_random.is_patched() | ||
| mkl_random.restore() | ||
| assert not mkl_random.is_patched() | ||
|
|
||
| def test_monkey_patch_and_restore(): | ||
| """ | ||
| Test that monkey_patch replaces and restore brings back original functions. | ||
| """ | ||
| # Store original functions | ||
| orig_normal = np.random.normal | ||
| orig_randint = np.random.randint | ||
| orig_RandomState = np.random.RandomState | ||
|
|
||
| try: | ||
| mkl_random.monkey_patch(np) | ||
|
|
||
| # Check that functions are now different objects | ||
| assert np.random.normal is not orig_normal | ||
| assert np.random.randint is not orig_randint | ||
| assert np.random.RandomState is not orig_RandomState | ||
|
|
||
| # Check that they are from mkl_random | ||
| assert np.random.normal is mkl_random.mklrand.normal | ||
| assert np.random.RandomState is mkl_random.mklrand.RandomState | ||
|
|
||
| finally: | ||
| mkl_random.restore() | ||
|
|
||
| # Check that original functions are restored | ||
| assert mkl_random.is_patched() is False | ||
| assert np.random.normal is orig_normal | ||
| assert np.random.randint is orig_randint | ||
| assert np.random.RandomState is orig_RandomState | ||
|
|
||
| def test_context_manager(): | ||
| """ | ||
| Test that the context manager patches and automatically restores. | ||
| """ | ||
| orig_uniform = np.random.uniform | ||
| assert not mkl_random.is_patched() | ||
|
|
||
| with mkl_random.mkl_random(np): | ||
| assert mkl_random.is_patched() is True | ||
| assert np.random.uniform is not orig_uniform | ||
| # Smoke test inside context | ||
| arr = np.random.uniform(size=10) | ||
| assert arr.shape == (10,) | ||
|
|
||
| assert not mkl_random.is_patched() | ||
| assert np.random.uniform is orig_uniform | ||
|
|
||
| def test_patched_functions_callable(): | ||
| """ | ||
| Smoke test to ensure some patched functions can be called without error. | ||
| """ | ||
| mkl_random.monkey_patch(np) | ||
| try: | ||
| # These calls should now be routed to mkl_random's implementations | ||
| x = np.random.standard_normal(size=100) | ||
| assert x.shape == (100,) | ||
|
|
||
| y = np.random.randint(0, 100, size=50) | ||
| assert y.shape == (50,) | ||
| assert np.all(y >= 0) and np.all(y < 100) | ||
|
|
||
| st = np.random.RandomState(12345) | ||
| z = st.rand(10) | ||
| assert z.shape == (10,) | ||
|
|
||
| finally: | ||
| mkl_random.restore() | ||
|
|
||
| def test_patched_names(): | ||
| """ | ||
| Test that patched_names() returns a list of patched symbols. | ||
| """ | ||
| try: | ||
| mkl_random.monkey_patch(np) | ||
| names = mkl_random.patched_names() | ||
| assert isinstance(names, list) | ||
| assert len(names) > 0 | ||
| assert "normal" in names | ||
| assert "RandomState" in names | ||
| finally: | ||
| mkl_random.restore() | ||
|
jharlow-intel marked this conversation as resolved.
Outdated
jharlow-intel marked this conversation as resolved.
Outdated
jharlow-intel marked this conversation as resolved.
Outdated
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.