Skip to content

Commit fe9f626

Browse files
Sync Paddle DPA3 with dpmodel: add default_chg_spin, rename fparam to charge_spin
Co-authored-by: HydrogenSulfate <23737287+HydrogenSulfate@users.noreply.github.com>
1 parent f39a081 commit fe9f626

1 file changed

Lines changed: 23 additions & 4 deletions

File tree

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
use_loc_mapping: bool = True,
123123
type_map: list[str] | None = None,
124124
add_chg_spin_ebd: bool = False,
125+
default_chg_spin: list[float] | None = None,
125126
) -> None:
126127
super().__init__()
127128

@@ -177,6 +178,11 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
177178

178179
self.use_econf_tebd = use_econf_tebd
179180
self.add_chg_spin_ebd = add_chg_spin_ebd
181+
if default_chg_spin is not None and len(default_chg_spin) != 2:
182+
raise ValueError(
183+
"default_chg_spin must have exactly 2 values [charge, spin]"
184+
)
185+
self.default_chg_spin = default_chg_spin
180186
self.use_loc_mapping = use_loc_mapping
181187
self.use_tebd_bias = use_tebd_bias
182188
self.type_map = type_map
@@ -447,6 +453,18 @@ def get_stat_mean_and_stddev(
447453
stddev_list = [self.repflows.stddev]
448454
return mean_list, stddev_list
449455

456+
def get_dim_chg_spin(self) -> int:
457+
"""Returns the dimension of charge_spin input."""
458+
return 2 if self.add_chg_spin_ebd else 0
459+
460+
def has_default_chg_spin(self) -> bool:
461+
"""Returns whether default charge_spin values are set."""
462+
return self.default_chg_spin is not None
463+
464+
def get_default_chg_spin(self) -> list[float] | None:
465+
"""Returns the default charge_spin values."""
466+
return self.default_chg_spin
467+
450468
def serialize(self) -> dict:
451469
repflows = self.repflows
452470
data = {
@@ -465,6 +483,7 @@ def serialize(self) -> dict:
465483
"use_tebd_bias": self.use_tebd_bias,
466484
"use_loc_mapping": self.use_loc_mapping,
467485
"add_chg_spin_ebd": self.add_chg_spin_ebd,
486+
"default_chg_spin": self.default_chg_spin,
468487
"type_map": self.type_map,
469488
"type_embedding": self.type_embedding.embedding.serialize(),
470489
}
@@ -541,7 +560,7 @@ def forward(
541560
nlist: paddle.Tensor,
542561
mapping: paddle.Tensor | None = None,
543562
comm_dict: list[paddle.Tensor] | None = None,
544-
fparam: paddle.Tensor | None = None,
563+
charge_spin: paddle.Tensor | None = None,
545564
) -> tuple[
546565
paddle.Tensor,
547566
paddle.Tensor | None,
@@ -593,11 +612,11 @@ def forward(
593612
node_ebd_ext = self.type_embedding(extended_atype)
594613

595614
if self.add_chg_spin_ebd:
596-
assert fparam is not None
615+
assert charge_spin is not None
597616
assert self.chg_embedding is not None
598617
assert self.spin_embedding is not None
599-
charge = fparam[:, 0].to(dtype=paddle.int64) + 100
600-
spin = fparam[:, 1].to(dtype=paddle.int64)
618+
charge = charge_spin[:, 0].to(dtype=paddle.int64) + 100
619+
spin = charge_spin[:, 1].to(dtype=paddle.int64)
601620
chg_ebd = self.chg_embedding(charge)
602621
spin_ebd = self.spin_embedding(spin)
603622
sys_cs_embd = self.act(

0 commit comments

Comments
 (0)