Skip to content

Commit e01426c

Browse files
prady0tlucascolley
andauthored
BUG: testing: fix check_shape=False with broadcasting (#735)
* Fixing check_shape and scalar inputs Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * Update src/array_api_extra/_lib/_testing.py Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> * Update src/array_api_extra/_lib/_testing.py Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> * Update _testing.py * special case for dask arrays Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * typing --------- Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 76c3af8 commit e01426c

1 file changed

Lines changed: 27 additions & 20 deletions

File tree

src/array_api_extra/_lib/_testing.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _check_ns_shape_dtype(
3737
check_dtype: bool,
3838
check_shape: bool,
3939
check_scalar: bool,
40-
) -> ModuleType: # numpydoc ignore=RT03
40+
) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03
4141
"""
4242
Assert that namespace, shape and dtype of the two arrays match.
4343
@@ -55,24 +55,35 @@ def _check_ns_shape_dtype(
5555
5656
Returns
5757
-------
58-
Arrays namespace.
58+
Actual array, desired array, and their namespace.
5959
"""
60-
actual_xp = array_namespace(actual) # Raises on scalars and lists
60+
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
6161
desired_xp = array_namespace(desired)
6262

6363
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
6464
assert actual_xp == desired_xp, msg
6565

66+
if is_numpy_namespace(actual_xp) and check_scalar:
67+
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
68+
_msg = (
69+
"array-ness does not match:\n Actual: "
70+
f"{type(actual)}\n Desired: {type(desired)}"
71+
)
72+
assert np.isscalar(actual) == np.isscalar(desired), _msg
73+
6674
# Dask uses nan instead of None for unknown shapes
6775
actual_shape = cast(tuple[float, ...], actual.shape)
6876
desired_shape = cast(tuple[float, ...], desired.shape)
6977
assert None not in actual_shape # Requires explicit support
7078
assert None not in desired_shape
79+
7180
if is_dask_namespace(desired_xp):
7281
if any(math.isnan(i) for i in actual_shape):
73-
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82+
actual.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
83+
actual_shape = cast(tuple[float, ...], actual.shape)
7484
if any(math.isnan(i) for i in desired_shape):
75-
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
85+
desired.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
86+
desired_shape = cast(tuple[float, ...], desired.shape)
7687

7788
if check_shape:
7889
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
@@ -82,24 +93,16 @@ def _check_ns_shape_dtype(
8293
# np.testing.assert_array_equal etc even when strict=False, but not for
8394
# non-materializable arrays.
8495
# This check excludes 0d arrays as they are special-cased in NumPy.
85-
actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType]
86-
desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType]
96+
actual_size = math.prod(actual_shape)
97+
desired_size = math.prod(desired_shape)
8798
msg = f"sizes do not match: {actual_size} != f{desired_size}"
8899
assert actual_size == desired_size, msg
89100

90101
if check_dtype:
91102
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
92103
assert actual.dtype == desired.dtype, msg
93-
94-
if is_numpy_namespace(actual_xp) and check_scalar:
95-
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
96-
_msg = (
97-
"array-ness does not match:\n Actual: "
98-
f"{type(actual)}\n Desired: {type(desired)}"
99-
)
100-
assert np.isscalar(actual) == np.isscalar(desired), _msg
101-
102-
return desired_xp
104+
desired = desired_xp.broadcast_to(desired, actual_shape)
105+
return actual, desired, desired_xp
103106

104107

105108
def _is_materializable(x: Array) -> bool:
@@ -169,7 +172,9 @@ def xp_assert_equal(
169172
xp_assert_close : Similar function for inexact equality checks.
170173
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
171174
"""
172-
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
175+
actual, desired, xp = _check_ns_shape_dtype(
176+
actual, desired, check_dtype, check_shape, check_scalar
177+
)
173178
if not _is_materializable(actual):
174179
return
175180
actual_np = as_numpy_array(actual, xp=xp)
@@ -211,7 +216,7 @@ def xp_assert_less(
211216
xp_assert_close : Similar function for inexact equality checks.
212217
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
213218
"""
214-
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
219+
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
215220
if not _is_materializable(x):
216221
return
217222
x_np = as_numpy_array(x, xp=xp)
@@ -267,7 +272,9 @@ def xp_assert_close(
267272
-----
268273
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
269274
"""
270-
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
275+
actual, desired, xp = _check_ns_shape_dtype(
276+
actual, desired, check_dtype, check_shape, check_scalar
277+
)
271278
if not _is_materializable(actual):
272279
return
273280

0 commit comments

Comments
 (0)