Skip to content

Commit 9b07ca4

Browse files
feat(pd): add add_chg_spin_ebd parameter to DescrptDPA3
Co-authored-by: HydrogenSulfate <23737287+HydrogenSulfate@users.noreply.github.com> Agent-Logs-Url: https://github.com/HydrogenSulfate/deepmd-kit/sessions/730a0b97-f969-4779-8394-1758329031b6
1 parent 4ce5c73 commit 9b07ca4

7 files changed

Lines changed: 117 additions & 18 deletions

File tree

deepmd/pd/model/atomic_model/dp_atomic_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __init__(
6565
self.sel = self.descriptor.get_sel()
6666
self.fitting_net = fitting
6767
super().init_out_stat()
68+
self.add_chg_spin_ebd: bool = getattr(
69+
self.descriptor, "add_chg_spin_ebd", False
70+
)
6871
self.enable_eval_descriptor_hook = False
6972
self.enable_eval_fitting_last_layer_hook = False
7073
self.eval_descriptor_list = []
@@ -329,6 +332,7 @@ def forward_atomic(
329332
nlist,
330333
mapping=mapping,
331334
comm_dict=comm_dict,
335+
fparam=fparam if self.add_chg_spin_ebd else None,
332336
)
333337
assert descriptor is not None
334338
if self.enable_eval_descriptor_hook:

deepmd/pd/model/descriptor/dpa1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ def forward(
625625
nlist: paddle.Tensor,
626626
mapping: paddle.Tensor | None = None,
627627
comm_dict: list[paddle.Tensor] | None = None,
628+
fparam: paddle.Tensor | None = None,
628629
) -> paddle.Tensor:
629630
"""Compute the descriptor.
630631

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ def forward(
732732
nlist: paddle.Tensor,
733733
mapping: paddle.Tensor | None = None,
734734
comm_dict: list[paddle.Tensor] | None = None,
735+
fparam: paddle.Tensor | None = None,
735736
) -> paddle.Tensor:
736737
"""Compute the descriptor.
737738

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
UpdateSel,
3333
)
3434
from deepmd.pd.utils.utils import (
35+
ActivationFn,
3536
to_numpy_array,
3637
)
3738
from deepmd.utils.data_system import (
@@ -120,6 +121,7 @@ def __init__(
120121
use_tebd_bias: bool = False,
121122
use_loc_mapping: bool = True,
122123
type_map: list[str] | None = None,
124+
add_chg_spin_ebd: bool = False,
123125
) -> None:
124126
super().__init__()
125127

@@ -174,6 +176,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
174176
)
175177

176178
self.use_econf_tebd = use_econf_tebd
179+
self.add_chg_spin_ebd = add_chg_spin_ebd
177180
self.use_loc_mapping = use_loc_mapping
178181
self.use_tebd_bias = use_tebd_bias
179182
self.type_map = type_map
@@ -196,6 +199,34 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
196199
self.concat_output_tebd = concat_output_tebd
197200
self.precision = precision
198201
self.prec = PRECISION_DICT[self.precision]
202+
203+
if self.add_chg_spin_ebd:
204+
self.act = ActivationFn(activation_function)
205+
# -100 ~ 100 is a conservative bound
206+
self.chg_embedding = TypeEmbedNet(
207+
200,
208+
self.tebd_dim,
209+
precision=precision,
210+
seed=child_seed(seed, 3),
211+
)
212+
# 100 is a conservative upper bound
213+
self.spin_embedding = TypeEmbedNet(
214+
100,
215+
self.tebd_dim,
216+
precision=precision,
217+
seed=child_seed(seed, 4),
218+
)
219+
self.mix_cs_mlp = MLPLayer(
220+
2 * self.tebd_dim,
221+
self.tebd_dim,
222+
precision=precision,
223+
seed=child_seed(seed, 5),
224+
)
225+
else:
226+
self.chg_embedding = None
227+
self.spin_embedding = None
228+
self.mix_cs_mlp = None
229+
199230
self.exclude_types = exclude_types
200231
self.env_protection = env_protection
201232
self.trainable = trainable
@@ -433,9 +464,14 @@ def serialize(self) -> dict:
433464
"use_econf_tebd": self.use_econf_tebd,
434465
"use_tebd_bias": self.use_tebd_bias,
435466
"use_loc_mapping": self.use_loc_mapping,
467+
"add_chg_spin_ebd": self.add_chg_spin_ebd,
436468
"type_map": self.type_map,
437469
"type_embedding": self.type_embedding.embedding.serialize(),
438470
}
471+
if self.add_chg_spin_ebd:
472+
data["chg_embedding"] = self.chg_embedding.embedding.serialize()
473+
data["spin_embedding"] = self.spin_embedding.embedding.serialize()
474+
data["mix_cs_mlp"] = self.mix_cs_mlp.serialize()
439475
repflow_variable = {
440476
"edge_embd": repflows.edge_embd.serialize(),
441477
"angle_embd": repflows.angle_embd.serialize(),
@@ -462,12 +498,24 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
462498
data.pop("type")
463499
repflow_variable = data.pop("repflow_variable").copy()
464500
type_embedding = data.pop("type_embedding")
501+
chg_embedding = data.pop("chg_embedding", None)
502+
spin_embedding = data.pop("spin_embedding", None)
503+
mix_cs_mlp = data.pop("mix_cs_mlp", None)
465504
data["repflow"] = RepFlowArgs(**data.pop("repflow_args"))
466505
obj = cls(**data)
467506
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
468507
type_embedding
469508
)
470509

