Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
super().init_out_stat()
self.add_chg_spin_ebd: bool = getattr(
self.descriptor, "add_chg_spin_ebd", False
)
self.enable_eval_descriptor_hook = False
self.enable_eval_fitting_last_layer_hook = False
self.eval_descriptor_list = []
Expand Down Expand Up @@ -329,6 +332,7 @@ def forward_atomic(
nlist,
mapping=mapping,
comm_dict=comm_dict,
fparam=fparam if self.add_chg_spin_ebd else None,
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pd/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,14 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
) -> paddle.Tensor:
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.

Parameters
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pd/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,14 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
) -> paddle.Tensor:
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.

Parameters
Expand Down
83 changes: 78 additions & 5 deletions deepmd/pd/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
UpdateSel,
)
from deepmd.pd.utils.utils import (
ActivationFn,
to_numpy_array,
)
from deepmd.utils.data_system import (
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: list[str] | None = None,
add_chg_spin_ebd: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -174,6 +176,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
)

self.use_econf_tebd = use_econf_tebd
self.add_chg_spin_ebd = add_chg_spin_ebd
self.use_loc_mapping = use_loc_mapping
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand All @@ -196,6 +199,34 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
self.concat_output_tebd = concat_output_tebd
self.precision = precision
self.prec = PRECISION_DICT[self.precision]

if self.add_chg_spin_ebd:
self.act = ActivationFn(activation_function)
# -100 ~ 100 is a conservative bound
self.chg_embedding = TypeEmbedNet(
200,
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 3),
)
# 100 is a conservative upper bound
self.spin_embedding = TypeEmbedNet(
100,
Comment thread
HydrogenSulfate marked this conversation as resolved.
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 4),
)
self.mix_cs_mlp = MLPLayer(
2 * self.tebd_dim,
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 5),
)
else:
self.chg_embedding = None
self.spin_embedding = None
self.mix_cs_mlp = None

self.exclude_types = exclude_types
self.env_protection = env_protection
self.trainable = trainable
Expand Down Expand Up @@ -433,9 +464,14 @@ def serialize(self) -> dict:
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"use_loc_mapping": self.use_loc_mapping,
"add_chg_spin_ebd": self.add_chg_spin_ebd,
"type_map": self.type_map,
"type_embedding": self.type_embedding.embedding.serialize(),
}
if self.add_chg_spin_ebd:
data["chg_embedding"] = self.chg_embedding.embedding.serialize()
data["spin_embedding"] = self.spin_embedding.embedding.serialize()
data["mix_cs_mlp"] = self.mix_cs_mlp.serialize()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
"angle_embd": repflows.angle_embd.serialize(),
Expand All @@ -462,12 +498,24 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
data.pop("type")
repflow_variable = data.pop("repflow_variable").copy()
type_embedding = data.pop("type_embedding")
chg_embedding = data.pop("chg_embedding", None)
spin_embedding = data.pop("spin_embedding", None)
mix_cs_mlp = data.pop("mix_cs_mlp", None)
data["repflow"] = RepFlowArgs(**data.pop("repflow_args"))
obj = cls(**data)
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
type_embedding
)

if obj.add_chg_spin_ebd and chg_embedding is not None:
obj.chg_embedding.embedding = TypeEmbedNetConsistent.deserialize(
chg_embedding
)
obj.spin_embedding.embedding = TypeEmbedNetConsistent.deserialize(
spin_embedding
)
obj.mix_cs_mlp = MLPLayer.deserialize(mix_cs_mlp)
Comment thread
HydrogenSulfate marked this conversation as resolved.

def t_cvt(xx: Any) -> paddle.Tensor:
return paddle.to_tensor(xx, dtype=obj.repflows.prec, place=env.DEVICE)

Expand All @@ -493,7 +541,14 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
) -> paddle.Tensor:
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.

