Skip to content

Commit 03ef831

Browse files
committed
add ener direct fitting
1 parent 10a2e5c commit 03ef831

9 files changed

Lines changed: 360 additions & 16 deletions

File tree

deepmd/pt/infer/deep_eval.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,10 +535,19 @@ def _eval_model(
535535
out = batch_output[pt_name].reshape(shape).detach().cpu().numpy()
536536
results.append(out)
537537
else:
538-
shape = self._get_output_shape(odef, nframes, natoms)
539-
results.append(
540-
np.full(np.abs(shape), np.nan, dtype=prec)
541-
) # this is kinda hacky
538+
if (
539+
self._OUTDEF_DP2BACKEND[odef.name] == "force"
540+
and "dforce" in batch_output
541+
):
542+
# if no force, use dforce if possible
543+
shape = self._get_output_shape(odef, nframes, natoms)
544+
out = batch_output["dforce"].reshape(shape).detach().cpu().numpy()
545+
results.append(out)
546+
else:
547+
shape = self._get_output_shape(odef, nframes, natoms)
548+
results.append(
549+
np.full(np.abs(shape), np.nan, dtype=prec)
550+
) # this is kinda hacky
542551
return tuple(results)
543552

544553
def _eval_model_spin(

deepmd/pt/loss/ener.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def forward(
184184
Other losses for display.
185185
"""
186186
model_pred = model(**input_dict)
187+
188+
if "force" not in model_pred and "dforce" in model_pred:
189+
model_pred["force"] = model_pred["dforce"]
187190
coef = learning_rate / self.starter_learning_rate
188191
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
189192
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def forward_atomic(
286286
aparam=aparam,
287287
sw=sw,
288288
edge_index=add_input.get("edge_index", None),
289+
diff=add_input.get("diff", None),
289290
)
290291
if self.enable_eval_fitting_last_layer_hook:
291292
assert "middle_output" in fit_ret, (

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,8 @@ def forward(
573573
h2 = h2[nlist_mask]
574574
# n_edge x 1
575575
sw = sw[nlist_mask]
576+
# n_edge x 3
577+
diff = diff[nlist_mask]
576578
# nb x nloc x a_nnei x a_nnei
577579
a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
578580
# n_angle x 1
@@ -585,6 +587,7 @@ def forward(
585587
edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype)
586588
angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype)
587589
self.additional_output_for_fitting["edge_index"] = None
590+
self.additional_output_for_fitting["diff"] = diff
588591
# get edge and angle embedding
589592
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
590593
if not self.edge_init_use_dist:

deepmd/pt/model/model/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,16 @@ def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple:
9191
fitting_net["ntypes"] = descriptor.get_ntypes()
9292
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
9393
fitting_net["mixed_types"] = descriptor.mixed_types()
94-
if fitting_net["type"] in ["dipole", "polar", "ener_readout"]:
94+
if fitting_net["type"] in ["dipole", "polar", "ener_readout", "ener_direct"]:
9595
fitting_net["embedding_width"] = descriptor.get_dim_emb()
9696
if fitting_net["type"] in ["ener_readout"]:
9797
fitting_net["norm_fact"] = descriptor.get_norm_fact()
9898
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
99-
grad_force = "direct" not in fitting_net["type"]
100-
if not grad_force:
101-
fitting_net["out_dim"] = descriptor.get_dim_emb()
102-
if "ener" in fitting_net["type"]:
103-
fitting_net["return_energy"] = True
99+
# grad_force = "direct" not in fitting_net["type"]
100+
# if not grad_force:
101+
# fitting_net["out_dim"] = descriptor.get_dim_emb()
102+
# if "ener" in fitting_net["type"]:
103+
# fitting_net["return_energy"] = True
104104
fitting = BaseFitting(**fitting_net)
105105
return descriptor, fitting, fitting_net["type"]
106106

@@ -267,7 +267,12 @@ def get_standard_model(model_params: dict) -> BaseModel:
267267
modelcls = PolarModel
268268
elif fitting_net_type == "dos":
269269
modelcls = DOSModel
270-
elif fitting_net_type in ["ener", "direct_force_ener", "ener_readout"]:
270+
elif fitting_net_type in [
271+
"ener",
272+
"direct_force_ener",
273+
"ener_readout",
274+
"ener_direct",
275+
]:
271276
modelcls = EnergyModel
272277
elif fitting_net_type == "property":
273278
modelcls = PropertyModel

deepmd/pt/model/model/ener_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def forward(
120120
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
121121
-3
122122
)
123-
else:
124-
model_predict["force"] = model_ret["dforce"]
123+
if "dforce" in model_ret:
124+
model_predict["dforce"] = model_ret["dforce"]
125125
if "mask" in model_ret:
126126
model_predict["mask"] = model_ret["mask"]
127127
if self._hessian_enabled:
@@ -160,15 +160,16 @@ def forward_lower(
160160
model_predict["energy"] = model_ret["energy_redu"]
161161
if self.do_grad_r("energy"):
162162
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
163+
else:
164+
assert model_ret["dforce"] is not None
165+
model_predict["dforce"] = model_ret["dforce"]
166+
163167
if self.do_grad_c("energy"):
164168
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
165169
if do_atomic_virial:
166170
model_predict["extended_virial"] = model_ret[
167171
"energy_derv_c"
168172
].squeeze(-3)
169-
else:
170-
assert model_ret["dforce"] is not None
171-
model_predict["dforce"] = model_ret["dforce"]
172173
else:
173174
model_predict = model_ret
174175
return model_predict

deepmd/pt/model/task/ener.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,239 @@ def forward(
422422
# energy
423423
out = out + edge_energy / self.norm_e_fact
424424
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
425+
426+
427+
@Fitting.register("ener_direct")
428+
@fitting_check_output
429+
class EnergyFittingNetDirectHead(InvarFitting):
430+
def __init__(
431+
self,
432+
ntypes: int,
433+
dim_descrpt: int,
434+
neuron: list[int] = [128, 128, 128],
435+
bias_atom_e: Optional[torch.Tensor] = None,
436+
resnet_dt: bool = True,
437+
numb_fparam: int = 0,
438+
numb_aparam: int = 0,
439+
dim_case_embd: int = 0,
440+
embedding_width: int = 128,
441+
activation_function: str = "tanh",
442+
precision: str = DEFAULT_PRECISION,
443+
mixed_types: bool = True,
444+
seed: Optional[Union[int, list[int]]] = None,
445+
type_map: Optional[list[str]] = None,
446+
additional_gradient: bool = False,
447+
additional_noise_head: bool = False,
448+
**kwargs: Any,
449+
) -> None:
450+
"""Construct a fitting net for energy.
451+
452+
Args:
453+
- ntypes: Element count.
454+
- embedding_width: Embedding width per atom.
455+
- neuron: Number of neurons in each hidden layers of the fitting net.
456+
- bias_atom_e: Average energy per atom for each element.
457+
- resnet_dt: Using time-step in the ResNet construction.
458+
"""
459+
self.additional_gradient = additional_gradient
460+
self.additional_noise_head = additional_noise_head
461+
super().__init__(
462+
"energy",
463+
ntypes,
464+
dim_descrpt,
465+
1,
466+
neuron=neuron,
467+
bias_atom_e=bias_atom_e,
468+
resnet_dt=resnet_dt,
469+
numb_fparam=numb_fparam,
470+
numb_aparam=numb_aparam,
471+
dim_case_embd=dim_case_embd,
472+
activation_function=activation_function,
473+
precision=precision,
474+
mixed_types=mixed_types,
475+
seed=seed,
476+
type_map=type_map,
477+
**kwargs,
478+
)
479+
480+
# embedding for direct force
481+
self.force_input_dim = embedding_width # can add force embedding if needed
482+
self.force_embed = NetworkCollection(
483+
1 if not self.mixed_types else 0,
484+
self.ntypes,
485+
network_type="fitting_network",
486+
networks=[
487+
FittingNet(
488+
self.force_input_dim,
489+
1,
490+
self.neuron,
491+
self.activation_function,
492+
self.resnet_dt,
493+
self.precision,
494+
bias_out=True,
495+
seed=child_seed(self.seed + 100, ii),
496+
)
497+
for ii in range(self.ntypes if not self.mixed_types else 1)
498+
],
499+
)
500+
# additional noise head
501+
self.noise_input_dim = embedding_width # can add noise embedding if needed
502+
if self.additional_noise_head:
503+
# dforce for force; dnosie for noise
504+
self.noise_embed = NetworkCollection(
505+
1 if not self.mixed_types else 0,
506+
self.ntypes,
507+
network_type="fitting_network",
508+
networks=[
509+
FittingNet(
510+
self.noise_input_dim,
511+
1,
512+
self.neuron,
513+
self.activation_function,
514+
self.resnet_dt,
515+
self.precision,
516+
bias_out=True,
517+
seed=child_seed(self.seed + 200, ii),
518+
)
519+
for ii in range(self.ntypes if not self.mixed_types else 1)
520+
],
521+
)
522+
else:
523+
# dforce for noise
524+
self.noise_embed = None
525+
526+
# set trainable
527+
for param in self.parameters():
528+
param.requires_grad = self.trainable
529+
530+
def output_def(self) -> FittingOutputDef:
531+
out_list = [
532+
OutputVariableDef(
533+
self.var_name,
534+
[self.dim_out],
535+
reducible=True,
536+
r_differentiable=self.additional_gradient,
537+
c_differentiable=self.additional_gradient,
538+
),
539+
OutputVariableDef(
540+
"dforce",
541+
[3],
542+
reducible=False,
543+
r_differentiable=False,
544+
c_differentiable=False,
545+
),
546+
]
547+
if self.additional_noise_head:
548+
out_list.append(
549+
OutputVariableDef(
550+
"dnoise",
551+
[3],
552+
reducible=False,
553+
r_differentiable=False,
554+
c_differentiable=False,
555+
)
556+
)
557+
558+
return FittingOutputDef(out_list)
559+
560+
# make jit happy with torch 2.0.0
561+
exclude_types: list[int]
562+
563+
def need_additional_input(self) -> bool:
564+
return True
565+
566+
def serialize(self) -> dict:
567+
raise NotImplementedError
568+
569+
@classmethod
570+
def deserialize(cls, data: dict) -> "EnergyFittingNetDirectHead":
571+
raise NotImplementedError
572+
573+
def change_type_map(
574+
self, type_map: list[str], model_with_new_type_stat: Optional[Any] = None
575+
) -> None:
576+
raise NotImplementedError
577+
578+
def get_type_map(self) -> list[str]:
579+
raise NotImplementedError
580+
581+
def forward(
582+
self,
583+
descriptor: torch.Tensor,
584+
atype: torch.Tensor,
585+
gr: Optional[torch.Tensor] = None,
586+
g2: Optional[torch.Tensor] = None,
587+
h2: Optional[torch.Tensor] = None,
588+
fparam: Optional[torch.Tensor] = None,
589+
aparam: Optional[torch.Tensor] = None,
590+
diff: Optional[torch.Tensor] = None,
591+
edge_index: Optional[torch.Tensor] = None,
592+
sw: Optional[torch.Tensor] = None,
593+
) -> dict[str, torch.Tensor]:
594+
"""Based on embedding net output, alculate total energy.
595+
596+
Args:
597+
- inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt].
598+
- natoms: Tell atom count and element count. Its shape is [2+self.ntypes].
599+
600+
Returns
601+
-------
602+
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
603+
"""
604+
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
605+
self.var_name
606+
]
607+
# energy
608+
result = {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
609+
610+
# direct force
611+
assert diff is not None
612+
assert g2 is not None
613+
614+
nf, nloc, _ = descriptor.shape
615+
616+
# nf x nloc x nnei x 3 [OR] nedge x 3
617+
edge_vec = diff
618+
# nf x nloc x nnei x d [OR] nedge x d
619+
edge_feature = g2
620+
# nf x nloc x nnei x 1 [OR] nedge x 1
621+
edge_weight = self.force_embed.networks[0](edge_feature)
622+
# nf x nloc x nnei x 3 [OR] nedge x 3
623+
fij = edge_weight * edge_vec
624+
if edge_index is not None:
625+
# use dynamic sel
626+
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
627+
# nf x nloc x 3
628+
fi = aggregate(
629+
fij,
630+
n2e_index,
631+
average=False,
632+
num_owner=nf * nloc,
633+
).reshape(nf, nloc, 3)
634+
else:
635+
# nf x nloc x 3
636+
fi = torch.sum(fij, dim=-2)
637+
638+
result["dforce"] = fi
639+
640+
if self.additional_noise_head:
641+
assert self.noise_embed is not None
642+
edge_weight = self.noise_embed.networks[0](edge_feature)
643+
# nf x nloc x nnei x 3 [OR] nedge x 3
644+
nij = edge_weight * edge_vec
645+
if edge_index is not None:
646+
# use dynamic sel
647+
n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1]
648+
# nf x nloc x 3
649+
ni = aggregate(
650+
nij,
651+
n2e_index,
652+
average=False,
653+
num_owner=nf * nloc,
654+
).reshape(nf, nloc, 3)
655+
else:
656+
# nf x nloc x 3
657+
ni = torch.sum(nij, dim=-2)
658+
result["dnoise"] = ni
659+
660+
return result

deepmd/pt/model/task/invar_fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def forward(
178178
aparam: Optional[torch.Tensor] = None,
179179
sw: Optional[torch.Tensor] = None,
180180
edge_index: Optional[torch.Tensor] = None,
181+
diff: Optional[torch.Tensor] = None,
181182
) -> dict[str, torch.Tensor]:
182183
"""Based on embedding net output, alculate total energy.
183184

0 commit comments

Comments
 (0)