Skip to content

Commit 96c722c

Browse files
Disallow numeric scalar conversion for non-0d arrays
1 parent e041fca commit 96c722c

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
124124
return view
125125

126126

127+
cdef inline void _check_0d_scalar_conversion(object usm_ary) except *:
128+
"Raise TypeError if array cannot be converted to a Python scalar"
129+
if (usm_ary.ndim != 0):
130+
raise TypeError(
131+
"only 0-dimensional arrays can be converted to Python scalars"
132+
)
133+
134+
127135
cdef int _copy_writable(int lhs_flags, int rhs_flags):
128136
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
129137
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)
@@ -1147,6 +1155,7 @@ cdef class usm_ndarray:
11471155

11481156
def __float__(self):
11491157
if self.size == 1:
1158+
_check_0d_scalar_conversion(self)
11501159
view = _as_zero_dim_ndarray(self)
11511160
return view.__float__()
11521161

@@ -1156,6 +1165,7 @@ cdef class usm_ndarray:
11561165

11571166
def __complex__(self):
11581167
if self.size == 1:
1168+
_check_0d_scalar_conversion(self)
11591169
view = _as_zero_dim_ndarray(self)
11601170
return view.__complex__()
11611171

@@ -1165,6 +1175,7 @@ cdef class usm_ndarray:
11651175

11661176
def __int__(self):
11671177
if self.size == 1:
1178+
_check_0d_scalar_conversion(self)
11681179
view = _as_zero_dim_ndarray(self)
11691180
return view.__int__()
11701181

0 commit comments

Comments
 (0)