Skip to content

Commit fc0fca2

Browse files
committed
ENH: test broadcast_shapes
1 parent be9dd9e commit fc0fca2

2 files changed

Lines changed: 27 additions & 0 deletions

File tree

array_api_tests/_array_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __repr__(self):
3535
_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
3636
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]
3737

38+
_top_level_attrs += ['broadcast_shapes'] # FIXME: until the spec is not updated
39+
3840
for attr in _top_level_attrs:
3941
try:
4042
globals()[attr] = getattr(xp, attr)

array_api_tests/test_data_type_functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,31 @@ def test_broadcast_arrays(shapes, data):
128128
raise
129129

130130

131+
132+
class TestBroadcastShapes:
133+
134+
@given(shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes))
135+
def test_broadcast_shapes(self, shapes):
136+
repro_snippet = ph.format_snippet(f"xp.broadcast_shapes(*shapes) with {shapes = }")
137+
try:
138+
out_shape = xp.broadcast_shapes(*shapes)
139+
expected_shape = sh.broadcast_shapes(*shapes)
140+
assert out_shape == expected_shape
141+
except Exception as exc:
142+
ph.add_note(exc, repro_snippet)
143+
raise
144+
145+
def test_empty(self):
146+
assert xp.broadcast_shapes() == ()
147+
148+
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1, min_side=3))
149+
def test_error(self, shapes):
150+
s1, s2 = shapes
151+
s1 = s1[:-1] + (s1[-1] + 1,)
152+
with pytest.raises(ValueError):
153+
xp.broadcast_shapes(s1, s2)
154+
155+
131156
@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
132157
def test_broadcast_to(x, data):
133158
shape = data.draw(

0 commit comments

Comments
 (0)