Skip to content

Commit 7c62be2

Browse files
committed
add ut
1 parent 68b7e25 commit 7c62be2

5 files changed

Lines changed: 80 additions & 4 deletions

File tree

deepmd/pt/infer/deep_eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def __init__(
130130
] = state_dict[item].clone()
131131
state_dict = state_dict_head
132132
model = get_model(self.input_param).to(DEVICE)
133-
# TODO fix jit
134133
if not self.input_param.get("hessian_mode"):
135134
model = torch.jit.script(model)
136135
self.dp = ModelWrapper(model)

deepmd/pt/model/task/fitting.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,9 +617,22 @@ def _forward_common(
617617
) # Shape is [nframes, natoms[0], net_dim_out]
618618
else:
619619
if self.eval_return_middle_output:
620-
raise NotImplementedError(
621-
"Middle output is only supported for mixed types!"
622-
)
620+
outs_middle = torch.zeros(
621+
(nf, nloc, self.neuron[-1]),
622+
dtype=self.prec,
623+
device=descriptor.device,
624+
) # jit assertion
625+
for type_i, ll in enumerate(self.filter_layers.networks):
626+
mask = (atype == type_i).unsqueeze(-1)
627+
mask = torch.tile(mask, (1, 1, net_dim_out))
628+
middle_output_type = ll.call_until_last(xx)
629+
middle_output_type = torch.where(
630+
torch.tile(mask, (1, 1, self.neuron[-1])),
631+
middle_output_type,
632+
0.0,
633+
)
634+
outs_middle = outs_middle + middle_output_type
635+
results["middle_output"] = outs_middle
623636
for type_i, ll in enumerate(self.filter_layers.networks):
624637
mask = (atype == type_i).unsqueeze(-1)
625638
mask = torch.tile(mask, (1, 1, net_dim_out))

source/tests/infer/case.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def __init__(self, data: dict) -> None:
125125
else:
126126
self.descriptor = None
127127

128+
if "fit_ll" in data:
129+
self.fit_ll = np.array(data["fit_ll"], dtype=np.float64).reshape(
130+
self.nloc, -1
131+
)
132+
else:
133+
self.fit_ll = None
134+
128135

129136
class Case:
130137
"""Test case.

source/tests/infer/deeppot-testcase.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,45 @@ results:
350350
1.391094495316195001e+00,
351351
7.036614101584164338e-01,
352352
]
353+
fit_ll:
354+
[
355+
-1.930622006643730598e-02,
356+
7.105172146387829235e-01,
357+
8.063835335367619539e-01,
358+
-8.414936892447275607e-01,
359+
1.076881365346436414e+00,
360+
-5.058153291569045251e-01,
361+
-3.104797373867691779e-02,
362+
7.915138025598530414e-01,
363+
8.704498369678651537e-01,
364+
-9.394329433114724237e-01,
365+
1.081177674358831053e+00,
366+
-5.122829163516022799e-01,
367+
5.307913125575804136e-03,
368+
7.644783775007328863e-01,
369+
8.548853566716824171e-01,
370+
-9.264496186379944653e-01,
371+
1.087178488222722672e+00,
372+
-4.893627623467682874e-01,
373+
-1.098746804357388085e-01,
374+
8.092546382430507723e-01,
375+
8.757043853926992361e-01,
376+
-9.036627000544070754e-01,
377+
1.064706190677472852e+00,
378+
-5.670533963064982030e-01,
379+
-1.270062329805081158e-01,
380+
8.618261193779762630e-01,
381+
8.979592934126284787e-01,
382+
-9.939941754957831721e-01,
383+
1.072078883192923771e+00,
384+
-5.780043831847785363e-01,
385+
-8.617331266742107865e-02,
386+
8.388158674801169390e-01,
387+
8.904977456468012864e-01,
388+
-9.751383339999978306e-01,
389+
1.075378146084344344e+00,
390+
-5.508880199511664300e-01,
391+
]
353392
- coord:
354393
[
355394
12.83,

source/tests/infer/test_models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,24 @@ def test_descriptor(self) -> None:
164164
expected_descpt = result.descriptor
165165
np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel())
166166

167+
def test_fitting_last_layer(self) -> None:
168+
_, extension = self.param
169+
if extension == ".pb":
170+
self.skipTest("fitting_last_layer not supported for TensorFlow models")
171+
for ii, result in enumerate(self.case.results):
172+
if result.fit_ll is None:
173+
continue
174+
fit_ll = self.dp.eval_fitting_last_layer(
175+
result.coord, result.box, result.atype
176+
)
177+
expected_fit_ll = result.fit_ll
178+
np.testing.assert_almost_equal(fit_ll.ravel(), expected_fit_ll.ravel())
179+
fit_ll = self.dp.eval_fitting_last_layer(
180+
result.coord, result.box, result.atype
181+
)
182+
expected_fit_ll = result.fit_ll
183+
np.testing.assert_almost_equal(fit_ll.ravel(), expected_fit_ll.ravel())
184+
167185
def test_2frame_atm(self) -> None:
168186
for ii, result in enumerate(self.case.results):
169187
coords2 = np.concatenate((result.coord, result.coord))

0 commit comments

Comments
 (0)