Skip to content

Commit 0fa716a

Browse files
committed
refactor: simplify act function selection
1 parent db22802 commit 0fa716a

1 file changed

Lines changed: 24 additions & 36 deletions

File tree

deepmd/pt/utils/utils.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -158,49 +158,37 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
158158
class ActivationFn(torch.nn.Module):
159159
def __init__(self, activation: Optional[str]) -> None:
160160
super().__init__()
161-
self.activation: str = activation if activation is not None else "linear"
162-
if self.activation.lower().startswith(
163-
"silut"
164-
) or self.activation.lower().startswith("custom_silu"):
165-
threshold = (
166-
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
167-
)
161+
activation = activation.lower() if activation is not None else "linear"
162+
if activation == "relu":
163+
self.act_module = torch.nn.ReLU()
164+
elif activation == "gelu" or activation == "gelu_tf":
165+
self.act_module = torch.nn.GELU(approximate="tanh")
166+
elif activation == "tanh":
167+
self.act_module = torch.nn.Tanh()
168+
elif activation == "relu6":
169+
self.act_module = torch.nn.ReLU6()
170+
elif activation == "softplus":
171+
self.act_module = torch.nn.Softplus()
172+
elif activation == "sigmoid":
173+
self.act_module = torch.nn.Sigmoid()
174+
elif activation == "silu":
175+
self.act_module = torch.nn.SiLU()
176+
elif activation.startswith("silut") or activation.startswith("custom_silu"):
177+
threshold = float(activation.split(":")[-1]) if ":" in activation else 3.0
168178
if env.CUSTOM_OP_USE_JIT:
169179
# for efficient training but can not be jit
170-
self.silut = SiLUTScript(threshold=threshold)
180+
self.act_module = SiLUTScript(threshold=threshold)
171181
else:
172182
# for jit freeze
173-
self.silut = SiLUT(threshold=threshold)
183+
self.act_module = SiLUT(threshold=threshold)
184+
elif activation == "linear" or activation == "none":
185+
self.act_module = torch.nn.Identity()
174186
else:
175-
self.silut = None
187+
raise RuntimeError(f"activation function {self.activation} not supported")
176188

177189
def forward(self, x: torch.Tensor) -> torch.Tensor:
178-
"""Returns the tensor after applying activation function corresponding to `activation`."""
179-
# See jit supported types: https://pytorch.org/docs/stable/jit_language_reference.html#supported-type
180-
181-
if self.activation.lower() == "relu":
182-
return F.relu(x)
183-
elif self.activation.lower() == "gelu" or self.activation.lower() == "gelu_tf":
184-
return F.gelu(x, approximate="tanh")
185-
elif self.activation.lower() == "tanh":
186-
return torch.tanh(x)
187-
elif self.activation.lower() == "relu6":
188-
return F.relu6(x)
189-
elif self.activation.lower() == "softplus":
190-
return F.softplus(x)
191-
elif self.activation.lower() == "sigmoid":
192-
return torch.sigmoid(x)
193-
elif self.activation.lower() == "silu":
194-
return F.silu(x)
195-
elif self.activation.lower().startswith(
196-
"silut"
197-
) or self.activation.lower().startswith("custom_silu"):
198-
assert self.silut is not None
199-
return self.silut(x)
200-
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
201-
return x
202-
else:
203-
raise RuntimeError(f"activation function {self.activation} not supported")
190+
"""Returns the tensor after applying activation function."""
191+
return self.act_module(x)
204192

205193

206194
@overload

0 commit comments

Comments
 (0)