Skip to content

Commit 7ad3b25

Browse files
committed
BUG: torch: work around torch.round not supporting complex inputs
1 parent a88067a commit 7ad3b25

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,15 @@ def sign(x: Array, /) -> Array:
912912
return out
913913

914914

915+
def round(x: Array, /, **kwargs) -> Array:
916+
# torch.round fails for complex inputs
917+
# https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845
918+
if x.dtype.is_complex:
919+
return torch.round(x.real, **kwargs) + 1j*torch.round(x.imag, **kwargs)
920+
else:
921+
return torch.round(x, **kwargs)
922+
923+
915924
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
916925
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
917926
# will be required to pass the indexing argument."
@@ -923,7 +932,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
923932
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
924933
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
925934
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
926-
'diff', 'divide',
935+
'diff', 'divide', 'round',
927936
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
928937
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
929938
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

torch-xfails.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ array_api_tests/test_statistical_functions.py::test_var
130130

131131

132132
# These functions do not yet support complex numbers
133-
array_api_tests/test_operators_and_elementwise_functions.py::test_round
134133
array_api_tests/test_set_functions.py::test_unique_counts
135134
array_api_tests/test_set_functions.py::test_unique_values
136135

0 commit comments

Comments
 (0)