Skip to content

Commit 7c58b9e

Browse files
committed
feat: add edge readout
1 parent ba4e2ad commit 7c58b9e

10 files changed

Lines changed: 328 additions & 11 deletions

File tree

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def compute_input_stats(
154154
"""Update mean and stddev for descriptor elements."""
155155
raise NotImplementedError
156156

157+
def get_norm_fact(self) -> list[float]:
158+
"""Returns the norm factor."""
159+
raise NotImplementedError
160+
161+
def get_additional_output_for_fitting(self) -> dict[str, Optional[Array]]:
162+
raise NotImplementedError
163+
157164
def enable_compression(
158165
self,
159166
min_nbor_dist: float,

deepmd/dpmodel/fitting/make_base_fitting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def compute_output_stats(self, merged: Any) -> NoReturn:
6868
"""Update the output bias for fitting net."""
6969
raise NotImplementedError
7070

71+
def need_additional_input(self) -> bool:
72+
return False
73+
7174
@abstractmethod
7275
def get_type_map(self) -> list[str]:
7376
"""Get the name to each type of atoms."""

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,29 @@ def forward_atomic(
264264
if self.enable_eval_descriptor_hook:
265265
self.eval_descriptor_list.append(descriptor.detach())
266266
# energy, force
267-
fit_ret = self.fitting_net(
268-
descriptor,
269-
atype,
270-
gr=rot_mat,
271-
g2=g2,
272-
h2=h2,
273-
fparam=fparam,
274-
aparam=aparam,
275-
)
267+
if not self.fitting_net.need_additional_input():
268+
fit_ret = self.fitting_net(
269+
descriptor,
270+
atype,
271+
gr=rot_mat,
272+
g2=g2,
273+
h2=h2,
274+
fparam=fparam,
275+
aparam=aparam,
276+
)
277+
else:
278+
add_input = self.descriptor.get_additional_output_for_fitting()
279+
fit_ret = self.fitting_net(
280+
descriptor,
281+
atype,
282+
gr=rot_mat,
283+
g2=g2,
284+
h2=h2,
285+
fparam=fparam,
286+
aparam=aparam,
287+
sw=sw,
288+
edge_index=add_input.get("edge_index", None),
289+
)
276290
if self.enable_eval_fitting_last_layer_hook:
277291
assert "middle_output" in fit_ret, (
278292
"eval_fitting_last_layer not supported for this fitting net!"

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,13 @@ def get_dim_out(self) -> int:
357357
def get_dim_emb(self) -> int:
358358
return self.se_atten.dim_emb
359359

360+
def get_norm_fact(self) -> list[float]:
361+
"""Returns the norm factor."""
362+
return [float(self.get_nnei())]
363+
364+
def get_additional_output_for_fitting(self) -> dict[str, Optional[torch.Tensor]]:
365+
return {}
366+
360367
def mixed_types(self) -> bool:
361368
"""If true, the descriptor
362369
1. assumes total number of atoms aligned across frames;

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ def get_dim_emb(self) -> int:
253253
"""Returns the embedding dimension of this descriptor."""
254254
return self.repflows.dim_emb
255255

256+
def get_norm_fact(self) -> list[float]:
257+
"""Returns the norm factor."""
258+
return self.repflows.get_norm_fact()
259+
260+
def get_additional_output_for_fitting(self) -> dict[str, Optional[torch.Tensor]]:
261+
return self.repflows.get_additional_output_for_fitting()
262+
256263
def mixed_types(self) -> bool:
257264
"""If true, the descriptor
258265
1. assumes total number of atoms aligned across frames;

deepmd/pt/model/descriptor/repflows.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ def __init__(
261261
self.use_exp_switch = use_exp_switch
262262
self.use_dynamic_sel = use_dynamic_sel
263263
self.sel_reduce_factor = sel_reduce_factor
264+
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
265+
self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
264266
if self.use_dynamic_sel and not self.smooth_edge_update:
265267
raise NotImplementedError(
266268
"smooth_edge_update must be True when use_dynamic_sel is True!"
@@ -344,6 +346,7 @@ def __init__(
344346
)
345347
)
346348
self.layers = torch.nn.ModuleList(layers)
349+
self.additional_output_for_fitting: dict[str, Optional[torch.Tensor]] = {}
347350

348351
wanted_shape = (self.ntypes, self.nnei, 4)
349352
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
@@ -354,6 +357,8 @@ def __init__(
354357
self.register_buffer("stddev", stddev)
355358
self.stats = None
356359

360+
additional_output_for_fitting: dict[str, Optional[torch.Tensor]]
361+
357362
def get_rcut(self) -> float:
358363
"""Returns the cut-off radius."""
359364
return self.e_rcut
@@ -386,6 +391,16 @@ def get_dim_emb(self) -> int:
386391
"""Returns the embedding dimension e_dim."""
387392
return self.e_dim
388393

394+
def get_additional_output_for_fitting(self) -> dict[str, Optional[torch.Tensor]]:
395+
return self.additional_output_for_fitting
396+
397+
def get_norm_fact(self) -> list[float]:
398+
"""Returns the norm factor."""
399+
return [
400+
float(self.dynamic_e_sel if self.use_dynamic_sel else self.nnei),
401+
# float(self.dynamic_a_sel if self.use_dynamic_sel else self.a_sel),
402+
]
403+
389404
def __setitem__(self, key: str, value: Any) -> None:
390405
if key in ("avg", "data_avg", "davg"):
391406
self.mean = value
@@ -564,10 +579,12 @@ def forward(
564579
angle_input = angle_input[a_nlist_mask]
565580
# n_angle x 1
566581
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
582+
self.additional_output_for_fitting["edge_index"] = edge_index
567583
else:
568584
# avoid jit assertion
569585
edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype)
570586
angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype)
587+
self.additional_output_for_fitting["edge_index"] = None
571588
# get edge and angle embedding
572589
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
573590
if not self.edge_init_use_dist:

deepmd/pt/model/model/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple:
9191
fitting_net["ntypes"] = descriptor.get_ntypes()
9292
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
9393
fitting_net["mixed_types"] = descriptor.mixed_types()
94-
if fitting_net["type"] in ["dipole", "polar"]:
94+
if fitting_net["type"] in ["dipole", "polar", "ener_readout"]:
9595
fitting_net["embedding_width"] = descriptor.get_dim_emb()
96+
if fitting_net["type"] in ["ener_readout"]:
97+
fitting_net["norm_fact"] = descriptor.get_norm_fact()
9698
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
9799
grad_force = "direct" not in fitting_net["type"]
98100
if not grad_force:
@@ -265,7 +267,7 @@ def get_standard_model(model_params: dict) -> BaseModel:
265267
modelcls = PolarModel
266268
elif fitting_net_type == "dos":
267269
modelcls = DOSModel
268-
elif fitting_net_type in ["ener", "direct_force_ener"]:
270+
elif fitting_net_type in ["ener", "direct_force_ener", "ener_readout"]:
269271
modelcls = EnergyModel
270272
elif fitting_net_type == "property":
271273
modelcls = PropertyModel

deepmd/pt/model/task/ener.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,19 @@
1414
OutputVariableDef,
1515
fitting_check_output,
1616
)
17+
from deepmd.dpmodel.utils.seed import (
18+
child_seed,
19+
)
20+
from deepmd.pt.model.network.mlp import (
21+
FittingNet,
22+
NetworkCollection,
23+
)
1724
from deepmd.pt.model.network.network import (
1825
ResidualDeep,
1926
)
27+
from deepmd.pt.model.network.utils import (
28+
aggregate,
29+
)
2030
from deepmd.pt.model.task.fitting import (
2131
Fitting,
2232
GeneralFitting,
@@ -260,3 +270,155 @@ def forward(
260270
"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION),
261271
"dforce": vec_out,
262272
}
273+
274+
275+
@Fitting.register("ener_readout")
276+
@fitting_check_output
277+
class EnergyFittingNetReadout(InvarFitting):
278+
def __init__(
279+
self,
280+
ntypes: int,
281+
dim_descrpt: int,
282+
neuron: list[int] = [128, 128, 128],
283+
bias_atom_e: Optional[torch.Tensor] = None,
284+
resnet_dt: bool = True,
285+
numb_fparam: int = 0,
286+
numb_aparam: int = 0,
287+
dim_case_embd: int = 0,
288+
embedding_width: int = 128,
289+
activation_function: str = "tanh",
290+
precision: str = DEFAULT_PRECISION,
291+
mixed_types: bool = True,
292+
seed: Optional[Union[int, list[int]]] = None,
293+
type_map: Optional[list[str]] = None,
294+
norm_fact: list[float] = [120.0],
295+
add_edge_readout: bool = True,
296+
slim_edge_readout: bool = False,
297+
**kwargs: Any,
298+
) -> None:
299+
"""Construct a fitting net for energy.
300+
301+
Args:
302+
- ntypes: Element count.
303+
- embedding_width: Embedding width per atom.
304+
- neuron: Number of neurons in each hidden layers of the fitting net.
305+
- bias_atom_e: Average energy per atom for each element.
306+
- resnet_dt: Using time-step in the ResNet construction.
307+
"""
308+
self.add_edge_readout = add_edge_readout
309+
super().__init__(
310+
"energy",
311+
ntypes,
312+
dim_descrpt,
313+
1,
314+
neuron=neuron,
315+
bias_atom_e=bias_atom_e,
316+
resnet_dt=resnet_dt,
317+
numb_fparam=numb_fparam,
318+
numb_aparam=numb_aparam,
319+
dim_case_embd=dim_case_embd,
320+
activation_function=activation_function,
321+
precision=precision,
322+
mixed_types=mixed_types,
323+
seed=seed,
324+
type_map=type_map,
325+
**kwargs,
326+
)
327+
328+
# embedding for edge readout
329+
self.embedding_width = embedding_width
330+
self.slim_edge_readout = slim_edge_readout
331+
self.norm_e_fact = norm_fact[0]
332+
333+
if self.add_edge_readout:
334+
self.edge_embed = NetworkCollection(
335+
1 if not self.mixed_types else 0,
336+
self.ntypes,
337+
network_type="fitting_network",
338+
networks=[
339+
FittingNet(
340+
self.embedding_width,
341+
1,
342+
self.neuron if not self.slim_edge_readout else self.neuron[:1],
343+
self.activation_function,
344+
self.resnet_dt,
345+
self.precision,
346+
bias_out=True,
347+
seed=child_seed(self.seed + 100, ii),
348+
)
349+
for ii in range(self.ntypes if not self.mixed_types else 1)
350+
],
351+
)
352+
else:
353+
self.edge_embed = None
354+
355+
# set trainable
356+
for param in self.parameters():
357+
param.requires_grad = self.trainable
358+
359+
# make jit happy with torch 2.0.0
360+
exclude_types: list[int]
361+
362+
def need_additional_input(self) -> bool:
363+
return True
364+
365+
def serialize(self) -> dict:
366+
raise NotImplementedError
367+
368+
@classmethod
369+
def deserialize(cls, data: dict) -> "EnergyFittingNetReadout":
370+
raise NotImplementedError
371+
372+
def forward(
373+
self,
374+
descriptor: torch.Tensor,
375+
atype: torch.Tensor,
376+
gr: Optional[torch.Tensor] = None,
377+
g2: Optional[torch.Tensor] = None,
378+
h2: Optional[torch.Tensor] = None,
379+
fparam: Optional[torch.Tensor] = None,
380+
aparam: Optional[torch.Tensor] = None,
381+
sw: Optional[torch.Tensor] = None,
382+
edge_index: Optional[torch.Tensor] = None,
383+
) -> dict[str, torch.Tensor]:
384+
"""Based on embedding net output, alculate total energy.
385+
386+
Args:
387+
- inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt].
388+
- natoms: Tell atom count and element count. Its shape is [2+self.ntypes].
389+
390+
Returns
391+
-------
392+
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
393+
"""
394+
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
395+
self.var_name
396+
]
397+
nf, nloc, _ = descriptor.shape
398+
399+
if self.add_edge_readout:
400+
assert g2 is not None
401+
assert sw is not None
402+
assert self.edge_embed is not None
403+
# nf x nloc x nnei x d [OR] nedge x d
404+
edge_feature = g2
405+
# nf x nloc x nnei x 1 [OR] nedge x 1
406+
edge_atomic_contrib = self.edge_embed.networks[0](edge_feature)
407+
# nf x nloc x nnei x 1 [OR] nedge x 1
408+
edge_atomic_contrib = edge_atomic_contrib * sw.unsqueeze(-1)
409+
if edge_index is not None:
410+
# use dynamic sel
411+
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
412+
# nf x nloc x 1
413+
edge_energy = aggregate(
414+
edge_atomic_contrib,
415+
n2e_index,
416+
average=False,
417+
num_owner=nf * nloc,
418+
).reshape(nf, nloc, 1)
419+
else:
420+
# nf x nloc x 1
421+
edge_energy = torch.sum(edge_atomic_contrib, dim=-2)
422+
# energy
423+
out = out + edge_energy / self.norm_e_fact
424+
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

deepmd/pt/model/task/invar_fitting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def forward(
176176
h2: Optional[torch.Tensor] = None,
177177
fparam: Optional[torch.Tensor] = None,
178178
aparam: Optional[torch.Tensor] = None,
179+
sw: Optional[torch.Tensor] = None,
180+
edge_index: Optional[torch.Tensor] = None,
179181
) -> dict[str, torch.Tensor]:
180182
"""Based on embedding net output, alculate total energy.
181183

0 commit comments

Comments
 (0)