Skip to content

Commit ac44cf0

Browse files
YuzhiLiu-aiYuzhi
authored andcommitted
feat: add charge density prediction support
- Add DensityFittingNet and GridDensityModel for grid-based charge density - Add DPDensityAtomicModel with grid neighbor list handling - Add GridDensityLoss for training on density labels - Support density data loading (grid.npy / density.npy) - Support model inference with grid input (DeepPot / DeepEval) - Add density model argument checking (fitting_density, loss_grid_density) - Skip change_out_bias for density models (grid-based output) - Add example configs and evaluation script for density models
1 parent f39a081 commit ac44cf0

35 files changed

Lines changed: 2144 additions & 8 deletions

deepmd/infer/deep_pot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ def eval(
212212
aparam=aparam,
213213
**kwargs,
214214
)
215+
#TODO: if the grid is requested, we can directly return it without reshaping to energy, force and virial. We can also consider to return the grid in a separate key in the results dict, instead of reshaping it to energy, force and virial.
216+
if "grid" in kwargs:
217+
result = results["density"].reshape(nframes, -1)
218+
return result
219+
215220
energy = results["energy_redu"].reshape(nframes, 1)
216221
force = results["energy_derv_r"].reshape(nframes, natoms, 3)
217222
virial = results["energy_derv_c_redu"].reshape(nframes, 9)

