Skip to content

Commit 2098262

Browse files
committed
feat(pt): cherry-pick fitting ll output
1 parent 563b2b2 commit 2098262

10 files changed

Lines changed: 319 additions & 5 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,25 @@ def call(self, x):
572572
x = layer(x)
573573
return x
574574

575+
def call_until_last(self, x):
576+
"""Return the output before last layer.
577+
578+
Parameters
579+
----------
580+
x : np.ndarray
581+
The input.
582+
583+
Returns
584+
-------
585+
np.ndarray
586+
The output before last layer.
587+
"""
588+
# avoid slice (self.layers[:-1]) for jit
589+
for ii, layer in enumerate(self.layers):
590+
if ii < len(self.layers) - 1:
591+
x = layer(x)
592+
return x
593+
575594
def clear(self) -> None:
576595
"""Clear the network parameters to zero."""
577596
for layer in self.layers:

deepmd/infer/deep_eval.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,48 @@ def eval_descriptor(
217217
"""
218218
raise NotImplementedError
219219

220+
def eval_fitting_last_layer(
221+
self,
222+
coords: np.ndarray,
223+
cells: Optional[np.ndarray],
224+
atom_types: np.ndarray,
225+
fparam: Optional[np.ndarray] = None,
226+
aparam: Optional[np.ndarray] = None,
227+
**kwargs: Any,
228+
) -> np.ndarray:
229+
"""Evaluate fitting before last layer by using this DP.
230+
231+
Parameters
232+
----------
233+
coords
234+
The coordinates of atoms.
235+
The array should be of size nframes x natoms x 3
236+
cells
237+
The cell of the region.
238+
If None then non-PBC is assumed, otherwise using PBC.
239+
The array should be of size nframes x 9
240+
atom_types
241+
The atom types
242+
The list should contain natoms ints
243+
fparam
244+
The frame parameter.
245+
The array can be of size :
246+
- nframes x dim_fparam.
247+
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
248+
aparam
249+
The atomic parameter
250+
The array can be of size :
251+
- nframes x natoms x dim_aparam.
252+
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
253+
- dim_aparam. Then all frames and atoms are provided with the same aparam.
254+
255+
Returns
256+
-------
257+
fitting
258+
Fitting output before last layer.
259+
"""
260+
raise NotImplementedError
261+
220262
def eval_typeebd(self) -> np.ndarray:
221263
"""Evaluate output of type embedding network by using this model.
222264
@@ -460,6 +502,73 @@ def eval_descriptor(
460502
)
461503
return descriptor
462504

505+
def eval_fitting_last_layer(
506+
self,
507+
coords: np.ndarray,
508+
cells: Optional[np.ndarray],
509+
atom_types: np.ndarray,
510+
fparam: Optional[np.ndarray] = None,
511+
aparam: Optional[np.ndarray] = None,
512+
mixed_type: bool = False,
513+
**kwargs: Any,
514+
) -> np.ndarray:
515+
"""Evaluate fitting before last layer by using this DP.
516+
517+
Parameters
518+
----------
519+
coords
520+
The coordinates of atoms.
521+
The array should be of size nframes x natoms x 3
522+
cells
523+
The cell of the region.
524+
If None then non-PBC is assumed, otherwise using PBC.
525+
The array should be of size nframes x 9
526+
atom_types
527+
The atom types
528+
The list should contain natoms ints
529+
fparam
530+
The frame parameter.
531+
The array can be of size :
532+
- nframes x dim_fparam.
533+
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
534+
aparam
535+
The atomic parameter
536+
The array can be of size :
537+
- nframes x natoms x dim_aparam.
538+
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
539+
- dim_aparam. Then all frames and atoms are provided with the same aparam.
540+
efield
541+
The external field on atoms.
542+
The array should be of size nframes x natoms x 3
543+
mixed_type
544+
Whether to perform the mixed_type mode.
545+
If True, the input data has the mixed_type format (see doc/model/train_se_atten.md),
546+
in which frames in a system may have different natoms_vec(s), with the same nloc.
547+
548+
Returns
549+
-------
550+
fitting
551+
Fitting output before last layer.
552+
"""
553+
(
554+
coords,
555+
cells,
556+
atom_types,
557+
fparam,
558+
aparam,
559+
nframes,
560+
natoms,
561+
) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type)
562+
fitting = self.deep_eval.eval_fitting_last_layer(
563+
coords,
564+
cells,
565+
atom_types,
566+
fparam=fparam,
567+
aparam=aparam,
568+
**kwargs,
569+
)
570+
return fitting
571+
463572
def eval_typeebd(self) -> np.ndarray:
464573
"""Evaluate output of type embedding network by using this model.
465574

deepmd/pt/infer/deep_eval.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,58 @@ def eval_descriptor(
682682
descriptor = model.eval_descriptor()
683683
model.set_eval_descriptor_hook(False)
684684
return to_numpy_array(descriptor)
685+
686+
def eval_fitting_last_layer(
687+
self,
688+
coords: np.ndarray,
689+
cells: Optional[np.ndarray],
690+
atom_types: np.ndarray,
691+
fparam: Optional[np.ndarray] = None,
692+
aparam: Optional[np.ndarray] = None,
693+
**kwargs: Any,
694+
) -> np.ndarray:
695+
"""Evaluate fitting before last layer by using this DP.
696+
697+
Parameters
698+
----------
699+
coords
700+
The coordinates of atoms.
701+
The array should be of size nframes x natoms x 3
702+
cells
703+
The cell of the region.
704+
If None then non-PBC is assumed, otherwise using PBC.
705+
The array should be of size nframes x 9
706+
atom_types
707+
The atom types
708+
The list should contain natoms ints
709+
fparam
710+
The frame parameter.
711+
The array can be of size :
712+
- nframes x dim_fparam.
713+
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
714+
aparam
715+
The atomic parameter
716+
The array can be of size :
717+
- nframes x natoms x dim_aparam.
718+
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
719+
- dim_aparam. Then all frames and atoms are provided with the same aparam.
720+
721+
Returns
722+
-------
723+
fitting
724+
Fitting output before last layer.
725+
"""
726+
model = self.dp.model["Default"]
727+
model.set_eval_fitting_last_layer_hook(True)
728+
self.eval(
729+
coords,
730+
cells,
731+
atom_types,
732+
atomic=False,
733+
fparam=fparam,
734+
aparam=aparam,
735+
**kwargs,
736+
)
737+
fitting_net = model.eval_fitting_last_layer()
738+
model.set_eval_fitting_last_layer_hook(False)
739+
return to_numpy_array(fitting_net)

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ def __init__(
6262
self.fitting_net = fitting
6363
super().init_out_stat()
6464
self.enable_eval_descriptor_hook = False
65+
self.enable_eval_fitting_last_layer_hook = False
6566
self.eval_descriptor_list = []
67+
self.eval_fitting_last_layer_list = []
6668

6769
eval_descriptor_list: list[torch.Tensor]
70+
eval_fitting_last_layer_list: list[torch.Tensor]
6871

6972
def set_eval_descriptor_hook(self, enable: bool) -> None:
7073
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
@@ -75,6 +78,17 @@ def eval_descriptor(self) -> torch.Tensor:
7578
"""Evaluate the descriptor."""
7679
return torch.concat(self.eval_descriptor_list)
7780

81+
def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
82+
"""Set the hook for evaluating fitting last layer output and clear the cache for fitting last layer output list."""
83+
self.enable_eval_fitting_last_layer_hook = enable
84+
self.fitting_net.set_return_middle_output(enable)
85+
# = [] does not work; See #4533
86+
self.eval_fitting_last_layer_list.clear()
87+
88+
def eval_fitting_last_layer(self) -> torch.Tensor:
89+
"""Evaluate the fitting last layer output."""
90+
return torch.concat(self.eval_fitting_last_layer_list)
91+
7892
@torch.jit.export
7993
def fitting_output_def(self) -> FittingOutputDef:
8094
"""Get the output def of the fitting net."""
@@ -286,6 +300,13 @@ def forward_atomic(
286300
angle_index=add_input.get("angle_index", None),
287301
a_sw=add_input.get("a_sw", None),
288302
)
303+
if self.enable_eval_fitting_last_layer_hook:
304+
assert (
305+
"middle_output" in fit_ret
306+
), "eval_fitting_last_layer not supported for this fitting net!"
307+
self.eval_fitting_last_layer_list.append(
308+
fit_ret.pop("middle_output").detach()
309+
)
289310
return fit_ret
290311

291312
def get_out_bias(self) -> torch.Tensor:

deepmd/pt/model/model/dp_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,13 @@ def set_eval_descriptor_hook(self, enable: bool) -> None:
6464
def eval_descriptor(self) -> torch.Tensor:
6565
"""Evaluate the descriptor."""
6666
return self.atomic_model.eval_descriptor()
67+
68+
@torch.jit.export
69+
def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
70+
"""Set the hook for evaluating fitting_last_layer and clear the cache for fitting_last_layer list."""
71+
self.atomic_model.set_eval_fitting_last_layer_hook(enable)
72+
73+
@torch.jit.export
74+
def eval_fitting_last_layer(self) -> torch.Tensor:
75+
"""Evaluate the fitting_last_layer."""
76+
return self.atomic_model.eval_fitting_last_layer()

deepmd/pt/model/task/fitting.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ def __init__(
262262
for param in self.parameters():
263263
param.requires_grad = self.trainable
264264

265+
self.eval_return_middle_output = False
266+
265267
def reinit_exclude(
266268
self,
267269
exclude_types: list[int] = [],
@@ -386,6 +388,9 @@ def set_case_embd(self, case_idx: int):
386388
case_idx
387389
]
388390

391+
def set_return_middle_output(self, return_middle_output: bool = True) -> None:
392+
self.eval_return_middle_output = return_middle_output
393+
389394
def __setitem__(self, key, value) -> None:
390395
if key in ["bias_atom_e"]:
391396
value = value.view([self.ntypes, self._net_out_dim()])
@@ -540,14 +545,37 @@ def _forward_common(
540545
dtype=self.prec,
541546
device=descriptor.device,
542547
) # jit assertion
548+
results = {}
549+
543550
if self.mixed_types:
544551
atom_property = self.filter_layers.networks[0](xx)
552+
if self.eval_return_middle_output:
553+
results["middle_output"] = self.filter_layers.networks[
554+
0
555+
].call_until_last(xx)
545556
if xx_zeros is not None:
546557
atom_property -= self.filter_layers.networks[0](xx_zeros)
547558
outs = (
548559
outs + atom_property + self.bias_atom_e[atype].to(self.prec)
549560
) # Shape is [nframes, natoms[0], net_dim_out]
550561
else:
562+
if self.eval_return_middle_output:
563+
outs_middle = torch.zeros(
564+
(nf, nloc, self.neuron[-1]),
565+
dtype=self.prec,
566+
device=descriptor.device,
567+
) # jit assertion
568+
for type_i, ll in enumerate(self.filter_layers.networks):
569+
mask = (atype == type_i).unsqueeze(-1)
570+
mask = torch.tile(mask, (1, 1, net_dim_out))
571+
middle_output_type = ll.call_until_last(xx)
572+
middle_output_type = torch.where(
573+
torch.tile(mask, (1, 1, self.neuron[-1])),
574+
middle_output_type,
575+
0.0,
576+
)
577+
outs_middle = outs_middle + middle_output_type
578+
results["middle_output"] = outs_middle
551579
for type_i, ll in enumerate(self.filter_layers.networks):
552580
mask = (atype == type_i).unsqueeze(-1)
553581
mask = torch.tile(mask, (1, 1, net_dim_out))
@@ -569,4 +597,5 @@ def _forward_common(
569597
mask = self.emask(atype).to(torch.bool)
570598
# nf x nloc x nod
571599
outs = torch.where(mask[:, :, None], outs, 0.0)
572-
return {self.var_name: outs}
600+
results.update({self.var_name: outs})
601+
return results

deepmd/pt/model/task/invar_fitting.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,17 @@ def forward(
189189
-------
190190
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
191191
"""
192-
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
193-
self.var_name
194-
]
195-
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
192+
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)
193+
result = {self.var_name: out[self.var_name].to(env.GLOBAL_PT_FLOAT_PRECISION)}
194+
if "middle_output" in out:
195+
result.update(
196+
{
197+
"middle_output": out["middle_output"].to(
198+
env.GLOBAL_PT_FLOAT_PRECISION
199+
)
200+
}
201+
)
202+
return result
196203

197204
# make jit happy with torch 2.0.0
198205
exclude_types: list[int]

source/tests/infer/case.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def __init__(self, data: dict) -> None:
125125
else:
126126
self.descriptor = None
127127

128+
if "fit_ll" in data:
129+
self.fit_ll = np.array(data["fit_ll"], dtype=np.float64).reshape(
130+
self.nloc, -1
131+
)
132+
else:
133+
self.fit_ll = None
134+
128135

129136
class Case:
130137
"""Test case.

0 commit comments

Comments
 (0)