Skip to content

Commit f7ddf6b

Browse files
committed
BUG: torch: fix up clip
1 parent dd6d3e8 commit f7ddf6b

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import reduce as _reduce, wraps as _wraps
55
from builtins import all as _builtin_all, any as _builtin_any
66
from typing import Any, Literal
7+
import math
78

89
import torch
910

@@ -839,7 +840,6 @@ def clip(
839840
/,
840841
min: int | float | Array | None = None,
841842
max: int | float | Array | None = None,
842-
**kwargs
843843
) -> Array:
844844
def _isscalar(a: object):
845845
return isinstance(a, int | float) or a is None
@@ -857,13 +857,24 @@ def _isscalar(a: object):
857857
min_is_scalar = _isscalar(min)
858858
max_is_scalar = _isscalar(max)
859859

860-
if min is not None and max is not None:
861-
if min_is_scalar and not max_is_scalar:
862-
min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
863-
if max_is_scalar and not min_is_scalar:
864-
max = torch.as_tensor(max, dtype=x.dtype, device=x.device)
860+
if min_is_scalar and max_is_scalar:
861+
if (min is not None and math.isnan(min)) or (max is not None and math.isnan(max)):
862+
# edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.)
863+
# https://github.com/pytorch/pytorch/issues/172067
864+
return torch.full_like(x, fill_value=torch.nan)
865+
return torch.clamp(x, min, max)
865866

866-
return torch.clamp(x, min, max, **kwargs)
867+
# pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures,
868+
# but does not accept (tensor, scalar, tensor)
869+
a_min = min
870+
if min is not None and min_is_scalar:
871+
a_min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
872+
873+
a_max = max
874+
if max is not None and max_is_scalar:
875+
a_max = torch.as_tensor(max, dtype=x.dtype, device=x.device)
876+
877+
return torch.clamp(x, a_min, a_max)
867878

868879

869880
def sign(x: Array, /) -> Array:

0 commit comments

Comments
 (0)