Skip to content

Commit ea7f828

Browse files
Copilotnjzjz
andauthored
fix(tf): fix serialization of dipole fitting with sel_type (#4934)
Fix #3672. Fixes backend conversion issues for dipole models when using the `sel_type` parameter. The `dp convert-backend` command was failing due to missing serialization support for `None` networks and incomplete dipole fitting serialization. - [x] Fix NetworkCollection serialization to handle `None` networks - [x] Add missing `@variables` dictionary for DipoleFittingSeA PyTorch compatibility - [x] Include `sel_type` in serialized data for proper backend conversion - [x] Fix TF fitting deserialization to skip `None` networks - [x] Add comprehensive tests for `sel_type` parameter - [x] Remove duplicate test classes and merge parameterized tests - [x] Clean up accidentally committed test output files - [x] Refactor additional_data property to return dictionary directly - [x] Resolve merge conflicts in .gitignore after rebase All tests pass and the `dp convert-backend` command now works for dipole models with `sel_type` parameter. The branch has been successfully rebased against the latest devel branch with all conflicts resolved. <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/deepmodeling/deepmd-kit/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent aca4e8c commit ea7f828

5 files changed

Lines changed: 108 additions & 6 deletions

File tree

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,12 @@ node_modules/
5959
test_dp_test/
6060
test_dp_test_*.out
6161
*_detail.out
62+
63+
# Training and model output files
64+
*.pth
65+
*.ckpt*
66+
checkpoint
67+
lcurve.out
6268
out.json
69+
input_v2_compat.json
70+
frozen_model.*

deepmd/dpmodel/utils/network.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def __init__(
935935
self._networks = [None for ii in range(ntypes**ndim)]
936936
for ii, network in enumerate(networks):
937937
self[ii] = network
938-
if len(networks):
938+
if len(networks) and all(net is not None for net in networks):
939939
self.check_completeness()
940940

941941
def check_completeness(self) -> None:
@@ -969,7 +969,9 @@ def __getitem__(self, key):
969969
return self._networks[self._convert_key(key)]
970970

971971
def __setitem__(self, key, value) -> None:
972-
if isinstance(value, self.network_type):
972+
if value is None:
973+
pass
974+
elif isinstance(value, self.network_type):
973975
pass
974976
elif isinstance(value, dict):
975977
value = self.network_type.deserialize(value)
@@ -993,7 +995,9 @@ def serialize(self) -> dict:
993995
"ndim": self.ndim,
994996
"ntypes": self.ntypes,
995997
"network_type": network_type_name,
996-
"networks": [nn.serialize() for nn in self._networks],
998+
"networks": [
999+
nn.serialize() if nn is not None else None for nn in self._networks
1000+
],
9971001
}
9981002

9991003
@classmethod

deepmd/tf/fit/dipole.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import numpy as np
77

8+
from deepmd.env import (
9+
GLOBAL_NP_FLOAT_PRECISION,
10+
)
811
from deepmd.tf.common import (
912
cast_precision,
1013
get_activation_func,
@@ -421,7 +424,9 @@ def serialize(self, suffix: str) -> dict:
421424
"dim_case_embd": self.dim_case_embd,
422425
"activation_function": self.activation_function_name,
423426
"precision": self.fitting_precision.name,
424-
"exclude_types": [],
427+
"exclude_types": []
428+
if self.sel_type is None
429+
else [ii for ii in range(self.ntypes) if ii not in self.sel_type],
425430
"nets": self.serialize_network(
426431
ntypes=self.ntypes,
427432
ndim=0 if self.mixed_types else 1,
@@ -434,6 +439,16 @@ def serialize(self, suffix: str) -> dict:
434439
trainable=self.trainable,
435440
suffix=suffix,
436441
),
442+
"@variables": {
443+
"fparam_avg": self.fparam_avg,
444+
"fparam_inv_std": self.fparam_inv_std,
445+
"aparam_avg": self.aparam_avg,
446+
"aparam_inv_std": self.aparam_inv_std,
447+
"case_embd": None,
448+
"bias_atom_e": np.zeros(
449+
(self.ntypes, self.dim_rot_mat_1), dtype=GLOBAL_NP_FLOAT_PRECISION
450+
),
451+
},
437452
"type_map": self.type_map,
438453
}
439454
return data
@@ -454,6 +469,11 @@ def deserialize(cls, data: dict, suffix: str):
454469
"""
455470
data = data.copy()
456471
check_version_compatibility(data.pop("@version", 1), 3, 1)
472+
exclude_types = data.pop("exclude_types", [])
473+
if len(exclude_types) > 0:
474+
data["sel_type"] = [
475+
ii for ii in range(data["ntypes"]) if ii not in exclude_types
476+
]
457477
fitting = cls(**data)
458478
fitting.fitting_net_variables = cls.deserialize_network(
459479
data["nets"],

deepmd/tf/fit/fitting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,9 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict:
244244
else:
245245
raise ValueError(f"Invalid ndim: {fittings.ndim}")
246246
network = fittings[net_idx]
247-
assert network is not None
247+
if network is None:
248+
# Skip types that are not selected (when sel_type is used)
249+
continue
248250
for layer_idx, layer in enumerate(network.layers):
249251
if layer_idx == len(network.layers) - 1:
250252
layer_name = "final_layer"

source/tests/consistent/fitting/test_dipole.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
(True, False), # resnet_dt
6262
("float64", "float32"), # precision
6363
(True, False), # mixed_types
64+
(None, [0]), # sel_type
6465
)
6566
class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase):
6667
@property
@@ -69,20 +70,45 @@ def data(self) -> dict:
6970
resnet_dt,
7071
precision,
7172
mixed_types,
73+
sel_type,
7274
) = self.param
73-
return {
75+
data = {
7476
"neuron": [5, 5, 5],
7577
"resnet_dt": resnet_dt,
7678
"precision": precision,
79+
"sel_type": sel_type,
7780
"seed": 20240217,
7881
}
82+
return data
83+
84+
def pass_data_to_cls(self, cls, data) -> Any:
85+
"""Pass data to the class."""
86+
if cls not in (self.tf_class,):
87+
sel_type = data.pop("sel_type", None)
88+
if sel_type is not None:
89+
all_types = list(range(self.ntypes))
90+
exclude_types = [t for t in all_types if t not in sel_type]
91+
data["exclude_types"] = exclude_types
92+
return cls(**data, **self.additional_data)
93+
94+
@property
95+
def skip_tf(self) -> bool:
96+
(
97+
resnet_dt,
98+
precision,
99+
mixed_types,
100+
sel_type,
101+
) = self.param
102+
# mixed_types + sel_type is not supported
103+
return CommonTest.skip_tf or (mixed_types and sel_type is not None)
79104

80105
@property
81106
def skip_pt(self) -> bool:
82107
(
83108
resnet_dt,
84109
precision,
85110
mixed_types,
111+
sel_type,
86112
) = self.param
87113
return CommonTest.skip_pt
88114

@@ -112,6 +138,7 @@ def additional_data(self) -> dict:
112138
resnet_dt,
113139
precision,
114140
mixed_types,
141+
sel_type,
115142
) = self.param
116143
return {
117144
"ntypes": self.ntypes,
@@ -125,6 +152,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
125152
resnet_dt,
126153
precision,
127154
mixed_types,
155+
sel_type,
128156
) = self.param
129157
return self.build_tf_fitting(
130158
obj,
@@ -141,6 +169,7 @@ def eval_pt(self, pt_obj: Any) -> Any:
141169
resnet_dt,
142170
precision,
143171
mixed_types,
172+
sel_type,
144173
) = self.param
145174
return (
146175
pt_obj(
@@ -159,6 +188,7 @@ def eval_dp(self, dp_obj: Any) -> Any:
159188
resnet_dt,
160189
precision,
161190
mixed_types,
191+
sel_type,
162192
) = self.param
163193
return dp_obj(
164194
self.inputs,
@@ -200,6 +230,7 @@ def rtol(self) -> float:
200230
resnet_dt,
201231
precision,
202232
mixed_types,
233+
sel_type,
203234
) = self.param
204235
if precision == "float64":
205236
return 1e-10
@@ -215,10 +246,47 @@ def atol(self) -> float:
215246
resnet_dt,
216247
precision,
217248
mixed_types,
249+
sel_type,
218250
) = self.param
219251
if precision == "float64":
220252
return 1e-10
221253
elif precision == "float32":
222254
return 1e-4
223255
else:
224256
raise ValueError(f"Unknown precision: {precision}")
257+
258+
def test_tf_consistent_with_ref(self) -> None:
259+
"""Test whether TF and reference are consistent."""
260+
# Special handle for sel_types
261+
if self.skip_tf:
262+
self.skipTest("Unsupported backend")
263+
ref_backend = self.get_reference_backend()
264+
if ref_backend == self.RefBackend.TF:
265+
self.skipTest("Reference is self")
266+
ret1, data1 = self.get_reference_ret_serialization(ref_backend)
267+
ret1 = self.extract_ret(ret1, ref_backend)
268+
self.reset_unique_id()
269+
tf_obj = self.tf_class.deserialize(data1, suffix=self.unique_id)
270+
ret2, data2 = self.get_tf_ret_serialization_from_cls(tf_obj)
271+
ret2 = self.extract_ret(ret2, self.RefBackend.TF)
272+
if tf_obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")):
273+
# tf, pt serialization mismatch
274+
common_keys = set(data1.keys()) & set(data2.keys())
275+
data1 = {k: data1[k] for k in common_keys}
276+
data2 = {k: data2[k] for k in common_keys}
277+
278+
# not comparing version
279+
data1.pop("@version")
280+
data2.pop("@version")
281+
282+
if tf_obj.__class__.__name__.startswith("Polar"):
283+
data1["@variables"].pop("bias_atom_e")
284+
for ii, networks in enumerate(data2["nets"]["networks"]):
285+
if networks is None:
286+
data1["nets"]["networks"][ii] = None
287+
np.testing.assert_equal(data1, data2)
288+
for rr1, rr2 in zip(ret1, ret2):
289+
np.testing.assert_allclose(
290+
rr1.ravel()[: rr2.size], rr2.ravel(), rtol=self.rtol, atol=self.atol
291+
)
292+
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

0 commit comments

Comments
 (0)