510+
if obj.add_chg_spin_ebd and chg_embedding is not None:
511+
obj.chg_embedding.embedding = TypeEmbedNetConsistent.deserialize(
512+
chg_embedding
513+
)
514+
obj.spin_embedding.embedding = TypeEmbedNetConsistent.deserialize(
515+
spin_embedding
516+
)
517+
obj.mix_cs_mlp = MLPLayer.deserialize(mix_cs_mlp)
518+
471519
def t_cvt(xx: Any) -> paddle.Tensor:
472520
return paddle.to_tensor(xx, dtype=obj.repflows.prec, place=env.DEVICE)
473521

@@ -493,7 +541,14 @@ def forward(
493541
nlist: paddle.Tensor,
494542
mapping: paddle.Tensor | None = None,
495543
comm_dict: list[paddle.Tensor] | None = None,
496-
) -> paddle.Tensor:
544+
fparam: paddle.Tensor | None = None,
545+
) -> tuple[
546+
paddle.Tensor,
547+
paddle.Tensor | None,
548+
paddle.Tensor | None,
549+
paddle.Tensor | None,
550+
paddle.Tensor | None,
551+
]:
497552
"""Compute the descriptor.
498553
499554
Parameters
@@ -536,6 +591,20 @@ def forward(
536591
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
537592
else:
538593
node_ebd_ext = self.type_embedding(extended_atype)
594+
595+
if self.add_chg_spin_ebd:
596+
assert fparam is not None
597+
assert self.chg_embedding is not None
598+
assert self.spin_embedding is not None
599+
charge = fparam[:, 0].to(dtype=paddle.int64) + 100
600+
spin = fparam[:, 1].to(dtype=paddle.int64)
601+
chg_ebd = self.chg_embedding(charge)
602+
spin_ebd = self.spin_embedding(spin)
603+
sys_cs_embd = self.act(
604+
self.mix_cs_mlp(paddle.concat([chg_ebd, spin_ebd], axis=-1))
605+
)
606+
node_ebd_ext = node_ebd_ext + sys_cs_embd.unsqueeze(1)
607+
539608
node_ebd_inp = node_ebd_ext[:, :nloc, :]
540609
# repflows
541610
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
@@ -550,10 +619,14 @@ def forward(
550619
node_ebd = paddle.concat([node_ebd, node_ebd_inp], axis=-1)
551620
return (
552621
node_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
553-
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
554-
edge_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
555-
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
556-
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
622+
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
623+
if rot_mat is not None
624+
else None,
625+
edge_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
626+
if edge_ebd is not None
627+
else None,
628+
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if h2 is not None else None,
629+
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if sw is not None else None,
557630
)
558631

559632
@classmethod

deepmd/pd/model/descriptor/se_a.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def forward(
288288
nlist: paddle.Tensor,
289289
mapping: paddle.Tensor | None = None,
290290
comm_dict: list[paddle.Tensor] | None = None,
291+
fparam: paddle.Tensor | None = None,
291292
) -> paddle.Tensor:
292293
"""Compute the descriptor.
293294

