@@ -598,22 +598,23 @@ def _check_ns_shape_dtype(
598598 np = _require_numpy ()
599599
600600 actual_xp = array_namespace (actual ) # Raises on Python scalars and lists
601- desired_xp = array_namespace (desired )
602601
603602 if xp is not None :
604603 _msg = (
605- "Namespace of desired array does not match the `xp` argument.\n "
606- f"Desired array's namespace: { desired_xp .__name__ } \n "
604+ "Namespace of actual array does not match the `xp` argument.\n "
605+ f"Actual array's namespace: { actual_xp .__name__ } \n "
607606 f"Expected namespace: { xp .__name__ } ."
608607 )
609- assert desired_xp == xp , _msg
610-
611- _msg = (
612- "Namespaces of actual and desired arrays do not match.\n "
613- f"Actual: { actual_xp .__name__ } \n "
614- f"Desired: { desired_xp .__name__ } ."
615- )
616- assert actual_xp == desired_xp , _msg
608+ assert actual_xp == xp , _msg
609+ desired_xp = xp
610+ else :
611+ desired_xp = array_namespace (desired )
612+ _msg = (
613+ "Namespaces of actual and desired arrays do not match.\n "
614+ f"Actual: { actual_xp .__name__ } \n "
615+ f"Desired: { desired_xp .__name__ } ."
616+ )
617+ assert actual_xp == desired_xp , _msg
617618
618619 if is_numpy_namespace (actual_xp ) and check_scalar :
619620 # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
@@ -653,7 +654,7 @@ def _check_ns_shape_dtype(
653654 if check_dtype :
654655 msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
655656 assert actual .dtype == desired .dtype , msg
656- desired = desired_xp .broadcast_to (desired , actual_shape )
657+ desired = desired_xp .broadcast_to (desired_xp . asarray ( desired ) , actual_shape )
657658 return actual , desired , desired_xp , np
658659
659660
0 commit comments