Skip to content

Commit 122c006

Browse files
committed
feat(pt): add direct fitting
1 parent 0f313d0 commit 122c006

11 files changed

Lines changed: 351 additions & 25 deletions

File tree

deepmd/dpmodel/fitting/make_base_fitting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def compute_output_stats(self, merged) -> NoReturn:
6767
"""Update the output bias for fitting net."""
6868
raise NotImplementedError
6969

70+
def need_additional_input(self) -> bool:
71+
return False
72+
7073
@abstractmethod
7174
def get_type_map(self) -> list[str]:
7275
"""Get the name to each type of atoms."""

deepmd/pt/infer/deep_eval.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,10 +467,19 @@ def _eval_model(
467467
out = batch_output[pt_name].reshape(shape).detach().cpu().numpy()
468468
results.append(out)
469469
else:
470-
shape = self._get_output_shape(odef, nframes, natoms)
471-
results.append(
472-
np.full(np.abs(shape), np.nan, dtype=prec)
473-
) # this is kinda hacky
470+
if (
471+
self._OUTDEF_DP2BACKEND[odef.name] == "force"
472+
and "dforce" in batch_output
473+
):
474+
# if no force, use dforce if possible
475+
shape = self._get_output_shape(odef, nframes, natoms)
476+
out = batch_output["dforce"].reshape(shape).detach().cpu().numpy()
477+
results.append(out)
478+
else:
479+
shape = self._get_output_shape(odef, nframes, natoms)
480+
results.append(
481+
np.full(np.abs(shape), np.nan, dtype=prec)
482+
) # this is kinda hacky
474483
return tuple(results)
475484

476485
def _eval_model_spin(

deepmd/pt/loss/ener.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
187187
Other losses for display.
188188
"""
189189
model_pred = model(**input_dict)
190+
191+
if "force" not in model_pred and "dforce" in model_pred:
192+
model_pred["force"] = model_pred["dforce"]
190193
coef = learning_rate / self.starter_learning_rate
191194
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
192195
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,30 @@ def forward_atomic(
245245
if self.enable_eval_descriptor_hook:
246246
self.eval_descriptor_list.append(descriptor)
247247
# energy, force
248-
fit_ret = self.fitting_net(
249-
descriptor,
250-
atype,
251-
gr=rot_mat,
252-
g2=g2,
253-
h2=h2,
254-
fparam=fparam,
255-
aparam=aparam,
256-
)
248+
if not self.fitting_net.need_additional_input():
249+
fit_ret = self.fitting_net(
250+
descriptor,
251+
atype,
252+
gr=rot_mat,
253+
g2=g2,
254+
h2=h2,
255+
fparam=fparam,
256+
aparam=aparam,
257+
)
258+
else:
259+
add_input = self.descriptor.get_additional_output_for_fitting()
260+
fit_ret = self.fitting_net(
261+
descriptor,
262+
atype,
263+
gr=rot_mat,
264+
g2=g2,
265+
h2=h2,
266+
fparam=fparam,
267+
aparam=aparam,
268+
diff=add_input["diff"],
269+
edge_index=add_input["edge_index"],
270+
sw=add_input["sw"],
271+
)
257272
return fit_ret
258273

259274
def get_out_bias(self) -> torch.Tensor:

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def get_rcut(self) -> float:
252252
"""Returns the cut-off radius."""
253253
return self.rcut
254254

255+
def get_additional_output_for_fitting(self):
256+
return self.repflows.get_additional_output_for_fitting()
257+
255258
def get_rcut_smth(self) -> float:
256259
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
257260
return self.rcut_smth

deepmd/pt/model/descriptor/repflows.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def __init__(
447447
)
448448
)
449449
self.layers = torch.nn.ModuleList(layers)
450+
self.additional_output_for_fitting: dict[str, Optional[torch.Tensor]] = {}
450451

451452
wanted_shape = (self.ntypes, self.nnei, 4)
452453
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
@@ -461,6 +462,8 @@ def get_rcut(self) -> float:
461462
"""Returns the cut-off radius."""
462463
return self.e_rcut
463464

465+
additional_output_for_fitting: dict[str, Optional[torch.Tensor]]
466+
464467
def get_rcut_smth(self) -> float:
465468
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
466469
return self.e_rcut_smth
@@ -548,6 +551,9 @@ def reinit_exclude(
548551
self.exclude_types = exclude_types
549552
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
550553

554+
def get_additional_output_for_fitting(self):
555+
return self.additional_output_for_fitting
556+
551557
def forward(
552558
self,
553559
nlist: torch.Tensor,
@@ -782,6 +788,8 @@ def forward(
782788
sw = sw[nlist_mask]
783789
# n_edge x 4
784790
dmatrix = dmatrix[nlist_mask]
791+
# n_edge x 3
792+
diff = diff[nlist_mask]
785793

786794
if self.edge_use_esen_atom_ebd:
787795
assert source_type is not None
@@ -809,12 +817,16 @@ def forward(
809817
* d_sw[:, :, None, :, None]
810818
* d_sw[:, :, None, None, :]
811819
)[d_nlist_mask]
820+
self.additional_output_for_fitting["edge_index"] = edge_index
812821
else:
813822
# avoid jit assertion
814823
edge_index = angle_index = torch.zeros(
815824
[1, 3], device=nlist.device, dtype=nlist.dtype
816825
)
817826
dihedral_index = None
827+
self.additional_output_for_fitting["edge_index"] = None
828+
self.additional_output_for_fitting["diff"] = diff
829+
self.additional_output_for_fitting["sw"] = sw
818830
# get edge and angle embedding
819831
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
820832
if self.edge_use_esen_rbf:

deepmd/pt/model/model/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ def _get_standard_model_components(model_params, ntypes):
9090
fitting_net["ntypes"] = descriptor.get_ntypes()
9191
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
9292
fitting_net["mixed_types"] = descriptor.mixed_types()
93-
if fitting_net["type"] in ["dipole", "polar"]:
93+
if fitting_net["type"] in ["dipole", "polar", "ener_direct"]:
9494
fitting_net["embedding_width"] = descriptor.get_dim_emb()
9595
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
96-
grad_force = "direct" not in fitting_net["type"]
97-
if not grad_force:
98-
fitting_net["out_dim"] = descriptor.get_dim_emb()
99-
if "ener" in fitting_net["type"]:
100-
fitting_net["return_energy"] = True
96+
# grad_force = "direct" not in fitting_net["type"]
97+
# if not grad_force:
98+
# fitting_net["out_dim"] = descriptor.get_dim_emb()
99+
# if "ener" in fitting_net["type"]:
100+
# fitting_net["return_energy"] = True
101101
fitting = BaseFitting(**fitting_net)
102102
return descriptor, fitting, fitting_net["type"]
103103

@@ -261,7 +261,7 @@ def get_standard_model(model_params):
261261
modelcls = PolarModel
262262
elif fitting_net_type == "dos":
263263
modelcls = DOSModel
264-
elif fitting_net_type in ["ener", "direct_force_ener"]:
264+
elif fitting_net_type in ["ener", "direct_force_ener", "ener_direct"]:
265265
modelcls = EnergyModel
266266
elif fitting_net_type == "property":
267267
modelcls = PropertyModel

deepmd/pt/model/model/ener_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def forward(
8181
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
8282
-3
8383
)
84-
else:
85-
model_predict["force"] = model_ret["dforce"]
84+
if "dforce" in model_ret:
85+
model_predict["dforce"] = model_ret["dforce"]
8686
if "mask" in model_ret:
8787
model_predict["mask"] = model_ret["mask"]
8888
else:
@@ -119,15 +119,16 @@ def forward_lower(
119119
model_predict["energy"] = model_ret["energy_redu"]
120120
if self.do_grad_r("energy"):
121121
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
122+
else:
123+
assert model_ret["dforce"] is not None
124+
model_predict["dforce"] = model_ret["dforce"]
125+
122126
if self.do_grad_c("energy"):
123127
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
124128
if do_atomic_virial:
125129
model_predict["extended_virial"] = model_ret[
126130
"energy_derv_c"
127131
].squeeze(-3)
128-
else:
129-
assert model_ret["dforce"] is not None
130-
model_predict["dforce"] = model_ret["dforce"]
131132
else:
132133
model_predict = model_ret
133134
return model_predict

0 commit comments

Comments
 (0)