Skip to content

Commit 69dc9c1

Browse files
author
Han Wang
committed
test: add torch.export test for silut activation
1 parent 6aa4f0b commit 69dc9c1

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

source/tests/pt_expt/utils/test_activation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,18 @@ def test_silut_above_threshold_is_tanh_branch(self) -> None:
103103
np.testing.assert_allclose(
104104
result.detach().numpy(), expected.detach().numpy(), rtol=1e-14, atol=1e-14
105105
)
106+
107+
def test_silut_export(self) -> None:
108+
"""torch.export.export can trace through silut activation."""
109+
110+
class SilutModule(torch.nn.Module):
111+
def forward(self, x: torch.Tensor) -> torch.Tensor:
112+
return _torch_activation(x, "silut:10.0")
113+
114+
mod = SilutModule()
115+
exported = torch.export.export(mod, (self.x_torch,))
116+
result = exported.module()(self.x_torch)
117+
expected = _torch_activation(self.x_torch, "silut:10.0")
118+
np.testing.assert_allclose(
119+
result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12
120+
)

0 commit comments

Comments
 (0)