Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
super().init_out_stat()
self.add_chg_spin_ebd: bool = getattr(
self.descriptor, "add_chg_spin_ebd", False
)
self.enable_eval_descriptor_hook = False
self.enable_eval_fitting_last_layer_hook = False
self.eval_descriptor_list = []
Expand Down Expand Up @@ -329,6 +332,7 @@ def forward_atomic(
nlist,
mapping=mapping,
comm_dict=comm_dict,
fparam=fparam if self.add_chg_spin_ebd else None,
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pd/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
) -> paddle.Tensor:
Comment thread
HydrogenSulfate marked this conversation as resolved.
Outdated
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/pd/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
) -> paddle.Tensor:
Comment thread
njzjz marked this conversation as resolved.
Outdated
"""Compute the descriptor.

Expand Down
83 changes: 78 additions & 5 deletions deepmd/pd/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
UpdateSel,
)
from deepmd.pd.utils.utils import (
ActivationFn,
to_numpy_array,
)
from deepmd.utils.data_system import (
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: list[str] | None = None,
add_chg_spin_ebd: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -174,6 +176,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
)

self.use_econf_tebd = use_econf_tebd
self.add_chg_spin_ebd = add_chg_spin_ebd
self.use_loc_mapping = use_loc_mapping
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand All @@ -196,6 +199,34 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
self.concat_output_tebd = concat_output_tebd
self.precision = precision
self.prec = PRECISION_DICT[self.precision]

if self.add_chg_spin_ebd:
self.act = ActivationFn(activation_function)
# -100 ~ 100 is a conservative bound
self.chg_embedding = TypeEmbedNet(
200,
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 3),
)
# 100 is a conservative upper bound
self.spin_embedding = TypeEmbedNet(
100,
Comment thread
HydrogenSulfate marked this conversation as resolved.
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 4),
)
self.mix_cs_mlp = MLPLayer(
2 * self.tebd_dim,
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 5),
)
else:
self.chg_embedding = None
self.spin_embedding = None
self.mix_cs_mlp = None

self.exclude_types = exclude_types
self.env_protection = env_protection
self.trainable = trainable
Expand Down Expand Up @@ -433,9 +464,14 @@ def serialize(self) -> dict:
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"use_loc_mapping": self.use_loc_mapping,
"add_chg_spin_ebd": self.add_chg_spin_ebd,
"type_map": self.type_map,
"type_embedding": self.type_embedding.embedding.serialize(),
}
if self.add_chg_spin_ebd:
data["chg_embedding"] = self.chg_embedding.embedding.serialize()
data["spin_embedding"] = self.spin_embedding.embedding.serialize()
data["mix_cs_mlp"] = self.mix_cs_mlp.serialize()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
"angle_embd": repflows.angle_embd.serialize(),
Expand All @@ -462,12 +498,24 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
data.pop("type")
repflow_variable = data.pop("repflow_variable").copy()
type_embedding = data.pop("type_embedding")
chg_embedding = data.pop("chg_embedding", None)
spin_embedding = data.pop("spin_embedding", None)
mix_cs_mlp = data.pop("mix_cs_mlp", None)
data["repflow"] = RepFlowArgs(**data.pop("repflow_args"))
obj = cls(**data)
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
type_embedding
)

if obj.add_chg_spin_ebd and chg_embedding is not None:
obj.chg_embedding.embedding = TypeEmbedNetConsistent.deserialize(
chg_embedding
)
obj.spin_embedding.embedding = TypeEmbedNetConsistent.deserialize(
spin_embedding
)
obj.mix_cs_mlp = MLPLayer.deserialize(mix_cs_mlp)
Comment thread
HydrogenSulfate marked this conversation as resolved.

def t_cvt(xx: Any) -> paddle.Tensor:
return paddle.to_tensor(xx, dtype=obj.repflows.prec, place=env.DEVICE)

Expand All @@ -493,7 +541,14 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
) -> paddle.Tensor:
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.

Parameters
Expand Down Expand Up @@ -536,6 +591,20 @@ def forward(
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
else:
node_ebd_ext = self.type_embedding(extended_atype)

if self.add_chg_spin_ebd:
assert fparam is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None
charge = fparam[:, 0].to(dtype=paddle.int64) + 100
spin = fparam[:, 1].to(dtype=paddle.int64)
chg_ebd = self.chg_embedding(charge)
spin_ebd = self.spin_embedding(spin)
sys_cs_embd = self.act(
self.mix_cs_mlp(paddle.concat([chg_ebd, spin_ebd], axis=-1))
)
Comment thread
HydrogenSulfate marked this conversation as resolved.
node_ebd_ext = node_ebd_ext + sys_cs_embd.unsqueeze(1)

node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
Expand All @@ -550,10 +619,14 @@ def forward(
node_ebd = paddle.concat([node_ebd, node_ebd_inp], axis=-1)
return (
node_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
edge_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
if rot_mat is not None
else None,
edge_ebd.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION)
if edge_ebd is not None
else None,
h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if h2 is not None else None,
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if sw is not None else None,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions deepmd/pd/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
) -> paddle.Tensor:
Comment thread
HydrogenSulfate marked this conversation as resolved.
Outdated
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/pd/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def forward(
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
) -> paddle.Tensor:
Comment thread
HydrogenSulfate marked this conversation as resolved.
"""Compute the descriptor.

