Skip to content

Commit fa58c60

Browse files
committed
Merge unvectorize_classes with fix
2 parents 7a9935d + 2c4cbba commit fa58c60

15 files changed

Lines changed: 1722 additions & 894 deletions

.git-blame-ignore-revs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,10 @@ e807ffe526c7330691e8f39d31347dc2b3106de3
77
bd42e84d2e5aae26ade8d70384e74effd1de89cb
88
f7e822883b7e24b5aa540e2413759a85128b42ef
99
a37f348ba27b6818e92fda8aee2406c653c671ea
10+
# gh-396
11+
ec5a3b4e185c262b0a5f5b1631b84a09f766d80e
12+
9058908b58ce627467ac34e768098a25f5863d31
13+
c80e1823c2e738381ca02f27cea1e2b89dde0ac5
14+
# gh-402
15+
bdc84e8316046cb5bdc637067460057eef17d0f1
16+

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ jobs:
77

88
runs-on: ubuntu-latest
99
steps:
10-
- uses: actions/checkout@v4
10+
- uses: actions/checkout@v6
1111
- name: Set up Python ${{ matrix.python-version }}
12-
uses: actions/setup-python@v5
12+
uses: actions/setup-python@v6
1313
with:
1414
python-version: "3.10"
1515
- name: Run pre-commit hook

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ jobs:
1111
python-version: ["3.10", "3.12", "3.13", "3.14"]
1212
steps:
1313
- name: Checkout array-api-tests
14-
uses: actions/checkout@v1
14+
uses: actions/checkout@v6
1515
with:
1616
submodules: 'true'
1717
- name: Set up Python ${{ matrix.python-version }}
18-
uses: actions/setup-python@v1
18+
uses: actions/setup-python@v6
1919
with:
2020
python-version: ${{ matrix.python-version }}
2121
- name: Install dependencies

array-api-strict-skips.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,21 @@ array_api_tests/test_data_type_functions.py::test_finfo
3232
array_api_tests/test_data_type_functions.py::test_finfo_dtype
3333
array_api_tests/test_data_type_functions.py::test_iinfo
3434
array_api_tests/test_data_type_functions.py::test_iinfo_dtype
35+
36+
37+
# complex special cases which failed "forever"
38+
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j]
39+
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j]
40+
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j]
41+
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j]
42+
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j]
43+
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j]
44+
array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j]
45+
46+
array_api_tests/test_special_cases.py::test_unary[sign((real(x_i) is -0 or real(x_i) == +0) and (imag(x_i) is -0 or imag(x_i) == +0)) -> 0 + 0j]
47+
array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j]
48+
49+
# this acosh failure is only seen with python==3.10 and numpy==2.2.6, and not e.g. python 3.12 & numpy 2.4.1
50+
array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2]
51+
52+
array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j]

array_api_tests/dtype_helpers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,39 @@ def is_scalar(x):
199199
return isinstance(x, (int, float, complex, bool))
200200

201201

202+
def complex_dtype_for(dtyp):
203+
"""Complex dtype for a float or complex."""
204+
if dtyp in complex_dtypes:
205+
return dtyp
206+
if dtyp not in real_float_dtypes:
207+
raise ValueError(f"no complex dtype to match {dtyp}")
208+
209+
real_name = dtype_to_name[dtyp]
210+
complex_name = {"float32": "complex64", "float64": "complex128"}[real_name]
211+
212+
complex_dtype = _name_to_dtype.get(complex_name, None)
213+
if complex_dtype is None:
214+
raise ValueError(f"no complex dtype to match {dtyp}")
215+
return complex_dtype
216+
217+
218+
def real_dtype_for(dtyp):
219+
"""Real float dtype for a float or complex."""
220+
if dtyp in real_float_dtypes:
221+
return dtyp
222+
if dtyp not in complex_dtypes:
223+
raise ValueError(f"no real float dtype to match {dtyp}")
224+
225+
complex_name = dtype_to_name[dtyp]
226+
real_name = {"complex64": "float32", "complex128": "float64"}[complex_name]
227+
228+
real_dtype = _name_to_dtype.get(real_name, None)
229+
if real_dtype is None:
230+
raise ValueError(f"no real dtype to match {dtyp}")
231+
return real_dtype
232+
233+
234+
202235
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
203236
dtype_value_pairs = []
204237
for name, value in mapping.items():

