@@ -140,6 +140,7 @@ def xp_assert_equal(
140140 desired : Array ,
141141 * ,
142142 err_msg : str = "" ,
143+ verbose : bool = True ,
143144 check_dtype : bool = True ,
144145 check_shape : bool = True ,
145146 check_scalar : bool = False ,
@@ -155,6 +156,8 @@ def xp_assert_equal(
155156 The expected array (typically hardcoded).
156157 err_msg : str, optional
157158 Error message to display on failure.
159+ verbose: bool, default: True
160+ Whether to include the conflicting arrays in the error message on failure.
158161 check_dtype, check_shape : bool, default: True
159162 Whether to check agreement between actual and desired dtypes and shapes
160163 check_scalar : bool, default: False
@@ -171,14 +174,17 @@ def xp_assert_equal(
171174 return
172175 actual_np = as_numpy_array (actual , xp = xp )
173176 desired_np = as_numpy_array (desired , xp = xp )
174- np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
177+ np .testing .assert_array_equal (
178+ actual_np , desired_np , err_msg = err_msg , verbose = verbose
179+ )
175180
176181
177182def xp_assert_less (
178183 x : Array ,
179184 y : Array ,
180185 * ,
181186 err_msg : str = "" ,
187+ verbose : bool = True ,
182188 check_dtype : bool = True ,
183189 check_shape : bool = True ,
184190 check_scalar : bool = False ,
@@ -192,6 +198,8 @@ def xp_assert_less(
192198 The arrays to compare according to ``x < y`` (elementwise).
193199 err_msg : str, optional
194200 Error message to display on failure.
201+ verbose: bool, default: True
202+ Whether to include the conflicting arrays in the error message on failure.
195203 check_dtype, check_shape : bool, default: True
196204 Whether to check agreement between actual and desired dtypes and shapes
197205 check_scalar : bool, default: False
@@ -208,7 +216,7 @@ def xp_assert_less(
208216 return
209217 x_np = as_numpy_array (x , xp = xp )
210218 y_np = as_numpy_array (y , xp = xp )
211- np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
219+ np .testing .assert_array_less (x_np , y_np , err_msg = err_msg , verbose = verbose )
212220
213221
214222def xp_assert_close (
@@ -217,7 +225,9 @@ def xp_assert_close(
217225 * ,
218226 rtol : float | None = None ,
219227 atol : float = 0 ,
228+ equal_nan : bool = True ,
220229 err_msg : str = "" ,
230+ verbose : bool = True ,
221231 check_dtype : bool = True ,
222232 check_shape : bool = True ,
223233 check_scalar : bool = False ,
@@ -235,8 +245,12 @@ def xp_assert_close(
235245 Relative tolerance. Default: dtype-dependent.
236246 atol : float, optional
237247 Absolute tolerance. Default: 0.
248+ equal_nan : bool, default: True
249+ Whether to consider NaNs in corresponding locations as equal.
238250 err_msg : str, optional
239251 Error message to display on failure.
252+ verbose: bool, default: True
253+ Whether to include the conflicting arrays in the error message on failure.
240254 check_dtype, check_shape : bool, default: True
241255 Whether to check agreement between actual and desired dtypes and shapes
242256 check_scalar : bool, default: False
@@ -273,7 +287,9 @@ def xp_assert_close(
273287 desired_np ,
274288 rtol = rtol , # pyright: ignore[reportArgumentType]
275289 atol = atol ,
290+ equal_nan = equal_nan ,
276291 err_msg = err_msg ,
292+ verbose = verbose ,
277293 )
278294
279295
0 commit comments