Skip to content
22 changes: 22 additions & 0 deletions deepmd/pd/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,22 @@ def reinit_pair_exclude(
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

def has_chg_spin_ebd(self) -> bool:
"""Check if the model has charge spin embedding."""
return False

def get_dim_chg_spin(self) -> int:
"""Get the dimension of charge_spin input."""
return 0

def has_default_chg_spin(self) -> bool:
"""Check if the model has default charge_spin values."""
return False

def get_default_chg_spin(self) -> paddle.Tensor | None:
"""Get the default charge_spin values."""
return None

# to make jit happy...
def make_atom_mask(
self,
Expand Down Expand Up @@ -229,6 +245,7 @@ def forward_common_atomic(
fparam: paddle.Tensor | None = None,
aparam: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
charge_spin: paddle.Tensor | None = None,
) -> dict[str, paddle.Tensor]:
"""Common interface for atomic inference.

Expand All @@ -252,6 +269,8 @@ def forward_common_atomic(
atomic parameter, shape: nf x nloc x dim_aparam
comm_dict
The data needed for communication for parallel inference.
charge_spin
charge and spin parameters, shape: nf x 2

Returns
-------
Expand Down Expand Up @@ -282,6 +301,7 @@ def forward_common_atomic(
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
charge_spin=charge_spin,
)
ret_dict = self.apply_out_stat(ret_dict, atype)
# nf x nloc
Expand Down Expand Up @@ -311,6 +331,7 @@ def forward(
fparam: paddle.Tensor | None = None,
aparam: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
charge_spin: paddle.Tensor | None = None,
) -> dict[str, paddle.Tensor]:
return self.forward_common_atomic(
extended_coord,
Expand All @@ -320,6 +341,7 @@ def forward(
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
charge_spin=charge_spin,
)

def change_type_map(
Expand Down
35 changes: 34 additions & 1 deletion deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def forward_atomic(
fparam: paddle.Tensor | None = None,
aparam: paddle.Tensor | None = None,
comm_dict: dict[str, paddle.Tensor] | None = None,
charge_spin: paddle.Tensor | None = None,
) -> dict[str, paddle.Tensor]:
"""Return atomic prediction.

Expand All @@ -315,6 +316,8 @@ def forward_atomic(
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
charge_spin
charge and spin parameters. nf x 2

Returns
-------
Expand All @@ -326,13 +329,21 @@ def forward_atomic(
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
extended_coord.stop_gradient = False

# Handle default chg_spin if descriptor supports it
if self.add_chg_spin_ebd and charge_spin is None:
default_cs_tensor = self.descriptor.get_default_chg_spin()
if default_cs_tensor is not None:
default_cs_tensor = default_cs_tensor.to(extended_coord.place)
charge_spin = paddle.tile(default_cs_tensor.unsqueeze(0), [nframes, 1])

descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
fparam=fparam if self.add_chg_spin_ebd else None,
charge_spin=charge_spin if self.add_chg_spin_ebd else None,
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
Expand Down Expand Up @@ -466,3 +477,25 @@ def is_aparam_nall(self) -> bool:
If False, the shape is (nframes, nloc, ndim).
"""
return False

def has_chg_spin_ebd(self) -> bool:
"""Check if the model has charge spin embedding."""
return self.add_chg_spin_ebd

def get_dim_chg_spin(self) -> int:
"""Get the dimension of charge_spin input."""
if self.add_chg_spin_ebd:
return self.descriptor.get_dim_chg_spin()
return 0

def has_default_chg_spin(self) -> bool:
"""Check if the model has default charge_spin values."""
if self.add_chg_spin_ebd:
return self.descriptor.has_default_chg_spin()
return False

def get_default_chg_spin(self) -> paddle.Tensor | None:
"""Get the default charge_spin values as a tensor."""
if self.add_chg_spin_ebd and self.descriptor.has_default_chg_spin():
return self.descriptor.get_default_chg_spin()
return None
13 changes: 13 additions & 0 deletions deepmd/pd/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,18 @@ def get_buffer_type_map(self) -> paddle.Tensor:
"""
return self.buffer_type_map

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input (0 if not supported)."""
return 0

def has_default_chg_spin(self) -> bool:
"""Returns whether the descriptor has a default charge_spin value."""
return False

def get_default_chg_spin(self) -> None:
"""Returns the default charge_spin value, or None."""
return None

def get_dim_out(self) -> int:
"""Returns the output dimension."""
ret = self.se_atten.get_dim_out()
Expand Down Expand Up @@ -627,6 +639,7 @@ def forward(
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
charge_spin: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
Expand Down
16 changes: 13 additions & 3 deletions deepmd/pd/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,18 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
param.stop_gradient = not trainable
self.compress = False

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input (0 if not supported)."""
return 0

def has_default_chg_spin(self) -> bool:
"""Returns whether the descriptor has a default charge_spin value."""
return False

def get_default_chg_spin(self) -> None:
"""Returns the default charge_spin value, or None."""
return None

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
Expand Down Expand Up @@ -734,10 +746,8 @@ def forward(
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
charge_spin: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
Expand Down
32 changes: 28 additions & 4 deletions deepmd/pd/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
use_loc_mapping: bool = True,
type_map: list[str] | None = None,
add_chg_spin_ebd: bool = False,
default_chg_spin: list[float] | None = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -177,6 +178,11 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:

self.use_econf_tebd = use_econf_tebd
self.add_chg_spin_ebd = add_chg_spin_ebd
if default_chg_spin is not None and len(default_chg_spin) != 2:
raise ValueError(
"default_chg_spin must have exactly 2 values [charge, spin]"
)
self.default_chg_spin = default_chg_spin
self.use_loc_mapping = use_loc_mapping
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand Down Expand Up @@ -447,6 +453,23 @@ def get_stat_mean_and_stddev(
stddev_list = [self.repflows.stddev]
return mean_list, stddev_list

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input."""
return 2 if self.add_chg_spin_ebd else 0

def has_default_chg_spin(self) -> bool:
"""Returns whether default charge_spin values are set."""
return self.default_chg_spin is not None

def get_default_chg_spin(self) -> paddle.Tensor | None:
"""Get the default charge_spin values as a tensor."""
if self.default_chg_spin is None:
return None
return paddle.to_tensor(
self.default_chg_spin,
dtype=self.prec,
)

def serialize(self) -> dict:
repflows = self.repflows
data = {
Expand All @@ -465,6 +488,7 @@ def serialize(self) -> dict:
"use_tebd_bias": self.use_tebd_bias,
"use_loc_mapping": self.use_loc_mapping,
"add_chg_spin_ebd": self.add_chg_spin_ebd,
"default_chg_spin": self.default_chg_spin,
"type_map": self.type_map,
"type_embedding": self.type_embedding.embedding.serialize(),
}
Expand Down Expand Up @@ -541,7 +565,7 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
charge_spin: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
Expand Down Expand Up @@ -593,11 +617,11 @@ def forward(
node_ebd_ext = self.type_embedding(extended_atype)

if self.add_chg_spin_ebd:
assert fparam is not None
assert charge_spin is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None
charge = fparam[:, 0].to(dtype=paddle.int64) + 100
spin = fparam[:, 1].to(dtype=paddle.int64)
charge = charge_spin[:, 0].to(dtype=paddle.int64) + 100
spin = charge_spin[:, 1].to(dtype=paddle.int64)
chg_ebd = self.chg_embedding(charge)
spin_ebd = self.spin_embedding(spin)
sys_cs_embd = self.act(
Expand Down
13 changes: 13 additions & 0 deletions deepmd/pd/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def __init__(
seed=seed,
)

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input (0 if not supported)."""
return 0

def has_default_chg_spin(self) -> bool:
"""Returns whether the descriptor has a default charge_spin value."""
return False

def get_default_chg_spin(self) -> None:
"""Returns the default charge_spin value, or None."""
return None

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.sea.get_rcut()
Expand Down Expand Up @@ -289,6 +301,7 @@ def forward(
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
charge_spin: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
Expand Down
13 changes: 13 additions & 0 deletions deepmd/pd/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ def __init__(
for param in self.parameters():
param.stop_gradient = not trainable

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input (0 if not supported)."""
return 0

def has_default_chg_spin(self) -> bool:
"""Returns whether the descriptor has a default charge_spin value."""
return False

def get_default_chg_spin(self) -> None:
"""Returns the default charge_spin value, or None."""
return None

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.se_ttebd.get_rcut()
Expand Down Expand Up @@ -438,6 +450,7 @@ def forward(
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
charge_spin: paddle.Tensor | None = None,
) -> paddle.Tensor:
"""Compute the descriptor.

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pd/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def forward(
fparam: paddle.Tensor | None = None,
aparam: paddle.Tensor | None = None,
do_atomic_virial: bool = False,
charge_spin: paddle.Tensor | None = None,
) -> dict[str, paddle.Tensor]:
model_ret = self.forward_common(
coord,
Expand All @@ -81,6 +82,7 @@ def forward(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
charge_spin=charge_spin,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down Expand Up @@ -117,6 +119,7 @@ def forward_lower(
aparam: paddle.Tensor | None = None,
do_atomic_virial: bool = False,
comm_dict: list[paddle.Tensor] | None = None,
charge_spin: paddle.Tensor | None = None,
) -> dict[str, paddle.Tensor]:
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -128,6 +131,7 @@ def forward_lower(
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
charge_spin=charge_spin,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
Loading
Loading