@@ -158,49 +158,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
158158class ActivationFn (torch .nn .Module ):
159159 def __init__ (self , activation : Optional [str ]) -> None :
160160 super ().__init__ ()
161- self .activation : str = activation if activation is not None else "linear"
162- if self .activation .lower ().startswith (
163- "silut"
164- ) or self .activation .lower ().startswith ("custom_silu" ):
165- threshold = (
166- float (self .activation .split (":" )[- 1 ]) if ":" in self .activation else 3.0
167- )
161+ activation = activation .lower () if activation is not None else "linear"
162+ if activation == "relu" :
163+ self .act_module = torch .nn .ReLU ()
164+ elif activation == "gelu" or activation == "gelu_tf" :
165+ self .act_module = torch .nn .GELU (approximate = "tanh" )
166+ elif activation == "tanh" :
167+ self .act_module = torch .nn .Tanh ()
168+ elif activation == "relu6" :
169+ self .act_module = torch .nn .ReLU6 ()
170+ elif activation == "softplus" :
171+ self .act_module = torch .nn .Softplus ()
172+ elif activation == "sigmoid" :
173+ self .act_module = torch .nn .Sigmoid ()
174+ elif activation == "silu" :
175+ self .act_module = torch .nn .SiLU ()
176+ elif activation .startswith ("silut" ) or activation .startswith ("custom_silu" ):
177+ threshold = float (activation .split (":" )[- 1 ]) if ":" in activation else 3.0
168178 if env .CUSTOM_OP_USE_JIT :
169179 # for efficient training but can not be jit
170- self .silut = SiLUTScript (threshold = threshold )
180+ self .act_module = SiLUTScript (threshold = threshold )
171181 else :
172182 # for jit freeze
173- self .silut = SiLUT (threshold = threshold )
183+ self .act_module = SiLUT (threshold = threshold )
184+ elif activation == "linear" or activation == "none" :
185+ self .act_module = torch .nn .Identity ()
174186 else :
175- self .silut = None
187+ raise RuntimeError ( f"activation function { self .activation } not supported" )
176188
177189 def forward (self , x : torch .Tensor ) -> torch .Tensor :
178- """Returns the tensor after applying activation function corresponding to `activation`."""
179- # See jit supported types: https://pytorch.org/docs/stable/jit_language_reference.html#supported-type
180-
181- if self .activation .lower () == "relu" :
182- return F .relu (x )
183- elif self .activation .lower () == "gelu" or self .activation .lower () == "gelu_tf" :
184- return F .gelu (x , approximate = "tanh" )
185- elif self .activation .lower () == "tanh" :
186- return torch .tanh (x )
187- elif self .activation .lower () == "relu6" :
188- return F .relu6 (x )
189- elif self .activation .lower () == "softplus" :
190- return F .softplus (x )
191- elif self .activation .lower () == "sigmoid" :
192- return torch .sigmoid (x )
193- elif self .activation .lower () == "silu" :
194- return F .silu (x )
195- elif self .activation .lower ().startswith (
196- "silut"
197- ) or self .activation .lower ().startswith ("custom_silu" ):
198- assert self .silut is not None
199- return self .silut (x )
200- elif self .activation .lower () == "linear" or self .activation .lower () == "none" :
201- return x
202- else :
203- raise RuntimeError (f"activation function { self .activation } not supported" )
190+ """Returns the tensor after applying activation function."""
191+ return self .act_module (x )
204192
205193
206194@overload
0 commit comments