1717)
1818from .env import PRECISION_DICT as PT_PRECISION_DICT
1919
20+ class SiLUT (torch .nn .Module ):
21+ def __init__ (self , threshold = 3.0 ):
22+ super ().__init__ ()
23+
24+ def sigmoid (x ):
25+ return 1 / (1 + np .exp (- x ))
26+
27+ def silu (x ):
28+ return x * sigmoid (x )
29+
30+ def silu_grad (x ):
31+ sig = sigmoid (x )
32+ return sig + x * sig * (1 - sig )
33+
34+ self .threshold = threshold
35+ self .slope = float (silu_grad (threshold ))
36+ self .const = float (silu (threshold ))
37+
38+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
39+ silu_part = F .silu (x )
40+ mask = x >= self .threshold
41+ if torch .any (mask ):
42+ tanh_part = torch .tanh (self .slope * (x - self .threshold )) + self .const
43+ return torch .where (x < self .threshold , silu_part , tanh_part )
44+ else :
45+ return silu_part
2046
2147class ActivationFn (torch .nn .Module ):
2248 def __init__ (self , activation : Optional [str ]) -> None :
2349 super ().__init__ ()
2450 self .activation : str = activation if activation is not None else "linear"
51+ if self .activation .lower ().startswith (
52+ "silut"
53+ ) or self .activation .lower ().startswith ("custom_silu" ):
54+ threshold = (
55+ float (self .activation .split (":" )[- 1 ]) if ":" in self .activation else 3.0
56+ )
57+ self .silut = SiLUT (threshold = threshold )
2558
2659 def forward (self , x : torch .Tensor ) -> torch .Tensor :
2760 """Returns the tensor after applying activation function corresponding to `activation`."""
@@ -43,6 +76,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4376 return F .silu (x )
4477 elif self .activation .lower () == "linear" or self .activation .lower () == "none" :
4578 return x
79+ elif self .activation .lower ().startswith ("custom_silu" ):
80+ assert self .silut is not None
81+ return self .silut (x )
4682 else :
4783 raise RuntimeError (f"activation function { self .activation } not supported" )
4884
0 commit comments