Skip to content

Commit 2d400fe

Browse files
committed
Fix tests
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 5374bd2 commit 2d400fe

2 files changed

Lines changed: 73 additions & 11 deletions

File tree

deepmd/tf/fit/dipole.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import numpy as np
77

8+
from deepmd.env import (
9+
GLOBAL_NP_FLOAT_PRECISION,
10+
)
811
from deepmd.tf.common import (
912
cast_precision,
1013
get_activation_func,
@@ -421,8 +424,9 @@ def serialize(self, suffix: str) -> dict:
421424
"dim_case_embd": self.dim_case_embd,
422425
"activation_function": self.activation_function_name,
423426
"precision": self.fitting_precision.name,
424-
"exclude_types": [],
425-
"sel_type": self.sel_type,
427+
"exclude_types": []
428+
if self.sel_type is None
429+
else [ii for ii in range(self.ntypes) if ii not in self.sel_type],
426430
"nets": self.serialize_network(
427431
ntypes=self.ntypes,
428432
ndim=0 if self.mixed_types else 1,
@@ -441,6 +445,9 @@ def serialize(self, suffix: str) -> dict:
441445
"aparam_avg": self.aparam_avg,
442446
"aparam_inv_std": self.aparam_inv_std,
443447
"case_embd": None,
448+
"bias_atom_e": np.zeros(
449+
(self.ntypes, self.dim_rot_mat_1), dtype=GLOBAL_NP_FLOAT_PRECISION
450+
),
444451
},
445452
"type_map": self.type_map,
446453
}
@@ -462,6 +469,11 @@ def deserialize(cls, data: dict, suffix: str):
462469
"""
463470
data = data.copy()
464471
check_version_compatibility(data.pop("@version", 1), 3, 1)
472+
exclude_types = data.pop("exclude_types", [])
473+
if len(exclude_types) > 0:
474+
data["sel_type"] = [
475+
ii for ii in range(data["ntypes"]) if ii not in exclude_types
476+
]
465477
fitting = cls(**data)
466478
fitting.fitting_net_variables = cls.deserialize_network(
467479
data["nets"],

source/tests/consistent/fitting/test_dipole.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
(True, False), # resnet_dt
6262
("float64", "float32"), # precision
6363
(True, False), # mixed_types
64-
([], [0]), # sel_type
64+
(None, [0]), # sel_type
6565
)
6666
class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase):
6767
@property
@@ -76,13 +76,32 @@ def data(self) -> dict:
7676
"neuron": [5, 5, 5],
7777
"resnet_dt": resnet_dt,
7878
"precision": precision,
79+
"sel_type": sel_type,
7980
"seed": 20240217,
8081
}
81-
# Only add sel_type if it's not empty (for TF backend compatibility)
82-
if sel_type:
83-
data["sel_type"] = sel_type
8482
return data
8583

84+
def pass_data_to_cls(self, cls, data) -> Any:
85+
"""Pass data to the class."""
86+
if cls not in (self.tf_class,):
87+
sel_type = data.pop("sel_type", None)
88+
if sel_type is not None:
89+
all_types = list(range(self.ntypes))
90+
exclude_types = [t for t in all_types if t not in sel_type]
91+
data["exclude_types"] = exclude_types
92+
return cls(**data, **self.additional_data)
93+
94+
@property
95+
def skip_tf(self) -> bool:
96+
(
97+
resnet_dt,
98+
precision,
99+
mixed_types,
100+
sel_type,
101+
) = self.param
102+
# mixed_types + sel_type is not supported
103+
return CommonTest.skip_tf or (mixed_types and sel_type is not None)
104+
86105
@property
87106
def skip_pt(self) -> bool:
88107
(
@@ -127,11 +146,6 @@ def additional_data(self) -> dict:
127146
"mixed_types": mixed_types,
128147
"embedding_width": 30,
129148
}
130-
# For DP/PT backends, use exclude_types instead of sel_type
131-
if sel_type:
132-
all_types = list(range(self.ntypes))
133-
exclude_types = [t for t in all_types if t not in sel_type]
134-
additional["exclude_types"] = exclude_types
135149
return additional
136150

137151
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
@@ -241,3 +255,39 @@ def atol(self) -> float:
241255
return 1e-4
242256
else:
243257
raise ValueError(f"Unknown precision: {precision}")
258+
259+
def test_tf_consistent_with_ref(self) -> None:
260+
"""Test whether TF and reference are consistent."""
261+
# Special handle for sel_types
262+
if self.skip_tf:
263+
self.skipTest("Unsupported backend")
264+
ref_backend = self.get_reference_backend()
265+
if ref_backend == self.RefBackend.TF:
266+
self.skipTest("Reference is self")
267+
ret1, data1 = self.get_reference_ret_serialization(ref_backend)
268+
ret1 = self.extract_ret(ret1, ref_backend)
269+
self.reset_unique_id()
270+
tf_obj = self.tf_class.deserialize(data1, suffix=self.unique_id)
271+
ret2, data2 = self.get_tf_ret_serialization_from_cls(tf_obj)
272+
ret2 = self.extract_ret(ret2, self.RefBackend.TF)
273+
if tf_obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")):
274+
# tf, pt serialization mismatch
275+
common_keys = set(data1.keys()) & set(data2.keys())
276+
data1 = {k: data1[k] for k in common_keys}
277+
data2 = {k: data2[k] for k in common_keys}
278+
279+
# not comparing version
280+
data1.pop("@version")
281+
data2.pop("@version")
282+
283+
if tf_obj.__class__.__name__.startswith("Polar"):
284+
data1["@variables"].pop("bias_atom_e")
285+
for ii, networks in enumerate(data2["nets"]["networks"]):
286+
if networks is None:
287+
data1["nets"]["networks"][ii] = None
288+
np.testing.assert_equal(data1, data2)
289+
for rr1, rr2 in zip(ret1, ret2):
290+
np.testing.assert_allclose(
291+
rr1.ravel()[: rr2.size], rr2.ravel(), rtol=self.rtol, atol=self.atol
292+
)
293+
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

0 commit comments

Comments
 (0)