Skip to content

Commit 0359ebd

Browse files
committed
Add casting for input strides and none buffer
1 parent 51b5603 commit 0359ebd

4 files changed

Lines changed: 18 additions & 10 deletions

File tree

dpnp/dpnp_array.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,19 @@ def __init__(
102102

103103
if dtype is None and hasattr(buffer, "dtype"):
104104
dtype = buffer.dtype
105-
106-
if not (strides is None or dtype is None):
107-
# dpctl expects strides as elements displacement in memory,
108-
# while dpnp (and numpy as well) relies on bytes displacement
109-
it_sz = dpnp.dtype(dtype).itemsize
110-
strides = tuple(el // it_sz for el in strides)
111105
else:
112106
buffer = usm_type
113107

108+
if strides is not None:
109+
# dpctl expects strides as elements displacement in memory,
110+
# while dpnp (and numpy as well) relies on bytes displacement
111+
if dtype is None:
112+
dtype = dpnp.default_float_type(
113+
device=device, sycl_queue=sycl_queue
114+
)
115+
it_sz = dpnp.dtype(dtype).itemsize
116+
strides = tuple(el // it_sz for el in strides)
117+
114118
sycl_queue_normalized = dpnp.get_normalized_queue_device(
115119
device=device, sycl_queue=sycl_queue
116120
)

dpnp/dpnp_iface_arraycreation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_empty_array(
105105
elif a.flags.c_contiguous:
106106
order = "C"
107107
else:
108-
strides = _get_strides_for_order_k(a, _shape)
108+
strides = _get_strides_for_order_k(a, _dtype, shape=_shape)
109109
order = "C"
110110
elif order not in "cfCF":
111111
raise ValueError(
@@ -122,15 +122,15 @@ def _get_empty_array(
122122
)
123123

124124

125-
def _get_strides_for_order_k(x, shape=None):
125+
def _get_strides_for_order_k(x, dtype, shape=None):
126126
"""
127127
Calculate strides when order='K' for empty_like, ones_like, zeros_like,
128128
and full_like where `shape` is ``None`` or len(shape) == x.ndim.
129129
130130
"""
131131
stride_and_index = sorted([(abs(s), -i) for i, s in enumerate(x.strides)])
132132
strides = [0] * x.ndim
133-
stride = 1
133+
stride = dpnp.dtype(dtype).itemsize
134134
for _, i in stride_and_index:
135135
strides[-i] = stride
136136
stride *= shape[-i] if shape else x.shape[-i]

dpnp/fft/dpnp_utils_fft.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
224224
if a.dtype == dpnp.complex64
225225
else dpnp.float64
226226
)
227+
# cast to expected strides format
228+
out_strides = tuple(
229+
el * dpnp.dtype(out_dtype).itemsize for el in out_strides
230+
)
227231
result = dpnp_array(
228232
out_shape,
229233
dtype=out_dtype,

dpnp/tests/test_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_flags_strides(dtype, order, strides):
290290
(4, 4), dtype=dtype, order=order, strides=strides
291291
)
292292
a = numpy.ndarray((4, 4), dtype=dtype, order=order, strides=numpy_strides)
293-
ia = dpnp.ndarray((4, 4), dtype=dtype, order=order, strides=strides)
293+
ia = dpnp.ndarray((4, 4), dtype=dtype, order=order, strides=numpy_strides)
294294
assert usm_array.flags == ia.flags
295295
assert a.flags.c_contiguous == ia.flags.c_contiguous
296296
assert a.flags.f_contiguous == ia.flags.f_contiguous

0 commit comments

Comments
 (0)