Skip to content

Commit 24325c1

Browse files
committed
add UT for pt2tf model.deserialize
1 parent 9b51508 commit 24325c1

2 files changed

Lines changed: 16 additions & 0 deletions

File tree

source/tests/consistent/model/test_dipole.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,11 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
204204
ret[1].ravel(),
205205
)
206206
raise ValueError(f"Unknown backend: {backend}")
207+
208+
def test_atom_exclude_types(self):
209+
_ret, data = self.get_reference_ret_serialization(self.RefBackend.PT)
210+
data["atom_exclude_types"] = [1]
211+
self.reset_unique_id()
212+
tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id)
213+
pt_obj = self.pt_class.deserialize(data)
214+
self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type())

source/tests/consistent/model/test_polar.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,11 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
198198
ret[1].ravel(),
199199
)
200200
raise ValueError(f"Unknown backend: {backend}")
201+
202+
def test_atom_exclude_types(self):
203+
_ret, data = self.get_reference_ret_serialization(self.RefBackend.PT)
204+
data["atom_exclude_types"] = [1]
205+
self.reset_unique_id()
206+
tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id)
207+
pt_obj = self.pt_class.deserialize(data)
208+
self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type())

0 commit comments

Comments
 (0)