Skip to content

Commit 5450066

Browse files
authored
feat(pt): add hook to last fitting layer output (#4789)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added the ability to evaluate and retrieve the output of the last hidden layer in fitting neural networks, providing access to intermediate model outputs. - Extended evaluation interfaces to support fetching intermediate fitting outputs for both standard and mixed-type models. - **Improvements** - Enhanced output dictionaries to optionally include intermediate network outputs when enabled, allowing for more detailed inspection during evaluation. - **Tests** - Introduced tests to verify correctness and consistency of fitting last layer evaluations across supported model types. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 4ad67d5 commit 5450066

10 files changed

Lines changed: 319 additions & 5 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,25 @@ def call(self, x):
642642
x = layer(x)
643643
return x
644644

645+
def call_until_last(self, x):
646+
"""Return the output before last layer.
647+
648+
Parameters
649+
----------
650+
x : np.ndarray
651+
The input.
652+
653+
Returns
654+
-------
655+
np.ndarray
656+
The output before last layer.
657+
"""
658+
# avoid slice (self.layers[:-1]) for jit
659+
for ii, layer in enumerate(self.layers):
660+
if ii < len(self.layers) - 1:
661+
x = layer(x)
662+
return x
663+
645664
def clear(self) -> None:
646665
"""Clear the network parameters to zero."""
647666
for layer in self.layers:

deepmd/infer/deep_eval.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,48 @@ def eval_descriptor(
215215
"""
216216
raise NotImplementedError
217217

218+
def eval_fitting_last_layer(
219+
self,
220+
coords: np.ndarray,
221+
cells: Optional[np.ndarray],
222+
atom_types: np.ndarray,
223+
fparam: Optional[np.ndarray] = None,
224+
aparam: Optional[np.ndarray] = None,
225+
**kwargs: Any,
226+
) -> np.ndarray:
227+
"""Evaluate fitting before last layer by using this DP.
228+
229+
Parameters
230+
----------
231+
coords
232+
The coordinates of atoms.
233+
The array should be of size nframes x natoms x 3
234+
cells
235+
The cell of the region.
236+
If None then non-PBC is assumed, otherwise using PBC.
237+
The array should be of size nframes x 9
238+
atom_types
239+
The atom types
240+
The list should contain natoms ints
241+
fparam
242+
The frame parameter.
243+
The array can be of size :
244+
- nframes x dim_fparam.
245+
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
246+
aparam
247+
The atomic parameter
248+
The array can be of size :
249+
- nframes x natoms x dim_aparam.
250+
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
251+
- dim_aparam. Then all frames and atoms are provided with the same aparam.
252+
253+
Returns
254+
-------
255+
fitting
256+
Fitting output before last layer.
257+
"""
258+
raise NotImplementedError
259+
218260
def eval_typeebd(self) -> np.ndarray:
219261
"""Evaluate output of type embedding network by using this model.
220262
@@ -467,6 +509,73 @@ def eval_descriptor(
467509
)
468510
return descriptor
469511

512+
def eval_fitting_last_layer(
513+
self,
514+
coords: np.ndarray,
515+
cells: Optional[np.ndarray],
516+
atom_types: np.ndarray,
517+
fparam: Optional[np.ndarray] = None,
518+
aparam: Optional[np.ndarray] = None,
519+
mixed_type: bool = False,
520+
**kwargs: Any,
521+
) -> np.ndarray:
522+
"""Evaluate fitting before last layer by using this DP.
523+
524+
Parameters
525+
----------
526+
coords
527+
The coordinates of atoms.
528+
The array should be of size nframes x natoms x 3
529+
cells
530+
The cell of the region.
531+
If None then non-PBC is assumed, otherwise using PBC.
532+
The array should be of size nframes x 9
533+
atom_types
534+
The atom types
535+
The list should contain natoms ints
536+
fparam
537+
The frame parameter.
538+
The array can be of size :
539+
- nframes x dim_fparam.
540+
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
541+
aparam
542+
The atomic parameter
543+
The array can be of size :
544+
- nframes x natoms x dim_aparam.
545+
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
546+
- dim_aparam. Then all frames and atoms are provided with the same aparam.
547+
efield
548+
The external field on atoms.
549+
The array should be of size nframes x natoms x 3
550+
mixed_type
551+
Whether to perform the mixed_type mode.
552+
If True, the input data has the mixed_type format (see doc/model/train_se_atten.md),
553+
in which frames in a system may have different natoms_vec(s), with the same nloc.
554+
555+
Returns
556+
-------
557+
fitting
558+
Fitting output before last layer.
559+
"""
560+
(
561+
coords,
562+
cells,
563+
atom_types,
564+
fparam,
565+
aparam,
566+
nframes,
567+
natoms,
568+
) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type)
569+
fitting = self.deep_eval.eval_fitting_last_layer(
570+
coords,
571+
cells,
572+
atom_types,
573+
fparam=fparam,
574+
aparam=aparam,
575+
**kwargs,
576+
)
577+
return fitting
578+
470579
def eval_typeebd(self) -> np.ndarray:
471580
"""Evaluate output of type embedding network by using this model.
472581

deepmd/pt/infer/deep_eval.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,58 @@ def eval_descriptor(
722722
descriptor = model.eval_descriptor()
723723
model.set_eval_descriptor_hook(False)
724724
return to_numpy_array(descriptor)
725+
726+
def eval_fitting_last_layer(
727+
self,
728+
coords: np.ndarray,
729+
cells: Optional[np.ndarray],
730+
atom_types: np.ndarray,
731+
fparam: Optional[np.ndarray] = None,
732+
aparam: Optional[np.ndarray] = None,
733+
**kwargs: Any,
734+
) -> np.ndarray:
735+
"""Evaluate fitting before last layer by using this DP.
736+
737+
Parameters
738+
----------
739+
coords
740+
The coordinates of atoms.
741+
The array should be of size nframes x natoms x 3
742+
cells
743+
The cell of the region.
744+
If None then non-PBC is assumed, otherwise using PBC.
745+
The array should be of size nframes x 9
746+
atom_types
747+
The atom types
748+
The list should contain natoms ints
749+
fparam
750+
The frame parameter.
751+
The array can be of size :
752+
- nframes x dim_fparam.
753+
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
754+
aparam
755+
The atomic parameter
756+
The array can be of size :
757+
- nframes x natoms x dim_aparam.
758+
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
759+
- dim_aparam. Then all frames and atoms are provided with the same aparam.
760+
761+
Returns
762+
-------
763+
fitting
764+
Fitting output before last layer.
765+
"""
766+
model = self.dp.model["Default"]
767+
model.set_eval_fitting_last_layer_hook(True)
768+
self.eval(
769+
coords,
770+
cells,
771+
atom_types,
772+
atomic=False,
773+
fparam=fparam,
774+
aparam=aparam,
775+
**kwargs,
776+
)
777+
fitting_net = model.eval_fitting_last_layer()
778+
model.set_eval_fitting_last_layer_hook(False)
779+
return to_numpy_array(fitting_net)

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ def __init__(
6262
self.fitting_net = fitting
6363
super().init_out_stat()
6464
self.enable_eval_descriptor_hook = False
65+
self.enable_eval_fitting_last_layer_hook = False
6566
self.eval_descriptor_list = []
67+
self.eval_fitting_last_layer_list = []
6668

6769
eval_descriptor_list: list[torch.Tensor]
70+
eval_fitting_last_layer_list: list[torch.Tensor]
6871

6972
def set_eval_descriptor_hook(self, enable: bool) -> None:
7073
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
@@ -76,6 +79,17 @@ def eval_descriptor(self) -> torch.Tensor:
7679
"""Evaluate the descriptor."""
7780
return torch.concat(self.eval_descriptor_list)
7881

82+
def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
83+
"""Set the hook for evaluating fitting last layer output and clear the cache for fitting last layer output list."""
84+
self.enable_eval_fitting_last_layer_hook = enable
85+
self.fitting_net.set_return_middle_output(enable)
86+
# = [] does not work; See #4533
87+
self.eval_fitting_last_layer_list.clear()
88+
89+
def eval_fitting_last_layer(self) -> torch.Tensor:
90+
"""Evaluate the fitting last layer output."""
91+
return torch.concat(self.eval_fitting_last_layer_list)
92+
7993
@torch.jit.export
8094
def fitting_output_def(self) -> FittingOutputDef:
8195
"""Get the output def of the fitting net."""
@@ -255,6 +269,13 @@ def forward_atomic(
255269
fparam=fparam,
256270
aparam=aparam,
257271
)
272+
if self.enable_eval_fitting_last_layer_hook:
273+
assert "middle_output" in fit_ret, (
274+
"eval_fitting_last_layer not supported for this fitting net!"
275+
)
276+
self.eval_fitting_last_layer_list.append(
277+
fit_ret.pop("middle_output").detach()
278+
)
258279
return fit_ret
259280

260281
def get_out_bias(self) -> torch.Tensor:

deepmd/pt/model/model/dp_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,13 @@ def set_eval_descriptor_hook(self, enable: bool) -> None:
6464
def eval_descriptor(self) -> torch.Tensor:
6565
"""Evaluate the descriptor."""
6666
return self.atomic_model.eval_descriptor()
67+
68+
@torch.jit.export
69+
def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
70+
"""Set the hook for evaluating fitting_last_layer and clear the cache for fitting_last_layer list."""
71+
self.atomic_model.set_eval_fitting_last_layer_hook(enable)
72+
73+
@torch.jit.export
74+
def eval_fitting_last_layer(self) -> torch.Tensor:
75+
"""Evaluate the fitting_last_layer."""
76+
return self.atomic_model.eval_fitting_last_layer()

deepmd/pt/model/task/fitting.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def __init__(
329329
for param in self.parameters():
330330
param.requires_grad = self.trainable
331331

332+
self.eval_return_middle_output = False
333+
332334
def reinit_exclude(
333335
self,
334336
exclude_types: list[int] = [],
@@ -450,6 +452,9 @@ def set_case_embd(self, case_idx: int):
450452
case_idx
451453
]
452454

455+
def set_return_middle_output(self, return_middle_output: bool = True) -> None:
456+
self.eval_return_middle_output = return_middle_output
457+
453458
def __setitem__(self, key, value) -> None:
454459
if key in ["bias_atom_e"]:
455460
value = value.view([self.ntypes, self._net_out_dim()])
@@ -598,14 +603,37 @@ def _forward_common(
598603
dtype=self.prec,
599604
device=descriptor.device,
600605
) # jit assertion
606+
results = {}
607+
601608
if self.mixed_types:
602609
atom_property = self.filter_layers.networks[0](xx)
610+
if self.eval_return_middle_output:
611+
results["middle_output"] = self.filter_layers.networks[
612+
0
613+
].call_until_last(xx)
603614
if xx_zeros is not None:
604615
atom_property -= self.filter_layers.networks[0](xx_zeros)
605616
outs = (
606617
outs + atom_property + self.bias_atom_e[atype].to(self.prec)
607618
) # Shape is [nframes, natoms[0], net_dim_out]
608619
else:
620+
if self.eval_return_middle_output:
621+
outs_middle = torch.zeros(
622+
(nf, nloc, self.neuron[-1]),
623+
dtype=self.prec,
624+
device=descriptor.device,
625+
) # jit assertion
626+
for type_i, ll in enumerate(self.filter_layers.networks):
627+
mask = (atype == type_i).unsqueeze(-1)
628+
mask = torch.tile(mask, (1, 1, net_dim_out))
629+
middle_output_type = ll.call_until_last(xx)
630+
middle_output_type = torch.where(
631+
torch.tile(mask, (1, 1, self.neuron[-1])),
632+
middle_output_type,
633+
0.0,
634+
)
635+
outs_middle = outs_middle + middle_output_type
636+
results["middle_output"] = outs_middle
609637
for type_i, ll in enumerate(self.filter_layers.networks):
610638
mask = (atype == type_i).unsqueeze(-1)
611639
mask = torch.tile(mask, (1, 1, net_dim_out))
@@ -627,4 +655,5 @@ def _forward_common(
627655
mask = self.emask(atype).to(torch.bool)
628656
# nf x nloc x nod
629657
outs = torch.where(mask[:, :, None], outs, 0.0)
630-
return {self.var_name: outs}
658+
results.update({self.var_name: outs})
659+
return results

deepmd/pt/model/task/invar_fitting.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,17 @@ def forward(
181181
-------
182182
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
183183
"""
184-
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
185-
self.var_name
186-
]
187-
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
184+
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)
185+
result = {self.var_name: out[self.var_name].to(env.GLOBAL_PT_FLOAT_PRECISION)}
186+
if "middle_output" in out:
187+
result.update(
188+
{
189+
"middle_output": out["middle_output"].to(
190+
env.GLOBAL_PT_FLOAT_PRECISION
191+
)
192+
}
193+
)
194+
return result
188195

189196
# make jit happy with torch 2.0.0
190197
exclude_types: list[int]

source/tests/infer/case.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def __init__(self, data: dict) -> None:
125125
else:
126126
self.descriptor = None
127127

128+
if "fit_ll" in data:
129+
self.fit_ll = np.array(data["fit_ll"], dtype=np.float64).reshape(
130+
self.nloc, -1
131+
)
132+
else:
133+
self.fit_ll = None
134+
128135

129136
class Case:
130137
"""Test case.

0 commit comments

Comments
 (0)