array_api_tests/hypothesis_helpers.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
457457
dtypes should be one of the shared_* dtypes strategies.
458458
"""
459459
dtype = draw(dtypes)
460-
mM = kwds.pop('mM', None)
461460
if dh.is_int_dtype(dtype):
462-
if mM is None:
463-
m, M = dh.dtype_ranges[dtype]
464-
else:
465-
m, M = mM
466-
return draw(integers(m, M))
461+
m, M = dh.dtype_ranges[dtype]
462+
min_value = kwds.get('min_value', m)
463+
max_value = kwds.get('max_value', M)
464+
465+
return draw(integers(min_value, max_value))
467466
elif dtype == bool_dtype:
468467
return draw(booleans())
469468
elif dtype == float64:
@@ -593,20 +592,32 @@ def two_mutual_arrays(
593592

594593

595594
@composite
596-
def array_and_py_scalar(draw, dtypes, mM=None, positive=False):
595+
def array_and_py_scalar(draw, dtypes, **kwds):
597596
"""Draw a pair: (array, scalar) or (scalar, array)."""
598597
dtype = draw(sampled_from(dtypes))
599598

600-
scalar_var = draw(scalars(just(dtype), finite=True, mM=mM))
601-
if positive:
602-
assume (scalar_var > 0)
599+
# draw the scalar: for float arrays, draw a float or an int
600+
if dtype in dh.real_float_dtypes:
601+
scalar_strategy = sampled_from([xp.int32, dtype])
602+
else:
603+
scalar_strategy = just(dtype)
604+
scalar_var = draw(scalars(scalar_strategy, finite=True, **kwds))
603605

606+
# draw the array.
607+
# XXX artificially limit the range of values for floats, otherwise value testing is flaky
604608
elements={}
605609
if dtype in dh.real_float_dtypes:
606-
elements = {'allow_nan': False, 'allow_infinity': False,
607-
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
608-
if positive:
609-
elements = {'min_value': 0}
610+
elements = {
611+
'allow_nan': False,
612+
'allow_infinity': False,
613+
'min_value': kwds.get('min_value', 1.0 / (2<<5)),
614+
'max_value': kwds.get('max_value', 2<<5)
615+
}
616+
elif dtype in dh.int_dtypes:
617+
elements = {
618+
'min_value': kwds.get('min_value', None),
619+
'max_value': kwds.get('max_value', None)
620+
}
610621
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
611622

612623
if draw(booleans()):

array_api_tests/test_array_object.py

Lines changed: 99 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def test_getitem(shape, dtype, data):
8686
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
8787

8888
repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]")
89-
9089
try:
9190
out = x[key]
9291

@@ -109,6 +108,7 @@ def test_getitem(shape, dtype, data):
109108
ph.add_note(exc, repro_snippet)
110109
raise
111110

111+
112112
@pytest.mark.unvectorized
113113
@given(
114114
shape=hh.shapes(),
@@ -133,28 +133,34 @@ def test_setitem(shape, dtypes, data):
133133
value = data.draw(value_strat, label="value")
134134

135135
res = xp.asarray(x, copy=True)
136-
res[key] = value
137-
138-
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
139-
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
140-
f_res = sh.fmt_idx("x", key)
141-
if isinstance(value, get_args(Scalar)):
142-
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
143-
if cmath.isnan(value):
144-
assert xp.isnan(res[key]), msg
136+
137+
repro_snippet = ph.format_snippet(f"{res!r}[{key!r}] = {value!r}")
138+
try:
139+
res[key] = value
140+
141+
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
142+
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
143+
f_res = sh.fmt_idx("x", key)
144+
if isinstance(value, get_args(Scalar)):
145+
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
146+
if cmath.isnan(value):
147+
assert xp.isnan(res[key]), msg
148+
else:
149+
assert res[key] == value, msg
145150
else:
146-
assert res[key] == value, msg
147-
else:
148-
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
149-
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
150-
for idx in unaffected_indices:
151-
ph.assert_0d_equals(
152-
"__setitem__",
153-
x_repr=f"old {f_res}",
154-
x_val=x[idx],
155-
out_repr=f"modified {f_res}",
156-
out_val=res[idx],
157-
)
151+
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
152+
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
153+
for idx in unaffected_indices:
154+
ph.assert_0d_equals(
155+
"__setitem__",
156+
x_repr=f"old {f_res}",
157+
x_val=x[idx],
158+
out_repr=f"modified {f_res}",
159+
out_val=res[idx],
160+
)
161+
except Exception as exc:
162+
ph.add_note(exc, repro_snippet)
163+
raise
158164

159165

160166
@pytest.mark.unvectorized
@@ -178,29 +184,34 @@ def test_getitem_masking(shape, data):
178184
x[key]
179185
return
180186

181-
out = x[key]
187+
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
188+
try:
189+
out = x[key]
182190

183-
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
184-
if key.ndim == 0:
185-
expected_shape = (1,) if key else (0,)
186-
expected_shape += x.shape
187-
else:
188-
size = int(xp.sum(xp.astype(key, xp.uint8)))
189-
expected_shape = (size,) + x.shape[key.ndim :]
190-
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
191-
if not any(s == 0 for s in key.shape):
192-
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
193-
out_indices = sh.ndindex(out.shape)
194-
for x_idx in sh.ndindex(x.shape):
195-
if key[x_idx]:
196-
out_idx = next(out_indices)
197-
ph.assert_0d_equals(
198-
"__getitem__",
199-
x_repr=f"x[{x_idx}]",
200-
x_val=x[x_idx],
201-
out_repr=f"out[{out_idx}]",
202-
out_val=out[out_idx],
203-
)
191+
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
192+
if key.ndim == 0:
193+
expected_shape = (1,) if key else (0,)
194+
expected_shape += x.shape
195+
else:
196+
size = int(xp.sum(xp.astype(key, xp.uint8)))
197+
expected_shape = (size,) + x.shape[key.ndim :]
198+
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
199+
if not any(s == 0 for s in key.shape):
200+
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
201+
out_indices = sh.ndindex(out.shape)
202+
for x_idx in sh.ndindex(x.shape):
203+
if key[x_idx]:
204+
out_idx = next(out_indices)
205+
ph.assert_0d_equals(
206+
"__getitem__",
207+
x_repr=f"x[{x_idx}]",
208+
x_val=x[x_idx],
209+
out_repr=f"out[{out_idx}]",
210+
out_val=out[out_idx],
211+
)
212+
except Exception as exc:
213+
ph.add_note(exc, repro_snippet)
214+
raise
204215

205216

206217
@pytest.mark.unvectorized
@@ -213,38 +224,44 @@ def test_setitem_masking(shape, data):
213224
)
214225

215226
res = xp.asarray(x, copy=True)
216-
res[key] = value
217-
218-
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
219-
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
220-
scalar_type = dh.get_scalar_type(x.dtype)
221-
for idx in sh.ndindex(x.shape):
222-
if key[idx]:
223-
if isinstance(value, get_args(Scalar)):
224-
ph.assert_scalar_equals(
225-
"__setitem__",
226-
type_=scalar_type,
227-
idx=idx,
228-
out=scalar_type(res[idx]),
229-
expected=value,
230-
repr_name="modified x",
231-
)
227+
228+
repro_snippet = ph.format_snippet(f"{res}[{key!r}] = {value!r}")
229+
try:
230+
res[key] = value
231+
232+
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
233+
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
234+
scalar_type = dh.get_scalar_type(x.dtype)
235+
for idx in sh.ndindex(x.shape):
236+
if key[idx]:
237+
if isinstance(value, get_args(Scalar)):
238+
ph.assert_scalar_equals(
239+
"__setitem__",
240+
type_=scalar_type,
241+
idx=idx,
242+
out=scalar_type(res[idx]),
243+
expected=value,
244+
repr_name="modified x",
245+
)
246+
else:
247+
ph.assert_0d_equals(
248+
"__setitem__",
249+
x_repr="value",
250+
x_val=value,
251+
out_repr=f"modified x[{idx}]",
252+
out_val=res[idx]
253+
)
232254
else:
233255
ph.assert_0d_equals(
234256
"__setitem__",
235-
x_repr="value",
236-
x_val=value,
257+
x_repr=f"old x[{idx}]",
258+
x_val=x[idx],
237259
out_repr=f"modified x[{idx}]",
238260
out_val=res[idx]
239261
)
240-
else:
241-
ph.assert_0d_equals(
242-
"__setitem__",
243-
x_repr=f"old x[{idx}]",
244-
x_val=x[idx],
245-
out_repr=f"modified x[{idx}]",
246-
out_val=res[idx]
247-
)
262+
except Exception as exc:
263+
ph.add_note(exc, repro_snippet)
264+
raise
248265

249266

250267
# ### Fancy indexing ###
@@ -309,15 +326,20 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
309326
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
310327

311328
key = tuple(key)
312-
out = x[key]
329+
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
330+
try:
331+
out = x[key]
313332

314-
arrays = [xp.asarray(k) for k in key]
315-
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
316-
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
333+
arrays = [xp.asarray(k) for k in key]
334+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
335+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
317336

318-
for idx in sh.ndindex(bcast_shape):
319-
tpl = tuple(k[idx] for k in bcast_key)
320-
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
337+
for idx in sh.ndindex(bcast_shape):
338+
tpl = tuple(k[idx] for k in bcast_key)
339+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
340+
except Exception as exc:
341+
ph.add_note(exc, repro_snippet)
342+
raise
321343

322344

323345
def make_scalar_casting_param(

0 commit comments

Comments
 (0)