deepmd/pd/model/descriptor/se_t_tebd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def forward(
436436
nlist: paddle.Tensor,
437437
mapping: paddle.Tensor | None = None,
438438
comm_dict: list[paddle.Tensor] | None = None,
439+
fparam: paddle.Tensor | None = None,
439440
) -> paddle.Tensor:
440441
"""Compute the descriptor.
441442

source/tests/pd/model/test_dpa3.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_consistency(
5555
nme,
5656
prec,
5757
ect,
58+
add_chg_spin,
5859
) in itertools.product(
5960
[True, False], # update_angle
6061
["res_residual"], # update_style
@@ -65,15 +66,13 @@ def test_consistency(
6566
[1, 2], # n_multi_edge_message
6667
["float64"], # precision
6768
[False], # use_econf_tebd
69+
[False, True], # add_chg_spin_ebd
6870
):
6971
dtype = PRECISION_DICT[prec]
7072
rtol, atol = get_tols(prec)
7173
if prec == "float64":
7274
atol = 1e-8 # marginal GPU test cases...
73-
coord_ext = np.concatenate([self.coord_ext[:1], self.coord_ext[:1]], axis=0)
74-
atype_ext = np.concatenate([self.atype_ext[:1], self.atype_ext[:1]], axis=0)
75-
nlist = np.concatenate([self.nlist[:1], self.nlist[:1]], axis=0)
76-
mapping = np.concatenate([self.mapping[:1], self.mapping[:1]], axis=0)
75+
7776
repflow = RepFlowArgs(
7877
n_dim=20,
7978
e_dim=10,
@@ -105,24 +104,37 @@ def test_consistency(
105104
precision=prec,
106105
use_econf_tebd=ect,
107106
type_map=["O", "H"] if ect else None,
107+
add_chg_spin_ebd=add_chg_spin,
108108
seed=GLOBAL_SEED,
109109
).to(env.DEVICE)
110110

111111
dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype, place=env.DEVICE)
112112
dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype, place=env.DEVICE)
113+
114+
# Prepare fparam if needed
115+
fparam = None
116+
fparam_np = None
117+
if add_chg_spin:
118+
fparam = paddle.to_tensor(
119+
[[5, 1]], dtype=dtype, place=env.DEVICE
120+
).expand(nf, -1)
121+
fparam_np = np.array([[5, 1]], dtype=np.float64).repeat(nf, axis=0)
122+
113123
rd0, _, _, _, _ = dd0(
114-
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
115-
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
116-
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
117-
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
124+
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
125+
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
126+
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
127+
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
128+
fparam=fparam,
118129
)
119130
# serialization
120131
dd1 = DescrptDPA3.deserialize(dd0.serialize())
121132
rd1, _, _, _, _ = dd1(
122-
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
123-
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
124-
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
125-
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
133+
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
134+
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
135+
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
136+
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
137+
fparam=fparam,
126138
)
127139
np.testing.assert_allclose(
128140
rd0.numpy(),
@@ -132,7 +144,13 @@ def test_consistency(
132144
)
133145
# dp impl
134146
dd2 = DPDescrptDPA3.deserialize(dd0.serialize())
135-
rd2, _, _, _, _ = dd2.call(coord_ext, atype_ext, nlist, mapping)
147+
rd2, _, _, _, _ = dd2.call(
148+
self.coord_ext,
149+
self.atype_ext,
150+
self.nlist,
151+
self.mapping,
152+
fparam=fparam_np,
153+
)
136154
np.testing.assert_allclose(
137155
rd0.numpy(),
138156
rd2,

0 commit comments

Comments
 (0)