@@ -37,7 +37,7 @@ def _check_ns_shape_dtype(
3737 check_dtype : bool ,
3838 check_shape : bool ,
3939 check_scalar : bool ,
40- ) -> ModuleType : # numpydoc ignore=RT03
40+ ) -> tuple [ Array , Array , ModuleType ] : # numpydoc ignore=RT03
4141 """
4242 Assert that namespace, shape and dtype of the two arrays match.
4343
@@ -55,24 +55,35 @@ def _check_ns_shape_dtype(
5555
5656 Returns
5757 -------
58- Arrays namespace.
58+ Actual array, desired array, and their namespace.
5959 """
60- actual_xp = array_namespace (actual ) # Raises on scalars and lists
60+ actual_xp = array_namespace (actual ) # Raises on Python scalars and lists
6161 desired_xp = array_namespace (desired )
6262
6363 msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
6464 assert actual_xp == desired_xp , msg
6565
66+ if is_numpy_namespace (actual_xp ) and check_scalar :
67+ # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
68+ _msg = (
69+ "array-ness does not match:\n Actual: "
70+ f"{ type (actual )} \n Desired: { type (desired )} "
71+ )
72+ assert np .isscalar (actual ) == np .isscalar (desired ), _msg
73+
6674 # Dask uses nan instead of None for unknown shapes
6775 actual_shape = cast (tuple [float , ...], actual .shape )
6876 desired_shape = cast (tuple [float , ...], desired .shape )
6977 assert None not in actual_shape # Requires explicit support
7078 assert None not in desired_shape
79+
7180 if is_dask_namespace (desired_xp ):
7281 if any (math .isnan (i ) for i in actual_shape ):
73- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82+ actual .compute_chunk_sizes () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
83+ actual_shape = cast (tuple [float , ...], actual .shape )
7484 if any (math .isnan (i ) for i in desired_shape ):
75- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
85+ desired .compute_chunk_sizes () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
86+ desired_shape = cast (tuple [float , ...], desired .shape )
7687
7788 if check_shape :
7889 msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
@@ -82,24 +93,16 @@ def _check_ns_shape_dtype(
8293 # np.testing.assert_array_equal etc even when strict=False, but not for
8394 # non-materializable arrays.
8495 # This check excludes 0d arrays as they are special-cased in NumPy.
85- actual_size = math .prod (actual_shape ) # pyright: ignore[reportUnknownArgumentType]
86- desired_size = math .prod (desired_shape ) # pyright: ignore[reportUnknownArgumentType]
96+ actual_size = math .prod (actual_shape )
97+ desired_size = math .prod (desired_shape )
8798 msg = f"sizes do not match: { actual_size } != f{ desired_size } "
8899 assert actual_size == desired_size , msg
89100
90101 if check_dtype :
91102 msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
92103 assert actual .dtype == desired .dtype , msg
93-
94- if is_numpy_namespace (actual_xp ) and check_scalar :
95- # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
96- _msg = (
97- "array-ness does not match:\n Actual: "
98- f"{ type (actual )} \n Desired: { type (desired )} "
99- )
100- assert np .isscalar (actual ) == np .isscalar (desired ), _msg
101-
102- return desired_xp
104+ desired = desired_xp .broadcast_to (desired , actual_shape )
105+ return actual , desired , desired_xp
103106
104107
105108def _is_materializable (x : Array ) -> bool :
@@ -169,7 +172,9 @@ def xp_assert_equal(
169172 xp_assert_close : Similar function for inexact equality checks.
170173 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
171174 """
172- xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
175+ actual , desired , xp = _check_ns_shape_dtype (
176+ actual , desired , check_dtype , check_shape , check_scalar
177+ )
173178 if not _is_materializable (actual ):
174179 return
175180 actual_np = as_numpy_array (actual , xp = xp )
@@ -211,7 +216,7 @@ def xp_assert_less(
211216 xp_assert_close : Similar function for inexact equality checks.
212217 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
213218 """
214- xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
219+ x , y , xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
215220 if not _is_materializable (x ):
216221 return
217222 x_np = as_numpy_array (x , xp = xp )
@@ -267,7 +272,9 @@ def xp_assert_close(
267272 -----
268273 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
269274 """
270- xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
275+ actual , desired , xp = _check_ns_shape_dtype (
276+ actual , desired , check_dtype , check_shape , check_scalar
277+ )
271278 if not _is_materializable (actual ):
272279 return
273280
0 commit comments