Skip to content

Commit 5836d3a

Browse files
committed
feat(pt): add edge readout
1 parent 68f0d21 commit 5836d3a

10 files changed

Lines changed: 328 additions & 11 deletions

File tree

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ def compute_input_stats(
148148
"""Update mean and stddev for descriptor elements."""
149149
raise NotImplementedError
150150

151+
def get_norm_fact(self) -> list[float]:
152+
"""Returns the norm factor."""
153+
raise NotImplementedError
154+
155+
def get_additional_output_for_fitting(self):
156+
raise NotImplementedError
157+
151158
def enable_compression(
152159
self,
153160
min_nbor_dist: float,

deepmd/dpmodel/fitting/make_base_fitting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def compute_output_stats(self, merged) -> NoReturn:
6767
"""Update the output bias for fitting net."""
6868
raise NotImplementedError
6969

70+
def need_additional_input(self) -> bool:
71+
return False
72+
7073
@abstractmethod
7174
def get_type_map(self) -> list[str]:
7275
"""Get the name to each type of atoms."""

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,29 @@ def forward_atomic(
260260
if self.enable_eval_descriptor_hook:
261261
self.eval_descriptor_list.append(descriptor.detach())
262262
# energy, force
263-
fit_ret = self.fitting_net(
264-
descriptor,
265-
atype,
266-
gr=rot_mat,
267-
g2=g2,
268-
h2=h2,
269-
fparam=fparam,
270-
aparam=aparam,
271-
)
263+
if not self.fitting_net.need_additional_input():
264+
fit_ret = self.fitting_net(
265+
descriptor,
266+
atype,
267+
gr=rot_mat,
268+
g2=g2,
269+
h2=h2,
270+
fparam=fparam,
271+
aparam=aparam,
272+
)
273+
else:
274+
add_input = self.descriptor.get_additional_output_for_fitting()
275+
fit_ret = self.fitting_net(
276+
descriptor,
277+
atype,
278+
gr=rot_mat,
279+
g2=g2,
280+
h2=h2,
281+
fparam=fparam,
282+
aparam=aparam,
283+
sw=sw,
284+
edge_index=add_input.get("edge_index", None),
285+
)
272286
if self.enable_eval_fitting_last_layer_hook:
273287
assert "middle_output" in fit_ret, (
274288
"eval_fitting_last_layer not supported for this fitting net!"

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ def get_dim_out(self) -> int:
356356
def get_dim_emb(self) -> int:
357357
return self.se_atten.dim_emb
358358

359+
def get_norm_fact(self) -> list[float]:
360+
"""Returns the norm factor."""
361+
return [float(self.get_nnei())]
362+
363+
def get_additional_output_for_fitting(self) -> dict[str, Optional[torch.Tensor]]:
364+
return {}
365+
359366
def mixed_types(self) -> bool:
360367
"""If true, the descriptor
361368
1. assumes total number of atoms aligned across frames;

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def get_dim_emb(self) -> int:
248248
"""Returns the embedding dimension of this descriptor."""
249249
return self.repflows.dim_emb
250250

251+
def get_norm_fact(self) -> list[float]:
252+
"""Returns the norm factor."""
253+
return self.repflows.get_norm_fact()
254+
255+
def get_additional_output_for_fitting(self):
256+
return self.repflows.get_additional_output_for_fitting()
257+
251258
def mixed_types(self) -> bool:
252259
"""If true, the descriptor
253260
1. assumes total number of atoms aligned across frames;

deepmd/pt/model/descriptor/repflows.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def __init__(
257257
self.use_exp_switch = use_exp_switch
258258
self.use_dynamic_sel = use_dynamic_sel
259259
self.sel_reduce_factor = sel_reduce_factor
260+
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
261+
self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
260262
if self.use_dynamic_sel and not self.smooth_edge_update:
261263
raise NotImplementedError(
262264
"smooth_edge_update must be True when use_dynamic_sel is True!"
@@ -334,6 +336,7 @@ def __init__(
334336
)
335337
)
336338
self.layers = torch.nn.ModuleList(layers)
339+
self.additional_output_for_fitting: dict[str, Optional[torch.Tensor]] = {}
337340

338341
wanted_shape = (self.ntypes, self.nnei, 4)
339342
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
@@ -344,6 +347,8 @@ def __init__(
344347
self.register_buffer("stddev", stddev)
345348
self.stats = None
346349

350+
additional_output_for_fitting: dict[str, Optional[torch.Tensor]]
351+
347352
def get_rcut(self) -> float:
348353
"""Returns the cut-off radius."""
349354
return self.e_rcut
@@ -376,6 +381,16 @@ def get_dim_emb(self) -> int:
376381
"""Returns the embedding dimension e_dim."""
377382
return self.e_dim
378383

384+
def get_additional_output_for_fitting(self):
385+
return self.additional_output_for_fitting
386+
387+
def get_norm_fact(self) -> list[float]:
388+
"""Returns the norm factor."""
389+
return [
390+
float(self.dynamic_e_sel if self.use_dynamic_sel else self.nnei),
391+
# float(self.dynamic_a_sel if self.use_dynamic_sel else self.a_sel),
392+
]
393+
379394
def __setitem__(self, key, value) -> None:
380395
if key in ("avg", "data_avg", "davg"):
381396
self.mean = value
@@ -548,10 +563,12 @@ def forward(
548563
angle_input = angle_input[a_nlist_mask]
549564
# n_angle x 1
550565
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
566+
self.additional_output_for_fitting["edge_index"] = edge_index
551567
else:
552568
# avoid jit assertion
553569
edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype)
554570
angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype)
571+
self.additional_output_for_fitting["edge_index"] = None
555572
# get edge and angle embedding
556573
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
557574
if not self.edge_init_use_dist:

deepmd/pt/model/model/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ def _get_standard_model_components(model_params, ntypes):
9090
fitting_net["ntypes"] = descriptor.get_ntypes()
9191
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
9292
fitting_net["mixed_types"] = descriptor.mixed_types()
93-
if fitting_net["type"] in ["dipole", "polar"]:
93+
if fitting_net["type"] in ["dipole", "polar", "ener_readout"]:
9494
fitting_net["embedding_width"] = descriptor.get_dim_emb()
95+
if fitting_net["type"] in ["ener_readout"]:
96+
fitting_net["norm_fact"] = descriptor.get_norm_fact()
9597
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
9698
grad_force = "direct" not in fitting_net["type"]
9799
if not grad_force:
@@ -262,7 +264,7 @@ def get_standard_model(model_params):
262264
modelcls = PolarModel
263265
elif fitting_net_type == "dos":
264266
modelcls = DOSModel
265-
elif fitting_net_type in ["ener", "direct_force_ener"]:
267+
elif fitting_net_type in ["ener", "direct_force_ener", "ener_readout"]:
266268
modelcls = EnergyModel
267269
elif fitting_net_type == "property":
268270
modelcls = PropertyModel

deepmd/pt/model/task/ener.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,19 @@
1313
OutputVariableDef,
1414
fitting_check_output,
1515
)
16+
from deepmd.dpmodel.utils.seed import (
17+
child_seed,
18+
)
19+
from deepmd.pt.model.network.mlp import (
20+
FittingNet,
21+
NetworkCollection,
22+
)
1623
from deepmd.pt.model.network.network import (
1724
ResidualDeep,
1825
)
26+
from deepmd.pt.model.network.utils import (
27+
aggregate,
28+
)
1929
from deepmd.pt.model.task.fitting import (
2030
Fitting,
2131
GeneralFitting,
@@ -257,3 +267,155 @@ def forward(
257267
"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION),
258268
"dforce": vec_out,
259269
}
270+
271+
272+
@Fitting.register("ener_readout")
273+
@fitting_check_output
274+
class EnergyFittingNetReadout(InvarFitting):
275+
def __init__(
276+
self,
277+
ntypes: int,
278+
dim_descrpt: int,
279+
neuron: list[int] = [128, 128, 128],
280+
bias_atom_e: Optional[torch.Tensor] = None,
281+
resnet_dt: bool = True,
282+
numb_fparam: int = 0,
283+
numb_aparam: int = 0,
284+
dim_case_embd: int = 0,
285+
embedding_width: int = 128,
286+
activation_function: str = "tanh",
287+
precision: str = DEFAULT_PRECISION,
288+
mixed_types: bool = True,
289+
seed: Optional[Union[int, list[int]]] = None,
290+
type_map: Optional[list[str]] = None,
291+
norm_fact: list[float] = [120.0],
292+
add_edge_readout: bool = True,
293+
slim_edge_readout: bool = False,
294+
**kwargs,
295+
) -> None:
296+
"""Construct a fitting net for energy.
297+
298+
Args:
299+
- ntypes: Element count.
300+
- embedding_width: Embedding width per atom.
301+
- neuron: Number of neurons in each hidden layers of the fitting net.
302+
- bias_atom_e: Average energy per atom for each element.
303+
- resnet_dt: Using time-step in the ResNet construction.
304+
"""
305+
self.add_edge_readout = add_edge_readout
306+
super().__init__(
307+
"energy",
308+
ntypes,
309+
dim_descrpt,
310+
1,
311+
neuron=neuron,
312+
bias_atom_e=bias_atom_e,
313+
resnet_dt=resnet_dt,
314+
numb_fparam=numb_fparam,
315+
numb_aparam=numb_aparam,
316+
dim_case_embd=dim_case_embd,
317+
activation_function=activation_function,
318+
precision=precision,
319+
mixed_types=mixed_types,
320+
seed=seed,
321+
type_map=type_map,
322+
**kwargs,
323+
)
324+
325+
# embedding for edge readout
326+
self.embedding_width = embedding_width
327+
self.slim_edge_readout = slim_edge_readout
328+
self.norm_e_fact = norm_fact[0]
329+
330+
if self.add_edge_readout:
331+
self.edge_embed = NetworkCollection(
332+
1 if not self.mixed_types else 0,
333+
self.ntypes,
334+
network_type="fitting_network",
335+
networks=[
336+
FittingNet(
337+
self.embedding_width,
338+
1,
339+
self.neuron if not self.slim_edge_readout else self.neuron[:1],
340+
self.activation_function,
341+
self.resnet_dt,
342+
self.precision,
343+
bias_out=True,
344+
seed=child_seed(self.seed + 100, ii),
345+
)
346+
for ii in range(self.ntypes if not self.mixed_types else 1)
347+
],
348+
)
349+
else:
350+
self.edge_embed = None
351+
352+
# set trainable
353+
for param in self.parameters():
354+
param.requires_grad = self.trainable
355+
356+
# make jit happy with torch 2.0.0
357+
exclude_types: list[int]
358+
359+
def need_additional_input(self) -> bool:
360+
return True
361+
362+
def serialize(self) -> dict:
363+
raise NotImplementedError
364+
365+
@classmethod
366+
def deserialize(cls, data: dict) -> "EnergyFittingNetReadout":
367+
raise NotImplementedError
368+
369+
def forward(
370+
self,
371+
descriptor: torch.Tensor,
372+
atype: torch.Tensor,
373+
gr: Optional[torch.Tensor] = None,
374+
g2: Optional[torch.Tensor] = None,
375+
h2: Optional[torch.Tensor] = None,
376+
fparam: Optional[torch.Tensor] = None,
377+
aparam: Optional[torch.Tensor] = None,
378+
sw: Optional[torch.Tensor] = None,
379+
edge_index: Optional[torch.Tensor] = None,
380+
):
381+
"""Based on embedding net output, alculate total energy.
382+
383+
Args:
384+
- inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt].
385+
- natoms: Tell atom count and element count. Its shape is [2+self.ntypes].
386+
387+
Returns
388+
-------
389+
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
390+
"""
391+
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
392+
self.var_name
393+
]
394+
nf, nloc, _ = descriptor.shape
395+
396+
if self.add_edge_readout:
397+
assert g2 is not None
398+
assert sw is not None
399+
assert self.edge_embed is not None
400+
# nf x nloc x nnei x d [OR] nedge x d
401+
edge_feature = g2
402+
# nf x nloc x nnei x 1 [OR] nedge x 1
403+
edge_atomic_contrib = self.edge_embed.networks[0](edge_feature)
404+
# nf x nloc x nnei x 1 [OR] nedge x 1
405+
edge_atomic_contrib = edge_atomic_contrib * sw.unsqueeze(-1)
406+
if edge_index is not None:
407+
# use dynamic sel
408+
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
409+
# nf x nloc x 1
410+
edge_energy = aggregate(
411+
edge_atomic_contrib,
412+
n2e_index,
413+
average=False,
414+
num_owner=nf * nloc,
415+
).reshape(nf, nloc, 1)
416+
else:
417+
# nf x nloc x 1
418+
edge_energy = torch.sum(edge_atomic_contrib, dim=-2)
419+
# energy
420+
out = out + edge_energy / self.norm_e_fact
421+
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

deepmd/pt/model/task/invar_fitting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def forward(
170170
h2: Optional[torch.Tensor] = None,
171171
fparam: Optional[torch.Tensor] = None,
172172
aparam: Optional[torch.Tensor] = None,
173+
sw: Optional[torch.Tensor] = None,
174+
edge_index: Optional[torch.Tensor] = None,
173175
):
174176
"""Based on embedding net output, alculate total energy.
175177

0 commit comments

Comments
 (0)