Parameters
Expand Down Expand Up @@ -536,6 +591,20 @@ def forward(
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
else:
node_ebd_ext = self.type_embedding(extended_atype)

if self.add_chg_spin_ebd:
assert fparam is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None
charge = fparam[:, 0].to(dtype=paddle.int64) + 100
spin = fparam[:, 1].to(dtype=paddle.int64)
chg_ebd = self.chg_embedding(charge)
spin_ebd = self.spin_embedding(spin)
sys_cs_embd = self.act(
self.mix_cs_mlp(paddle.concat([chg_ebd, spin_ebd], axis=-1))
)
Comment thread
HydrogenSulfate marked this conversation as resolved.
node_ebd_ext = node_ebd_ext + sys_cs_embd.unsqueeze(1)

node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
Expand All @@ -550,10 +619,14 @@ def forward(
node_ebd = paddle.concat([node_ebd, node_ebd_inp], axis=-1)
return (
node_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
edge_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
if rot_mat is not None
else None,
edge_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
if edge_ebd is not None
else None,
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if h2 is not None else None,
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if sw is not None else None,
)

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion deepmd/pd/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,14 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
) -> paddle.Tensor:
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.

Parameters
Expand Down
10 changes: 9 additions & 1 deletion deepmd/pd/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
) -> paddle.Tensor:
Comment thread
HydrogenSulfate marked this conversation as resolved.
"""Compute the descriptor.

Expand Down Expand Up @@ -789,7 +790,14 @@ def forward(
extended_atype_embd: paddle.Tensor | None = None,
mapping: paddle.Tensor | None = None,
type_embedding: paddle.Tensor | None = None,
) -> paddle.Tensor:
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.

Parameters
Expand Down
44 changes: 31 additions & 13 deletions source/tests/pd/model/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_consistency(
nme,
prec,
ect,
add_chg_spin,
) in itertools.product(
[True, False], # update_angle
["res_residual"], # update_style
Expand All @@ -65,15 +66,13 @@ def test_consistency(
[1, 2], # n_multi_edge_message
["float64"], # precision
[False], # use_econf_tebd
[False, True], # add_chg_spin_ebd
):
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
if prec == "float64":
atol = 1e-8 # marginal GPU test cases...
coord_ext = np.concatenate([self.coord_ext[:1], self.coord_ext[:1]], axis=0)
atype_ext = np.concatenate([self.atype_ext[:1], self.atype_ext[:1]], axis=0)
nlist = np.concatenate([self.nlist[:1], self.nlist[:1]], axis=0)
mapping = np.concatenate([self.mapping[:1], self.mapping[:1]], axis=0)

repflow = RepFlowArgs(
n_dim=20,
e_dim=10,
Expand Down Expand Up @@ -105,24 +104,37 @@ def test_consistency(
precision=prec,
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
add_chg_spin_ebd=add_chg_spin,
seed=GLOBAL_SEED,
).to(env.DEVICE)

dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype, place=env.DEVICE)
dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype, place=env.DEVICE)

# Prepare fparam if needed
fparam = None
fparam_np = None
if add_chg_spin:
fparam = paddle.to_tensor(
[[5, 1]], dtype=dtype, place=env.DEVICE
).expand(nf, -1)
fparam_np = np.array([[5, 1]], dtype=np.float64).repeat(nf, axis=0)

rd0, _, _, _, _ = dd0(
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
fparam=fparam,
)
# serialization
dd1 = DescrptDPA3.deserialize(dd0.serialize())
rd1, _, _, _, _ = dd1(
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
fparam=fparam,
)
np.testing.assert_allclose(
rd0.numpy(),
Expand All @@ -132,7 +144,13 @@ def test_consistency(
)
# dp impl
dd2 = DPDescrptDPA3.deserialize(dd0.serialize())
rd2, _, _, _, _ = dd2.call(coord_ext, atype_ext, nlist, mapping)
rd2, _, _, _, _ = dd2.call(
self.coord_ext,
self.atype_ext,
self.nlist,
self.mapping,
fparam=fparam_np,
)
np.testing.assert_allclose(
rd0.numpy(),
rd2,
Expand Down
Loading