Expand Down
44 changes: 31 additions & 13 deletions source/tests/pd/model/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_consistency(
nme,
prec,
ect,
add_chg_spin,
) in itertools.product(
[True, False], # update_angle
["res_residual"], # update_style
Expand All @@ -65,15 +66,13 @@ def test_consistency(
[1, 2], # n_multi_edge_message
["float64"], # precision
[False], # use_econf_tebd
[False, True], # add_chg_spin_ebd
):
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
if prec == "float64":
atol = 1e-8 # marginal GPU test cases...
coord_ext = np.concatenate([self.coord_ext[:1], self.coord_ext[:1]], axis=0)
atype_ext = np.concatenate([self.atype_ext[:1], self.atype_ext[:1]], axis=0)
nlist = np.concatenate([self.nlist[:1], self.nlist[:1]], axis=0)
mapping = np.concatenate([self.mapping[:1], self.mapping[:1]], axis=0)

repflow = RepFlowArgs(
n_dim=20,
e_dim=10,
Expand Down Expand Up @@ -105,24 +104,37 @@ def test_consistency(
precision=prec,
use_econf_tebd=ect,
type_map=["O", "H"] if ect else None,
add_chg_spin_ebd=add_chg_spin,
seed=GLOBAL_SEED,
).to(env.DEVICE)

dd0.repflows.mean = paddle.to_tensor(davg, dtype=dtype, place=env.DEVICE)
dd0.repflows.stddev = paddle.to_tensor(dstd, dtype=dtype, place=env.DEVICE)

# Prepare fparam if needed
fparam = None
fparam_np = None
if add_chg_spin:
fparam = paddle.to_tensor(
[[5, 1]], dtype=dtype, place=env.DEVICE
).expand(nf, -1)
fparam_np = np.array([[5, 1]], dtype=np.float64).repeat(nf, axis=0)

rd0, _, _, _, _ = dd0(
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
fparam=fparam,
)
# serialization
dd1 = DescrptDPA3.deserialize(dd0.serialize())
rd1, _, _, _, _ = dd1(
paddle.to_tensor(coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(mapping, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.coord_ext, dtype=dtype, place=env.DEVICE),
paddle.to_tensor(self.atype_ext, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.nlist, dtype=paddle.int64, place=env.DEVICE),
paddle.to_tensor(self.mapping, dtype=paddle.int64, place=env.DEVICE),
fparam=fparam,
)
np.testing.assert_allclose(
rd0.numpy(),
Expand All @@ -132,7 +144,13 @@ def test_consistency(
)
# dp impl
dd2 = DPDescrptDPA3.deserialize(dd0.serialize())
rd2, _, _, _, _ = dd2.call(coord_ext, atype_ext, nlist, mapping)
rd2, _, _, _, _ = dd2.call(
self.coord_ext,
self.atype_ext,
self.nlist,
self.mapping,
fparam=fparam_np,
)
np.testing.assert_allclose(
rd0.numpy(),
rd2,
Expand Down
Loading