44from functools import reduce as _reduce , wraps as _wraps
55from builtins import all as _builtin_all , any as _builtin_any
66from typing import Any , Literal
7+ import math
78
89import torch
910
@@ -839,7 +840,6 @@ def clip(
839840 / ,
840841 min : int | float | Array | None = None ,
841842 max : int | float | Array | None = None ,
842- ** kwargs
843843) -> Array :
844844 def _isscalar (a : object ):
845845 return isinstance (a , int | float ) or a is None
@@ -857,13 +857,24 @@ def _isscalar(a: object):
857857 min_is_scalar = _isscalar (min )
858858 max_is_scalar = _isscalar (max )
859859
860- if min is not None and max is not None :
861- if min_is_scalar and not max_is_scalar :
862- min = torch .as_tensor (min , dtype = x .dtype , device = x .device )
863- if max_is_scalar and not min_is_scalar :
864- max = torch .as_tensor (max , dtype = x .dtype , device = x .device )
860+ if min_is_scalar and max_is_scalar :
861+ if (min is not None and math .isnan (min )) or (max is not None and math .isnan (max )):
862+ # edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.)
863+ # https://github.com/pytorch/pytorch/issues/172067
864+ return torch .full_like (x , fill_value = torch .nan )
865+ return torch .clamp (x , min , max )
865866
866- return torch .clamp (x , min , max , ** kwargs )
867+ # pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures,
868+ # but does not accept (tensor, scalar, tensor)
869+ a_min = min
870+ if min is not None and min_is_scalar :
871+ a_min = torch .as_tensor (min , dtype = x .dtype , device = x .device )
872+
873+ a_max = max
874+ if max is not None and max_is_scalar :
875+ a_max = torch .as_tensor (max , dtype = x .dtype , device = x .device )
876+
877+ return torch .clamp (x , a_min , a_max )
867878
868879
869880def sign (x : Array , / ) -> Array :
0 commit comments