Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
* Removed `numpy-base` dependency and `USE_NUMPY_BASE` environment variable from conda recipe [gh-200](https://github.com/IntelPython/mkl_umath/pull/200)
* Updated `mkl_umath` patching to work with changes to NumPy Cython API present in NumPy 2.5 [gh-226](https://github.com/IntelPython/mkl_umath/pull/226)

### Fixed

Expand Down
75 changes: 52 additions & 23 deletions mkl_umath/src/_patch_numpy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ from libc.stdlib cimport free, malloc

cnp.import_umath()

cdef extern from *:
Comment thread
ndgrigorian marked this conversation as resolved.
"""
#include "numpy/ufuncobject.h"
static inline char* _get_ufunc_types(PyObject *u) {
return (char *)((PyUFuncObject *)u)->types;
}
"""
char* _get_ufunc_types(object u) noexcept


ctypedef struct function_info:
cnp.PyUFuncGenericFunction original_function
cnp.PyUFuncGenericFunction patch_function
Expand All @@ -49,70 +59,89 @@ ctypedef struct function_info:
cdef class _patch_impl:
cdef int functions_count
cdef function_info* functions

functions_dict = dict()
cdef dict functions_dict

def __cinit__(self):
cdef int pi, oi
self.functions_dict = {}
cdef int pi, oi, i, nargs
cdef int expected_count
cdef char* patch_types
cdef char* orig_types

umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
self.functions = NULL
self.functions_count = 0

umaths = [x for x in dir(mu) if isinstance(getattr(mu, x), np.ufunc)]
expected_count = 0
for umath in umaths:
mkl_umath_func = getattr(mu, umath)
self.functions_count += mkl_umath_func.ntypes

self.functions = <function_info *> malloc(
self.functions_count * sizeof(function_info)
)
expected_count += mkl_umath_func.ntypes

if expected_count > 0:
self.functions = <function_info *> malloc(
expected_count * sizeof(function_info)
)
if self.functions is NULL:
raise MemoryError(
"Failed to allocate memory for function_info array"
)

func_number = 0
for umath in umaths:
patch_umath = getattr(mu, umath)
c_patch_umath = <cnp.ufunc>patch_umath
c_orig_umath = <cnp.ufunc>getattr(np, umath)
# nargs must be >=0 as no ufuncs have no arguments
nargs = c_patch_umath.nargs
if nargs <= 0:
raise RuntimeError(
f"Invalid number of arguments for ufunc {umath}: {nargs}"
)
patch_types = _get_ufunc_types(c_patch_umath)
orig_types = _get_ufunc_types(c_orig_umath)
for pi in range(c_patch_umath.ntypes):
oi = 0
while oi < c_orig_umath.ntypes:
found = True
for i in range(c_patch_umath.nargs):
for i in range(nargs):
if (
c_patch_umath.types[pi * nargs + i]
!= c_orig_umath.types[oi * nargs + i]
patch_types[pi * nargs + i]
!= orig_types[oi * nargs + i]
):
found = False
break
if found is True:
break
oi = oi + 1
if oi < c_orig_umath.ntypes:
self.functions[func_number].original_function = (
self.functions[self.functions_count].original_function = (
c_orig_umath.functions[oi]
)
self.functions[func_number].patch_function = (
self.functions[self.functions_count].patch_function = (
c_patch_umath.functions[pi]
)
self.functions[func_number].signature = (
self.functions[self.functions_count].signature = (
<int *> malloc(nargs * sizeof(int))
)
for i in range(nargs):
self.functions[func_number].signature[i] = (
c_patch_umath.types[pi * nargs + i]
self.functions[self.functions_count].signature[i] = (
patch_types[pi * nargs + i]
)
self.functions_dict[(umath, patch_umath.types[pi])] = (
func_number
self.functions_count
)
func_number = func_number + 1
self.functions_count += 1
else:
raise RuntimeError(
f"Unable to find original function for: {umath} "
f"{patch_umath.types[pi]}"
)

def __dealloc__(self):
for i in range(self.functions_count):
free(self.functions[i].signature)
free(self.functions)
if self.functions is not NULL:
for i in range(self.functions_count):
if self.functions[i].signature is not NULL:
free(self.functions[i].signature)
free(self.functions)

cdef int _replace_loop(
self,
Expand Down
Loading