Skip to content

Commit 7845b3c

Browse files
committed
Add more tests to cover different use cases
1 parent 383c39c commit 7845b3c

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

dpnp/tests/test_indexing.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_errors(self):
118118
assert_raises(ValueError, a.diagonal, axis1=1, axis2=1)
119119
assert_raises(ValueError, a.diagonal, axis1=1, axis2=-1)
120120

121-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
121+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
122122
@pytest.mark.parametrize(
123123
"shape, offset",
124124
[
@@ -129,8 +129,8 @@ def test_errors(self):
129129
((3, 3, 4), 5), # 3D array, offset >= m
130130
],
131131
)
132-
def test_empty_strides(self, dtype, shape, offset):
133-
a = generate_random_numpy_array(shape=shape, dtype=dtype)
132+
def test_empty_strides(self, dt, shape, offset):
133+
a = generate_random_numpy_array(shape=shape, dtype=dt)
134134
ia = dpnp.array(a)
135135

136136
expected = numpy.diagonal(a, offset)
@@ -141,6 +141,41 @@ def test_empty_strides(self, dtype, shape, offset):
141141
assert expected.strides == result.strides
142142
assert_array_equal(expected, result)
143143

144+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
145+
def test_view(self, dt):
146+
a = generate_random_numpy_array(shape=(3, 4), dtype=dt)
147+
a = dpnp.array(a)
148+
ia = a.copy()
149+
150+
diag = dpnp.diagonal(a)
151+
diag[1] = 17 # modify a diagonal element
152+
ia[1, 1] = 17 # do the same in original copy of the array
153+
154+
assert (a == ia).all()
155+
156+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))
157+
@pytest.mark.parametrize(
158+
"slice_spec, offset",
159+
[
160+
((slice(None), slice(None, None, 2)), 0), # skip columns
161+
((slice(None, None, 2), slice(None)), 1), # skip rows
162+
((slice(None, None, 2), slice(None, None, 2)), 0), # skip both
163+
],
164+
)
165+
def test_noncontiguous(self, dt, slice_spec, offset):
166+
a = generate_random_numpy_array(shape=(4, 6), dtype=dt)
167+
a_sliced = a[slice_spec]
168+
ia = dpnp.array(a)
169+
ia_sliced = ia[slice_spec]
170+
171+
expected = numpy.diagonal(a_sliced, offset=offset)
172+
result = dpnp.diagonal(ia_sliced, offset=offset)
173+
174+
# Check strides match for non-contiguous arrays
175+
assert expected.shape == result.shape
176+
assert expected.strides == result.strides
177+
assert_array_equal(expected, result)
178+
144179

145180
class TestExtins:
146181
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True))

0 commit comments

Comments
 (0)