Skip to content

Commit 7a96536

Browse files
committed
fix ut
1 parent 623269f commit 7a96536

2 files changed

Lines changed: 18 additions & 4 deletions

File tree

  • deepmd/jax/descriptor
  • source/tests/array_api_strict/descriptor

deepmd/jax/descriptor/dpa3.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
flax_version,
2424
nnx,
2525
)
26+
from deepmd.jax.utils.network import (
27+
NativeLayer,
28+
)
2629
from deepmd.jax.utils.type_embed import (
2730
TypeEmbedNet,
2831
)
@@ -40,8 +43,12 @@ def __setattr__(self, name: str, value: Any) -> None:
4043
value = nnx.data(value)
4144
elif name in {"repflows"}:
4245
value = DescrptBlockRepflows.deserialize(value.serialize())
43-
elif name in {"type_embedding"}:
44-
value = TypeEmbedNet.deserialize(value.serialize())
46+
elif name in {"type_embedding", "chg_embedding", "spin_embedding"}:
47+
if value is not None:
48+
value = TypeEmbedNet.deserialize(value.serialize())
49+
elif name in {"mix_cs_mlp"}:
50+
if value is not None:
51+
value = NativeLayer.deserialize(value.serialize())
4552
else:
4653
pass
4754
return super().__setattr__(name, value)

source/tests/array_api_strict/descriptor/dpa3.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from ..common import (
99
to_array_api_strict_array,
1010
)
11+
from ..utils.network import (
12+
NativeLayer,
13+
)
1114
from ..utils.type_embed import (
1215
TypeEmbedNet,
1316
)
@@ -26,8 +29,12 @@ def __setattr__(self, name: str, value: Any) -> None:
2629
value = to_array_api_strict_array(value)
2730
elif name in {"repflows"}:
2831
value = DescrptBlockRepflows.deserialize(value.serialize())
29-
elif name in {"type_embedding"}:
30-
value = TypeEmbedNet.deserialize(value.serialize())
32+
elif name in {"type_embedding", "chg_embedding", "spin_embedding"}:
33+
if value is not None:
34+
value = TypeEmbedNet.deserialize(value.serialize())
35+
elif name in {"mix_cs_mlp"}:
36+
if value is not None:
37+
value = NativeLayer.deserialize(value.serialize())
3138
else:
3239
pass
3340
return super().__setattr__(name, value)

0 commit comments

Comments
 (0)