Skip to content

Commit 0eb9b80

Browse files
authored
Merge pull request #403 from ev-br/test_broadcast_shapes
ENH: test broadcast_shapes
2 parents 0ded6da + cb782aa commit 0eb9b80

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,31 @@ def test_broadcast_arrays(shapes, data):
130130
raise
131131

132132

133+
@pytest.mark.min_version("2025.12")
134+
class TestBroadcastShapes:
135+
136+
@given(shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes))
137+
def test_broadcast_shapes(self, shapes):
138+
repro_snippet = ph.format_snippet(f"xp.broadcast_shapes(*shapes) with {shapes = }")
139+
try:
140+
out_shape = xp.broadcast_shapes(*shapes)
141+
expected_shape = sh.broadcast_shapes(*shapes)
142+
assert out_shape == expected_shape
143+
except Exception as exc:
144+
ph.add_note(exc, repro_snippet)
145+
raise
146+
147+
def test_empty(self):
148+
assert xp.broadcast_shapes() == ()
149+
150+
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1, min_side=3))
151+
def test_error(self, shapes):
152+
s1, s2 = shapes
153+
s1 = s1[:-1] + (s1[-1] + 1,)
154+
with pytest.raises(ValueError):
155+
xp.broadcast_shapes(s1, s2)
156+
157+
133158
@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
134159
def test_broadcast_to(x, data):
135160
shape = data.draw(

0 commit comments

Comments
 (0)