@@ -40,6 +40,16 @@ from libc.stdlib cimport free, malloc
4040
4141cnp.import_umath()
4242
43+ cdef extern from * :
44+ """
45+ #include "numpy/ufuncobject.h"
46+ static inline char* _get_ufunc_types(PyObject *u) {
47+ return (char *)((PyUFuncObject *)u)->types;
48+ }
49+ """
50+ char * _get_ufunc_types(object u) noexcept
51+
52+
4353ctypedef struct function_info:
4454 cnp.PyUFuncGenericFunction original_function
4555 cnp.PyUFuncGenericFunction patch_function
@@ -49,70 +59,89 @@ ctypedef struct function_info:
4959cdef class _patch_impl:
5060 cdef int functions_count
5161 cdef function_info* functions
52-
53- functions_dict = dict ()
62+ cdef dict functions_dict
5463
5564 def __cinit__ (self ):
56- cdef int pi, oi
65+ self .functions_dict = {}
66+ cdef int pi, oi, i, nargs
67+ cdef int expected_count
68+ cdef char * patch_types
69+ cdef char * orig_types
5770
58- umaths = [i for i in dir (mu) if isinstance ( getattr (mu, i), np.ufunc)]
71+ self .functions = NULL
5972 self .functions_count = 0
73+
74+ umaths = [x for x in dir (mu) if isinstance (getattr (mu, x), np.ufunc)]
75+ expected_count = 0
6076 for umath in umaths:
6177 mkl_umath_func = getattr (mu, umath)
62- self .functions_count += mkl_umath_func.ntypes
63-
64- self .functions = < function_info * > malloc(
65- self .functions_count * sizeof(function_info)
66- )
78+ expected_count += mkl_umath_func.ntypes
79+
80+ if expected_count > 0 :
81+ self .functions = < function_info * > malloc(
82+ expected_count * sizeof(function_info)
83+ )
84+ if self .functions is NULL :
85+ raise MemoryError (
86+ " Failed to allocate memory for function_info array"
87+ )
6788
68- func_number = 0
6989 for umath in umaths:
7090 patch_umath = getattr (mu, umath)
7191 c_patch_umath = < cnp.ufunc> patch_umath
7292 c_orig_umath = < cnp.ufunc> getattr (np, umath)
93+ # nargs must be >=0 as no ufuncs have no arguments
7394 nargs = c_patch_umath.nargs
95+ if nargs <= 0 :
96+ raise RuntimeError (
97+ f" Invalid number of arguments for ufunc {umath}: {nargs}"
98+ )
99+ patch_types = _get_ufunc_types(c_patch_umath)
100+ orig_types = _get_ufunc_types(c_orig_umath)
74101 for pi in range (c_patch_umath.ntypes):
75102 oi = 0
76103 while oi < c_orig_umath.ntypes:
77104 found = True
78- for i in range (c_patch_umath. nargs):
105+ for i in range (nargs):
79106 if (
80- c_patch_umath.types [pi * nargs + i]
81- != c_orig_umath.types [oi * nargs + i]
107+ patch_types [pi * nargs + i]
108+ != orig_types [oi * nargs + i]
82109 ):
83110 found = False
84111 break
85112 if found is True :
86113 break
87114 oi = oi + 1
88115 if oi < c_orig_umath.ntypes:
89- self .functions[func_number ].original_function = (
116+ self .functions[self .functions_count ].original_function = (
90117 c_orig_umath.functions[oi]
91118 )
92- self .functions[func_number ].patch_function = (
119+ self .functions[self .functions_count ].patch_function = (
93120 c_patch_umath.functions[pi]
94121 )
95- self .functions[func_number ].signature = (
122+ self .functions[self .functions_count ].signature = (
96123 < int * > malloc(nargs * sizeof(int ))
97124 )
98125 for i in range (nargs):
99- self .functions[func_number ].signature[i] = (
100- c_patch_umath.types [pi * nargs + i]
126+ self .functions[self .functions_count ].signature[i] = (
127+ patch_types [pi * nargs + i]
101128 )
102129 self .functions_dict[(umath, patch_umath.types[pi])] = (
103- func_number
130+ self .functions_count
104131 )
105- func_number = func_number + 1
132+ self .functions_count += 1
106133 else :
107134 raise RuntimeError (
108135 f" Unable to find original function for: {umath} "
109136 f" {patch_umath.types[pi]}"
110137 )
111138
112139 def __dealloc__ (self ):
113- for i in range (self .functions_count):
114- free(self .functions[i].signature)
115- free(self .functions)
140+ if self .functions is not NULL :
141+ for i in range (self .functions_count):
142+ if self .functions[i].signature is not NULL :
143+ free(self .functions[i].signature)
144+ free(self .functions)
116145
117146 cdef int _replace_loop(
118147 self ,
0 commit comments