Skip to content

Commit d597b5c

Browse files
committed
MAINT: move expand_dims tests to a class
1 parent 54e5862 commit d597b5c

1 file changed

Lines changed: 64 additions & 63 deletions

File tree

array_api_tests/test_manipulation_functions.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -122,74 +122,75 @@ def test_concat(dtypes, base_shape, data):
122122
raise
123123

124124

125-
@pytest.mark.unvectorized
126-
@given(
127-
x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()),
128-
axis=shared_shapes().flatmap(
129-
# Generate both valid and invalid axis
130-
lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s))
131-
),
132-
)
133-
def test_expand_dims(x, axis):
134-
if axis < -x.ndim - 1 or axis > x.ndim:
135-
with pytest.raises(IndexError):
136-
xp.expand_dims(x, axis=axis)
137-
return
138-
139-
repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})")
140-
try:
141-
out = xp.expand_dims(x, axis=axis)
125+
class TestExpandDims:
126+
@pytest.mark.unvectorized
127+
@given(
128+
x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()),
129+
axis=shared_shapes().flatmap(
130+
# Generate both valid and invalid axis
131+
lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s))
132+
),
133+
)
134+
def test_expand_dims(self, x, axis):
135+
if axis < -x.ndim - 1 or axis > x.ndim:
136+
with pytest.raises(IndexError):
137+
xp.expand_dims(x, axis=axis)
138+
return
142139

143-
ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype)
140+
repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})")
141+
try:
142+
out = xp.expand_dims(x, axis=axis)
144143

145-
shape = [side for side in x.shape]
146-
index = axis if axis >= 0 else x.ndim + axis + 1
147-
shape.insert(index, 1)
148-
shape = tuple(shape)
149-
ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape)
144+
ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype)
150145

151-
assert_array_ndindex(
152-
"expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)
146+
shape = [side for side in x.shape]
147+
index = axis if axis >= 0 else x.ndim + axis + 1
148+
shape.insert(index, 1)
149+
shape = tuple(shape)
150+
ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape)
151+
152+
assert_array_ndindex(
153+
"expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)
154+
)
155+
except Exception as exc:
156+
ph.add_note(exc, repro_snippet)
157+
raise
158+
159+
160+
@given(
161+
x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)),
162+
axes=shared_shapes().flatmap(
163+
lambda s: st.lists(
164+
st.integers(2*(-len(s)-1), 2*len(s)),
165+
min_size=0 if len(s)==0 else 1,
166+
max_size=len(s)
167+
).map(tuple)
153168
)
154-
except Exception as exc:
155-
ph.add_note(exc, repro_snippet)
156-
raise
157-
158-
159-
@given(
160-
x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)),
161-
axes=shared_shapes().flatmap(
162-
lambda s: st.lists(
163-
st.integers(2*(-len(s)-1), 2*len(s)),
164-
min_size=0 if len(s)==0 else 1,
165-
max_size=len(s)
166-
).map(tuple)
167169
)
168-
)
169-
def test_expand_dims_tuples(x, axes):
170-
# normalize the axes
171-
y_ndim = x.ndim + len(axes)
172-
n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes)
173-
unique_axes = set(n_axes)
174-
175-
if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes):
176-
with pytest.raises((IndexError, ValueError)):
177-
xp.expand_dims(x, axis=axes)
178-
return
179-
180-
repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})")
181-
try:
182-
y = xp.expand_dims(x, axis=axes)
183-
184-
ye = x
185-
for ax in sorted(n_axes):
186-
ye = xp.expand_dims(ye, axis=ax)
187-
assert y.shape == ye.shape
188-
# TODO value tests; check that y.shape is 1s and items from x.shape, in order
189-
190-
except Exception as exc:
191-
ph.add_note(exc, repro_snippet)
192-
raise
170+
def test_expand_dims_tuples(self, x, axes):
171+
# normalize the axes
172+
y_ndim = x.ndim + len(axes)
173+
n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes)
174+
unique_axes = set(n_axes)
175+
176+
if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes):
177+
with pytest.raises((IndexError, ValueError)):
178+
xp.expand_dims(x, axis=axes)
179+
return
180+
181+
repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})")
182+
try:
183+
y = xp.expand_dims(x, axis=axes)
184+
185+
ye = x
186+
for ax in sorted(n_axes):
187+
ye = xp.expand_dims(ye, axis=ax)
188+
assert y.shape == ye.shape
189+
# TODO value tests; check that y.shape is 1s and items from x.shape, in order
190+
191+
except Exception as exc:
192+
ph.add_note(exc, repro_snippet)
193+
raise
193194

194195

195196
@pytest.mark.min_version("2023.12")

0 commit comments

Comments
 (0)