diff --git a/.gitignore b/.gitignore index 4fde04f91b..7528c5c2f2 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,12 @@ node_modules/ test_dp_test/ test_dp_test_*.out *_detail.out + +# Training and model output files +*.pth +*.ckpt* +checkpoint +lcurve.out out.json +input_v2_compat.json +frozen_model.* diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 9c51d70778..64cbb6e8e7 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -935,7 +935,7 @@ def __init__( self._networks = [None for ii in range(ntypes**ndim)] for ii, network in enumerate(networks): self[ii] = network - if len(networks): + if len(networks) and all(net is not None for net in networks): self.check_completeness() def check_completeness(self) -> None: @@ -969,7 +969,9 @@ def __getitem__(self, key): return self._networks[self._convert_key(key)] def __setitem__(self, key, value) -> None: - if isinstance(value, self.network_type): + if value is None: + pass + elif isinstance(value, self.network_type): pass elif isinstance(value, dict): value = self.network_type.deserialize(value) @@ -993,7 +995,9 @@ def serialize(self) -> dict: "ndim": self.ndim, "ntypes": self.ntypes, "network_type": network_type_name, - "networks": [nn.serialize() for nn in self._networks], + "networks": [ + nn.serialize() if nn is not None else None for nn in self._networks + ], } @classmethod diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index d9cb0002cb..a081f38b17 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -5,6 +5,9 @@ import numpy as np +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) from deepmd.tf.common import ( cast_precision, get_activation_func, @@ -421,7 +424,9 @@ def serialize(self, suffix: str) -> dict: "dim_case_embd": self.dim_case_embd, "activation_function": self.activation_function_name, "precision": self.fitting_precision.name, - "exclude_types": [], + "exclude_types": [] + if self.sel_type is None + else [ii for ii in range(self.ntypes) if ii not in self.sel_type], "nets": self.serialize_network( ntypes=self.ntypes, ndim=0 if self.mixed_types else 1, @@ -434,6 +439,16 @@ def serialize(self, suffix: str) -> dict: trainable=self.trainable, suffix=suffix, ), + "@variables": { + "fparam_avg": self.fparam_avg, + "fparam_inv_std": self.fparam_inv_std, + "aparam_avg": self.aparam_avg, + "aparam_inv_std": self.aparam_inv_std, + "case_embd": None, + "bias_atom_e": np.zeros( + (self.ntypes, self.dim_rot_mat_1), dtype=GLOBAL_NP_FLOAT_PRECISION + ), + }, "type_map": self.type_map, } return data @@ -454,6 +469,11 @@ def deserialize(cls, data: dict, suffix: str): """ data = data.copy() check_version_compatibility(data.pop("@version", 1), 3, 1) + exclude_types = data.pop("exclude_types", []) + if len(exclude_types) > 0: + data["sel_type"] = [ + ii for ii in range(data["ntypes"]) if ii not in exclude_types + ] fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( data["nets"], diff --git a/deepmd/tf/fit/fitting.py b/deepmd/tf/fit/fitting.py index 4f7436a52c..0e109fea60 100644 --- a/deepmd/tf/fit/fitting.py +++ b/deepmd/tf/fit/fitting.py @@ -244,7 +244,9 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict: else: raise ValueError(f"Invalid ndim: {fittings.ndim}") network = fittings[net_idx] - assert network is not None + if network is None: + # Skip types that are not selected (when sel_type is used) + continue for layer_idx, layer in enumerate(network.layers): if layer_idx == len(network.layers) - 1: layer_name = "final_layer" diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 396ee2d492..010944d109 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -61,6 +61,7 @@ (True, False), # resnet_dt ("float64", "float32"), # precision (True, False), # mixed_types + (None, [0]), # sel_type ) class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase): @property @@ -69,13 +70,37 @@ def data(self) -> dict: resnet_dt, precision, mixed_types, + sel_type, ) = self.param - return { + data = { "neuron": [5, 5, 5], "resnet_dt": resnet_dt, "precision": precision, + "sel_type": sel_type, "seed": 20240217, } + return data + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + if cls not in (self.tf_class,): + sel_type = data.pop("sel_type", None) + if sel_type is not None: + all_types = list(range(self.ntypes)) + exclude_types = [t for t in all_types if t not in sel_type] + data["exclude_types"] = exclude_types + return cls(**data, **self.additional_data) + + @property + def skip_tf(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + sel_type, + ) = self.param + # mixed_types + sel_type is not supported + return CommonTest.skip_tf or (mixed_types and sel_type is not None) @property def skip_pt(self) -> bool: @@ -83,6 +108,7 @@ def skip_pt(self) -> bool: resnet_dt, precision, mixed_types, + sel_type, ) = self.param return CommonTest.skip_pt @@ -112,6 +138,7 @@ def additional_data(self) -> dict: resnet_dt, precision, mixed_types, + sel_type, ) = self.param return { "ntypes": self.ntypes, @@ -125,6 +152,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: resnet_dt, precision, mixed_types, + sel_type, ) = self.param return self.build_tf_fitting( obj, @@ -141,6 +169,7 @@ def eval_pt(self, pt_obj: Any) -> Any: resnet_dt, precision, mixed_types, + sel_type, ) = self.param return ( pt_obj( @@ -159,6 +188,7 @@ def eval_dp(self, dp_obj: Any) -> Any: resnet_dt, precision, mixed_types, + sel_type, ) = self.param return dp_obj( self.inputs, @@ -200,6 +230,7 @@ def rtol(self) -> float: resnet_dt, precision, mixed_types, + sel_type, ) = self.param if precision == "float64": return 1e-10 @@ -215,6 +246,7 @@ def atol(self) -> float: resnet_dt, precision, mixed_types, + sel_type, ) = self.param if precision == "float64": return 1e-10 @@ -222,3 +254,39 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + def test_tf_consistent_with_ref(self) -> None: + """Test whether TF and reference are consistent.""" + # Special handle for sel_types + if self.skip_tf: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.TF: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + self.reset_unique_id() + tf_obj = self.tf_class.deserialize(data1, suffix=self.unique_id) + ret2, data2 = self.get_tf_ret_serialization_from_cls(tf_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.TF) + if tf_obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")): + # tf, pt serialization mismatch + common_keys = set(data1.keys()) & set(data2.keys()) + data1 = {k: data1[k] for k in common_keys} + data2 = {k: data2[k] for k in common_keys} + + # not comparing version + data1.pop("@version") + data2.pop("@version") + + if tf_obj.__class__.__name__.startswith("Polar"): + data1["@variables"].pop("bias_atom_e") + for ii, networks in enumerate(data2["nets"]["networks"]): + if networks is None: + data1["nets"]["networks"][ii] = None + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose( + rr1.ravel()[: rr2.size], rr2.ravel(), rtol=self.rtol, atol=self.atol + ) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"