Skip to content

Commit d15aea4

Browse files
authored
Merge pull request #226 from IntelPython/fix-for-numpy-2.5
Add getter for NumPy ufunc `types` to patching
2 parents 8852b0f + db4fe07 commit d15aea4

2 files changed

Lines changed: 53 additions & 23 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Changed
1313
* Removed `numpy-base` dependency and `USE_NUMPY_BASE` environment variable from conda recipe [gh-200](https://github.com/IntelPython/mkl_umath/pull/200)
14+
* Updated `mkl_umath` patching to work with changes to NumPy Cython API present in NumPy 2.5 [gh-226](https://github.com/IntelPython/mkl_umath/pull/226)
1415

1516
### Fixed
1617

mkl_umath/src/_patch_numpy.pyx

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ from libc.stdlib cimport free, malloc
4040

4141
cnp.import_umath()
4242

43+
cdef extern from *:
44+
"""
45+
#include "numpy/ufuncobject.h"
46+
static inline char* _get_ufunc_types(PyObject *u) {
47+
return (char *)((PyUFuncObject *)u)->types;
48+
}
49+
"""
50+
char* _get_ufunc_types(object u) noexcept
51+
52+
4353
ctypedef struct function_info:
4454
cnp.PyUFuncGenericFunction original_function
4555
cnp.PyUFuncGenericFunction patch_function
@@ -49,70 +59,89 @@ ctypedef struct function_info:
4959
cdef class _patch_impl:
5060
cdef int functions_count
5161
cdef function_info* functions
52-
53-
functions_dict = dict()
62+
cdef dict functions_dict
5463

5564
def __cinit__(self):
56-
cdef int pi, oi
65+
self.functions_dict = {}
66+
cdef int pi, oi, i, nargs
67+
cdef int expected_count
68+
cdef char* patch_types
69+
cdef char* orig_types
5770

58-
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
71+
self.functions = NULL
5972
self.functions_count = 0
73+
74+
umaths = [x for x in dir(mu) if isinstance(getattr(mu, x), np.ufunc)]
75+
expected_count = 0
6076
for umath in umaths:
6177
mkl_umath_func = getattr(mu, umath)
62-
self.functions_count += mkl_umath_func.ntypes
63-
64-
self.functions = <function_info *> malloc(
65-
self.functions_count * sizeof(function_info)
66-
)
78+
expected_count += mkl_umath_func.ntypes
79+
80+
if expected_count > 0:
81+
self.functions = <function_info *> malloc(
82+
expected_count * sizeof(function_info)
83+
)
84+
if self.functions is NULL:
85+
raise MemoryError(
86+
"Failed to allocate memory for function_info array"
87+
)
6788

68-
func_number = 0
6989
for umath in umaths:
7090
patch_umath = getattr(mu, umath)
7191
c_patch_umath = <cnp.ufunc>patch_umath
7292
c_orig_umath = <cnp.ufunc>getattr(np, umath)
93+
# nargs must be >=0 as no ufuncs have no arguments
7394
nargs = c_patch_umath.nargs
95+
if nargs <= 0:
96+
raise RuntimeError(
97+
f"Invalid number of arguments for ufunc {umath}: {nargs}"
98+
)
99+
patch_types = _get_ufunc_types(c_patch_umath)
100+
orig_types = _get_ufunc_types(c_orig_umath)
74101
for pi in range(c_patch_umath.ntypes):
75102
oi = 0
76103
while oi < c_orig_umath.ntypes:
77104
found = True
78-
for i in range(c_patch_umath.nargs):
105+
for i in range(nargs):
79106
if (
80-
c_patch_umath.types[pi * nargs + i]
81-
!= c_orig_umath.types[oi * nargs + i]
107+
patch_types[pi * nargs + i]
108+
!= orig_types[oi * nargs + i]
82109
):
83110
found = False
84111
break
85112
if found is True:
86113
break
87114
oi = oi + 1
88115
if oi < c_orig_umath.ntypes:
89-
self.functions[func_number].original_function = (
116+
self.functions[self.functions_count].original_function = (
90117
c_orig_umath.functions[oi]
91118
)
92-
self.functions[func_number].patch_function = (
119+
self.functions[self.functions_count].patch_function = (
93120
c_patch_umath.functions[pi]
94121
)
95-
self.functions[func_number].signature = (
122+
self.functions[self.functions_count].signature = (
96123
<int *> malloc(nargs * sizeof(int))
97124
)
98125
for i in range(nargs):
99-
self.functions[func_number].signature[i] = (
100-
c_patch_umath.types[pi * nargs + i]
126+
self.functions[self.functions_count].signature[i] = (
127+
patch_types[pi * nargs + i]
101128
)
102129
self.functions_dict[(umath, patch_umath.types[pi])] = (
103-
func_number
130+
self.functions_count
104131
)
105-
func_number = func_number + 1
132+
self.functions_count += 1
106133
else:
107134
raise RuntimeError(
108135
f"Unable to find original function for: {umath} "
109136
f"{patch_umath.types[pi]}"
110137
)
111138

112139
def __dealloc__(self):
113-
for i in range(self.functions_count):
114-
free(self.functions[i].signature)
115-
free(self.functions)
140+
if self.functions is not NULL:
141+
for i in range(self.functions_count):
142+
if self.functions[i].signature is not NULL:
143+
free(self.functions[i].signature)
144+
free(self.functions)
116145

117146
cdef int _replace_loop(
118147
self,

0 commit comments

Comments
 (0)