Skip to content

Commit c2d5cc7

Browse files
committed
add silut impl
1 parent 8ac8180 commit c2d5cc7

4 files changed

Lines changed: 39 additions & 1 deletion

File tree

deepmd/pt/cxx_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def load_library(module_name: str) -> bool:
8787
return False
8888

8989

90-
ENABLE_CUSTOMIZED_OP = load_library("deepmd_op_pt")
90+
ENABLE_CUSTOMIZED_OP = None#load_library("deepmd_op_pt")
9191

9292
__all__ = [
9393
"ENABLE_CUSTOMIZED_OP",

deepmd/pt/train/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,7 @@ def log_loss_valid(_task_key="Default"):
976976

977977
elapsed_batch = self.num_steps - self.start_step
978978
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
979+
log.info(msg=f"reserved memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
979980
if self.start_step >= 2 * self.disp_freq:
980981
log.info(
981982
"average training time: %.4f s/batch (exclude first %d batches)",

deepmd/pt/utils/dataloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def print_summary(
189189
name: str,
190190
prob: list[float],
191191
) -> None:
192+
return
192193
rank = dist.get_rank() if dist.is_initialized() else 0
193194
if rank == 0:
194195
print_summary(

deepmd/pt/utils/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,44 @@
1717
)
1818
from .env import PRECISION_DICT as PT_PRECISION_DICT
1919

20+
class SiLUT(torch.nn.Module):
21+
def __init__(self, threshold=3.0):
22+
super().__init__()
23+
24+
def sigmoid(x):
25+
return 1 / (1 + np.exp(-x))
26+
27+
def silu(x):
28+
return x * sigmoid(x)
29+
30+
def silu_grad(x):
31+
sig = sigmoid(x)
32+
return sig + x * sig * (1 - sig)
33+
34+
self.threshold = threshold
35+
self.slope = float(silu_grad(threshold))
36+
self.const = float(silu(threshold))
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
silu_part = F.silu(x)
40+
mask = x >= self.threshold
41+
if torch.any(mask):
42+
tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const
43+
return torch.where(x < self.threshold, silu_part, tanh_part)
44+
else:
45+
return silu_part
2046

2147
class ActivationFn(torch.nn.Module):
2248
def __init__(self, activation: Optional[str]) -> None:
2349
super().__init__()
2450
self.activation: str = activation if activation is not None else "linear"
51+
if self.activation.lower().startswith(
52+
"silut"
53+
) or self.activation.lower().startswith("custom_silu"):
54+
threshold = (
55+
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
56+
)
57+
self.silut = SiLUT(threshold=threshold)
2558

2659
def forward(self, x: torch.Tensor) -> torch.Tensor:
2760
"""Returns the tensor after applying activation function corresponding to `activation`."""
@@ -43,6 +76,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4376
return F.silu(x)
4477
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
4578
return x
79+
elif self.activation.lower().startswith("custom_silu"):
80+
assert self.silut is not None
81+
return self.silut(x)
4682
else:
4783
raise RuntimeError(f"activation function {self.activation} not supported")
4884

0 commit comments

Comments
 (0)