deepmd/pt/entrypoints/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,10 @@ def train(
421421

422422
# Initialize DDP
423423
if os.environ.get("LOCAL_RANK") is not None:
424-
dist.init_process_group(backend="cuda:nccl,cpu:gloo")
424+
import datetime
425+
timeout = datetime.timedelta(seconds=18000) # set a longer timeout for for large datasets or slow file systems
426+
dist.init_process_group(backend="cuda:nccl,cpu:gloo", timeout=timeout)
427+
425428

426429
trainer = get_trainer(
427430
config,
@@ -608,7 +611,6 @@ def change_bias(
608611
)
609612
log.info(f"Saved model to {output_path}")
610613

611-
612614
@record
613615
def main(args: list[str] | argparse.Namespace | None = None) -> None:
614616
if not isinstance(args, argparse.Namespace):

deepmd/pt/infer/deep_eval.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,19 @@ def eval(
408408
request_defs = self._get_request_defs(atomic)
409409
if "spin" not in kwargs or kwargs["spin"] is None:
410410
out = self._eval_func(self._eval_model, numb_test, natoms)(
411-
coords, cells, atom_types, fparam, aparam, request_defs, charge_spin
411+
coords, cells, atom_types, fparam, aparam, request_defs
412412
)
413+
elif "grid" in kwargs and kwargs["grid"] is not None:
414+
out = self._eval_func(self._eval_model_density, numb_test, natoms)(
415+
coords,
416+
cells,
417+
atom_types,
418+
np.array(kwargs["grid"]),
419+
fparam,
420+
aparam,
421+
request_defs,
422+
)
423+
return {"density": out}
413424
else:
414425
out = self._eval_func(self._eval_model_spin, numb_test, natoms)(
415426
coords,
@@ -688,6 +699,81 @@ def _eval_model_spin(
688699
) # this is kinda hacky
689700
return tuple(results)
690701

702+
def _eval_model_density(
703+
self,
704+
coords: np.ndarray,
705+
cells: Optional[np.ndarray],
706+
atom_types: np.ndarray,
707+
grid: np.ndarray,
708+
fparam: Optional[np.ndarray],
709+
aparam: Optional[np.ndarray],
710+
request_defs: list[OutputVariableDef],
711+
):
712+
model = self.dp.to(DEVICE)
713+
714+
nframes = coords.shape[0]
715+
if len(atom_types.shape) == 1:
716+
natoms = len(atom_types)
717+
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
718+
else:
719+
natoms = len(atom_types[0])
720+
721+
coord_input = torch.tensor(
722+
coords.reshape([nframes, natoms, 3]),
723+
dtype=GLOBAL_PT_FLOAT_PRECISION,
724+
device=DEVICE,
725+
)
726+
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
727+
grid_input = torch.tensor(
728+
grid.reshape([nframes, -1, 3]),
729+
dtype=GLOBAL_PT_FLOAT_PRECISION,
730+
device=DEVICE,
731+
)
732+
ngrid = grid_input.shape[1]
733+
if cells is not None:
734+
box_input = torch.tensor(
735+
cells.reshape([nframes, 3, 3]),
736+
dtype=GLOBAL_PT_FLOAT_PRECISION,
737+
device=DEVICE,
738+
)
739+
else:
740+
box_input = None
741+
if fparam is not None:
742+
fparam_input = to_torch_tensor(
743+
fparam.reshape(nframes, self.get_dim_fparam())
744+
)
745+
else:
746+
fparam_input = None
747+
if aparam is not None:
748+
aparam_input = to_torch_tensor(
749+
aparam.reshape(nframes, natoms, self.get_dim_aparam())
750+
)
751+
else:
752+
aparam_input = None
753+
754+
do_atomic_virial = any(
755+
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
756+
)
757+
batch_output = model(
758+
coord_input,
759+
type_input,
760+
grid=grid_input,
761+
box=box_input,
762+
do_atomic_virial=do_atomic_virial,
763+
fparam=fparam_input,
764+
aparam=aparam_input,
765+
)
766+
if isinstance(batch_output, tuple):
767+
batch_output = batch_output[0]
768+
769+
results = []
770+
pt_name = "density"
771+
density_shape = [nframes, ngrid]
772+
out = batch_output[pt_name].reshape(density_shape).detach().cpu().numpy()
773+
results.append(out)
774+
return tuple(results)
775+
776+
691777
def _get_output_shape(
692778
self, odef: OutputVariableDef, nframes: int, natoms: int
693779
) -> list[int]:

deepmd/pt/loss/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .charge import (
3+
GridDensityLoss,
4+
)
25
from .denoise import (
36
DenoiseLoss,
47
)
@@ -28,6 +31,7 @@
2831
"EnergyHessianStdLoss",
2932
"EnergySpinLoss",
3033
"EnergyStdLoss",
34+
"GridDensityLoss",
3135
"PropertyLoss",
3236
"TaskLoss",
3337
"TensorLoss",

deepmd/pt/loss/charge.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import torch
3+
4+
from deepmd.pt.loss.loss import (
5+
TaskLoss,
6+
)
7+
from deepmd.pt.utils import (
8+
env,
9+
)
10+
from deepmd.pt.utils.env import (
11+
GLOBAL_PT_FLOAT_PRECISION,
12+
)
13+
from deepmd.utils.data import (
14+
DataRequirementItem,
15+
)
16+
17+
18+
class GridDensityLoss(TaskLoss):
19+
def __init__(
20+
self,
21+
starter_learning_rate=1.0,
22+
start_pref_d=0.0,
23+
limit_pref_d=0.0,
24+
inference=False,
25+
**kwargs,
26+
):
27+
r"""Construct a layer to compute loss on grid density.
28+
29+
Parameters
30+
----------
31+
starter_learning_rate : float
32+
The learning rate at the start of the training.
33+
start_pref_d : float
34+
The prefactor of charge density loss at the start of the training.
35+
limit_pref_d : float
36+
The prefactor of charge density loss at the end of the training.
37+
inference : bool
38+
If true, it will output all losses found in output, ignoring the pre-factors.
39+
**kwargs
40+
Other keyword arguments.
41+
"""
42+
super().__init__()
43+
self.starter_learning_rate = starter_learning_rate
44+
self.has_d = (start_pref_d != 0.0 and limit_pref_d != 0.0) or inference
45+
46+
self.start_pref_d = start_pref_d
47+
self.limit_pref_d = limit_pref_d
48+
self.inference = inference
49+
50+
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
51+
"""Return loss on energy and force.
52+
53+
Parameters
54+
----------
55+
input_dict : dict[str, torch.Tensor]
56+
Model inputs.
57+
model : torch.nn.Module
58+
Model to be used to output the predictions.
59+
label : dict[str, torch.Tensor]
60+
Labels.
61+
natoms : int
62+
The local atom number.
63+
64+
Returns
65+
-------
66+
model_pred: dict[str, torch.Tensor]
67+
Model predictions.
68+
loss: torch.Tensor
69+
Loss for model to minimize.
70+
more_loss: dict[str, torch.Tensor]
71+
Other losses for display.
72+
"""
73+
model_pred = model(**input_dict)
74+
coef = learning_rate / self.starter_learning_rate
75+
pref_d = self.limit_pref_d + (self.start_pref_d - self.limit_pref_d) * coef
76+
77+
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
78+
more_loss = {}
79+
# more_loss['log_keys'] = [] # showed when validation on the fly
80+
# more_loss['test_keys'] = [] # showed when doing dp test
81+
atom_norm = 1.0 / natoms
82+
if self.has_d and "density" in model_pred and "density" in label:
83+
density_pred = model_pred["density"]
84+
density_label = label["density"]
85+
find_density = label.get("find_density", 0.0)
86+
pref_d = pref_d * find_density
87+
density_pred_reshape = density_pred.reshape(-1)
88+
density_label_reshape = density_label.reshape(-1)
89+
l2_density_loss = torch.square(
90+
density_label_reshape - density_pred_reshape
91+
).mean()
92+
rmse_d = l2_density_loss.sqrt()
93+
more_loss["rmse_d"] = self.display_if_exist(rmse_d.detach(), find_density)
94+
l1_density_loss = torch.abs(
95+
density_label_reshape - density_pred_reshape
96+
).mean()
97+
loss += (pref_d * l1_density_loss).to(GLOBAL_PT_FLOAT_PRECISION)
98+
mae_d = l1_density_loss
99+
more_loss["mae_d"] = self.display_if_exist(mae_d.detach(), find_density)
100+
return model_pred, loss, more_loss
101+
102+
@property
103+
def label_requirement(self) -> list[DataRequirementItem]:
104+
"""Return data label requirements needed for this loss calculation."""
105+
label_requirement = []
106+
label_requirement.append(
107+
DataRequirementItem(
108+
"grid",
109+
ndof=3,
110+
atomic=True, # the grid is defined for each atom, so it is atomic
111+
must=True,
112+
high_prec=True,
113+
)
114+
)
115+
if self.has_d:
116+
label_requirement.append(
117+
DataRequirementItem(
118+
"density",
119+
ndof=1,
120+
atomic=True,
121+
must=False,
122+
high_prec=True,
123+
)
124+
)
125+
return label_requirement

deepmd/pt/model/atomic_model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from .base_atomic_model import (
1818
BaseAtomicModel,
1919
)
20+
from .density_atomic_model import (
21+
DPDensityAtomicModel,
22+
)
2023
from .dipole_atomic_model import (
2124
DPDipoleAtomicModel,
2225
)
@@ -52,6 +55,7 @@
5255
"DPPolarAtomicModel",
5356
"DPPropertyAtomicModel",
5457
"DPZBLLinearEnergyAtomicModel",
58+
"DPDensityAtomicModel",
5559
"LinearEnergyAtomicModel",
5660
"PairTabAtomicModel",
5761
]

0 commit comments

Comments
 (0)