Skip to content

Commit 26a73cf

Browse files
author
Han Wang
committed
dpa1 compatible with torch.export
1 parent a4729b0 commit 26a73cf

5 files changed

Lines changed: 32 additions & 27 deletions

File tree

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,6 @@ def forward(
686686
self.filter_neuron[-1],
687687
self.is_sorted,
688688
)[0]
689-
# to make torchscript happy
690689
gg = torch.empty(
691690
nframes,
692691
nloc,

source/tests/pt/model/test_dpa1.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_consistency(
108108
err_msg=err_msg,
109109
)
110110

111-
def test_jit(
111+
def test_export(
112112
self,
113113
) -> None:
114114
rng = np.random.default_rng(GLOBAL_SEED)
@@ -132,8 +132,6 @@ def test_jit(
132132
[False, True], # use_econf_tebd
133133
):
134134
dtype = PRECISION_DICT[prec]
135-
rtol, atol = get_tols(prec)
136-
err_msg = f"idt={idt} prec={prec}"
137135
# dpa1 new impl
138136
dd0 = DescrptDPA1(
139137
self.rcut,
@@ -151,6 +149,9 @@ def test_jit(
151149
)
152150
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
153151
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
154-
# dd1 = DescrptDPA1.deserialize(dd0.serialize())
155-
model = torch.jit.script(dd0)
156-
# model = torch.jit.script(dd1)
152+
153+
coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE)
154+
atype_ext = torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE)
155+
nlist = torch.tensor(self.nlist, dtype=int, device=env.DEVICE)
156+
157+
_ = torch.export.export(dd0, (coord_ext, atype_ext, nlist))

source/tests/pt/model/test_export.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131

3232
from .test_permutation import (
33+
model_dpa1,
3334
model_se_e2_a,
3435
)
3536

@@ -167,6 +168,22 @@ def tearDown(self) -> None:
167168
os.remove(f)
168169

169170

171+
class TestEnergyModelDPA1IntegrationExport(unittest.TestCase, ExportIntegrationTest):
172+
def setUp(self) -> None:
173+
input_json = str(Path(__file__).parent / "water/se_atten.json")
174+
with open(input_json) as f:
175+
self.config = json.load(f)
176+
data_file = [str(Path(__file__).parent / "water/data/data_0")]
177+
self.config["training"]["training_data"]["systems"] = data_file
178+
self.config["training"]["validation_data"]["systems"] = data_file
179+
self.config["model"] = deepcopy(model_dpa1)
180+
self.config["training"]["numb_steps"] = 2
181+
self.config["training"]["save_freq"] = 2
182+
183+
def tearDown(self) -> None:
184+
ExportIntegrationTest.tearDown(self)
185+
186+
170187
class TestEnergyModelSeAIntegrationExport(unittest.TestCase, ExportIntegrationTest):
171188
def setUp(self) -> None:
172189
input_json = str(Path(__file__).parent / "water/se_atten.json")

source/tests/pt/model/test_jit.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from .test_permutation import (
2323
model_dos,
24-
model_dpa1,
2524
model_dpa2,
2625
model_hybrid,
2726
)
@@ -62,22 +61,6 @@ def tearDown(self) -> None:
6261
JITTest.tearDown(self)
6362

6463

65-
class TestEnergyModelDPA1(unittest.TestCase, JITTest):
66-
def setUp(self) -> None:
67-
input_json = str(Path(__file__).parent / "water/se_atten.json")
68-
with open(input_json) as f:
69-
self.config = json.load(f)
70-
data_file = [str(Path(__file__).parent / "water/data/data_0")]
71-
self.config["training"]["training_data"]["systems"] = data_file
72-
self.config["training"]["validation_data"]["systems"] = data_file
73-
self.config["model"] = deepcopy(model_dpa1)
74-
self.config["training"]["numb_steps"] = 10
75-
self.config["training"]["save_freq"] = 10
76-
77-
def tearDown(self) -> None:
78-
JITTest.tearDown(self)
79-
80-
8164
class TestEnergyModelDPA2(unittest.TestCase, JITTest):
8265
def setUp(self) -> None:
8366
input_json = str(Path(__file__).parent / "water/se_atten.json")

source/tests/pt/model/test_se_atten_v2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_consistency(
104104
err_msg=err_msg,
105105
)
106106

107-
def test_jit(
107+
def test_export(
108108
self,
109109
) -> None:
110110
rng = np.random.default_rng()
@@ -140,5 +140,10 @@ def test_jit(
140140
seed=GLOBAL_SEED,
141141
)
142142
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
143-
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
144-
_ = torch.jit.script(dd0)
143+
dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
144+
145+
coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE)
146+
atype_ext = torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE)
147+
nlist = torch.tensor(self.nlist, dtype=int, device=env.DEVICE)
148+
149+
_ = torch.export.export(dd0, (coord_ext, atype_ext, nlist))

0 commit comments

Comments
 (0)