Skip to content

Commit a71706f

Browse files
authored
Merge pull request #434 from ev-br/torch_mps
ENH: allow running on torch.mps device, which does not have float64/complex128
2 parents a75721f + 99e085b commit a71706f

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def finite_matrices(draw, shape=matrix_shapes(), dtype=floating_dtypes, bound=No
341341

342342

343343
rtol_shared_matrix_shapes = shared(matrix_shapes())
344-
# Should we set a max_value here?
345-
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
344+
# Arbitrary max_value for rtols, to avoid overflows when float64 is not available
345+
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0, max_value=42)
346346
rtols = one_of(floats(**_rtol_float_kw),
347347
arrays(dtype=real_floating_dtypes,
348348
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),

array_api_tests/test_signatures.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,10 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
156156
array_argnames -= set(func_to_specified_arg_exprs[func_name].keys())
157157
if len(array_argnames) > 0:
158158
in_dtypes = dh.func_in_dtypes[func_name]
159-
for dtype_name in ["float64", "bool", "int64", "complex128"]:
159+
# use "float64" if available, "float32" otherwise; ditto for complex128/complex64
160+
float_name = dh.dtype_to_name[dh.widest_real_dtype]
161+
cmplx_name = dh.dtype_to_name[dh.widest_complex_dtype]
162+
for dtype_name in [float_name, "bool", "int64", cmplx_name]:
160163
# We try float64 first because uninspectable numerical functions
161164
# tend to support float inputs first-and-foremost (i.e. PyTorch)
162165
try:

0 commit comments

Comments
 (0)