Skip to content

Commit ab441b6

Browse files
committed
Fix a way how contig flags calculated for linalg utils
1 parent 0ebd959 commit ab441b6

1 file changed

Lines changed: 3 additions & 7 deletions

File tree

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,14 @@ def _define_contig_flag(x):
185185
"""
186186

187187
flag = False
188-
x_strides = x.strides
188+
x_strides = dpnp.get_usm_ndarray(x).strides
189189
x_shape = x.shape
190190
if x.ndim < 2:
191191
return True, True, True
192192

193193
x_strides = _standardize_strides_to_nonzero(x_strides, x_shape)
194-
x_is_c_contiguous = (
195-
x_strides[-1] == x.itemsize and x_strides[-2] == x_shape[-1]
196-
)
197-
x_is_f_contiguous = (
198-
x_strides[-2] == x.itemsize and x_strides[-1] == x_shape[-2]
199-
)
194+
x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1]
195+
x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2]
200196
if x_is_c_contiguous or x_is_f_contiguous:
201197
flag = True
202198
return flag, x_is_c_contiguous, x_is_f_contiguous

0 commit comments

Comments
 (0)