@@ -598,22 +598,23 @@ def _check_ns_shape_dtype(
598598 np = _require_numpy ()
599599
600600 actual_xp = array_namespace (actual ) # Raises on Python scalars and lists
601- desired_xp = array_namespace (desired )
602601
603602 if xp is not None :
604603 _msg = (
605- "Namespace of desired array does not match the `xp` argument.\n "
606- f"Desired array's namespace: { desired_xp .__name__ } \n "
604+ "Namespace of actual array does not match the `xp` argument.\n "
605+ f"Actual array's namespace: { actual_xp .__name__ } \n "
607606 f"Expected namespace: { xp .__name__ } ."
608607 )
609- assert desired_xp == xp , _msg
610-
611- _msg = (
612- "Namespaces of actual and desired arrays do not match.\n "
613- f"Actual: { actual_xp .__name__ } \n "
614- f"Desired: { desired_xp .__name__ } ."
615- )
616- assert actual_xp == desired_xp , _msg
608+ assert actual_xp == xp , _msg
609+ desired_xp = xp
610+ else :
611+ desired_xp = array_namespace (desired )
612+ _msg = (
613+ "Namespaces of actual and desired arrays do not match.\n "
614+ f"Actual: { actual_xp .__name__ } \n "
615+ f"Desired: { desired_xp .__name__ } ."
616+ )
617+ assert actual_xp == desired_xp , _msg
617618
618619 if is_numpy_namespace (actual_xp ) and check_scalar :
619620 # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
@@ -650,6 +651,7 @@ def _check_ns_shape_dtype(
650651 msg = f"sizes do not match: { actual_size } != { desired_size } "
651652 assert actual_size == desired_size , msg
652653
654+ desired = desired_xp .asarray (desired )
653655 if check_dtype :
654656 msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
655657 assert actual .dtype == desired .dtype , msg
0 commit comments