diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a4089468f3..fabc39ae96 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -261,8 +261,18 @@ def compute_input_stats( fparam_std, ) fparam_inv_std = 1.0 / fparam_std - self.fparam_avg = fparam_avg.astype(self.fparam_avg.dtype) - self.fparam_inv_std = fparam_inv_std.astype(self.fparam_inv_std.dtype) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(self.fparam_avg) + self.fparam_avg = xp.asarray( + fparam_avg, + dtype=self.fparam_avg.dtype, + device=array_api_compat.device(self.fparam_avg), + ) + self.fparam_inv_std = xp.asarray( + fparam_inv_std, + dtype=self.fparam_inv_std.dtype, + device=array_api_compat.device(self.fparam_inv_std), + ) # stat aparam if self.numb_aparam > 0: sys_sumv = [] @@ -284,8 +294,18 @@ def compute_input_stats( aparam_std, ) aparam_inv_std = 1.0 / aparam_std - self.aparam_avg = aparam_avg.astype(self.aparam_avg.dtype) - self.aparam_inv_std = aparam_inv_std.astype(self.aparam_inv_std.dtype) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(self.aparam_avg) + self.aparam_avg = xp.asarray( + aparam_avg, + dtype=self.aparam_avg.dtype, + device=array_api_compat.device(self.aparam_avg), + ) + self.aparam_inv_std = xp.asarray( + aparam_inv_std, + dtype=self.aparam_inv_std.dtype, + device=array_api_compat.device(self.aparam_inv_std), + ) @abstractmethod def _net_out_dim(self) -> int: @@ -566,7 +586,9 @@ def _call_common( # calculate the prediction if not self.mixed_types: outs = xp.zeros( - [nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision) + [nf, nloc, net_dim_out], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(descriptor), ) for type_i in range(self.ntypes): mask = xp.tile( diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index afda16a09b..4679412d4b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -25,8 +25,6 @@ Array, xp_add_at, xp_bincount, - xp_setitem_at, - xp_sigmoid, ) from deepmd.dpmodel.common import ( to_numpy_array, @@ -41,7 +39,15 @@ def sigmoid_t(x): # noqa: ANN001, ANN201 """Sigmoid.""" - return xp_sigmoid(x) + if array_api_compat.is_jax_array(x): + from deepmd.jax.env import ( + jax, + ) + + # see https://github.com/jax-ml/jax/discussions/15617 + return jax.nn.sigmoid(x) + xp = array_api_compat.array_namespace(x) + return 1 / (1 + xp.exp(-x)) class Identity(NativeOP): @@ -1110,10 +1116,13 @@ def deserialize(cls, data: dict) -> "FittingNet": layers = data.pop("layers") obj = cls(**data) # Use type(obj.layers[0]) to respect subclass layer types - layer_type = type(obj.layers[0]) - obj.layers = type(obj.layers)( - [layer_type.deserialize(layer) for layer in layers] - ) + if obj.layers: + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + else: + obj.layers = type(obj.layers)([]) return obj @@ -1356,7 +1365,11 @@ def get_graph_index( # noqa: ANN201 # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate edge_id = xp.arange(n_edge, dtype=nlist.dtype) edge_index = xp.zeros((nf, nloc, nnei), dtype=nlist.dtype) - edge_index = xp_setitem_at(edge_index, xp.astype(nlist_mask, xp.bool), edge_id) + if array_api_compat.is_jax_array(nlist): + # JAX doesn't support in-place item assignment + edge_index = edge_index.at[xp.astype(nlist_mask, xp.bool)].set(edge_id) + else: + edge_index[xp.astype(nlist_mask, xp.bool)] = edge_id # only cut a_nnei neighbors, to avoid nnei x nnei edge_index = edge_index[:, :, :a_nnei] edge_index_ij = xp.broadcast_to( diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index c9503ec413..f8a98abd86 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -14,8 +14,8 @@ ) -@BaseDescriptor.register("se_e2_a_expt") -@BaseDescriptor.register("se_a_expt") +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") class DescrptSeA(DescrptSeADP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index f3006d38a5..0484c0dea4 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -14,8 +14,8 @@ ) -@BaseDescriptor.register("se_e2_r_expt") -@BaseDescriptor.register("se_r_expt") +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") class DescrptSeR(DescrptSeRDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 9a438c0140..6d732790ca 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -14,9 +14,9 @@ ) -@BaseDescriptor.register("se_e3_expt") -@BaseDescriptor.register("se_at_expt") -@BaseDescriptor.register("se_a_3be_expt") +@BaseDescriptor.register("se_e3") +@BaseDescriptor.register("se_at") +@BaseDescriptor.register("se_a_3be") class DescrptSeT(DescrptSeTDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 7545f8c6fe..f28e1564cc 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -14,7 +14,7 @@ ) -@BaseDescriptor.register("se_e3_tebd_expt") +@BaseDescriptor.register("se_e3_tebd") class DescrptSeTTebd(DescrptSeTTebdDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/fitting/__init__.py b/deepmd/pt_expt/fitting/__init__.py new file mode 100644 index 0000000000..4a7c8100de --- /dev/null +++ b/deepmd/pt_expt/fitting/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_fitting import ( + BaseFitting, +) +from .ener_fitting import ( + EnergyFittingNet, +) +from .invar_fitting import ( + InvarFitting, +) + +__all__ = [ + "BaseFitting", + "EnergyFittingNet", + "InvarFitting", +] diff --git a/deepmd/pt_expt/fitting/base_fitting.py b/deepmd/pt_expt/fitting/base_fitting.py new file mode 100644 index 0000000000..f42e572578 --- /dev/null +++ b/deepmd/pt_expt/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + +from deepmd.dpmodel.fitting import ( + make_base_fitting, +) + +BaseFitting = make_base_fitting(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py new file mode 100644 index 0000000000..425040ae75 --- /dev/null +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("ener") +class EnergyFittingNet(EnergyFittingNetDP, torch.nn.Module): + """Energy fitting net for pt_expt backend. + + This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EnergyFittingNetDP.__init__(self, *args, **kwargs) + # Convert dpmodel NetworkCollection to pt_expt NetworkCollection + self.nets = NetworkCollection.deserialize(self.nets.serialize()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.call( + descriptor, + atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + EnergyFittingNetDP, + lambda v: EnergyFittingNet.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py new file mode 100644 index 0000000000..aa37026284 --- /dev/null +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + + +@BaseFitting.register("invar") +class InvarFitting(InvarFittingDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + InvarFittingDP.__init__(self, *args, **kwargs) + # Convert dpmodel NetworkCollection to pt_expt NetworkCollection + self.nets = NetworkCollection.deserialize(self.nets.serialize()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.call( + descriptor, + atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + InvarFittingDP, + lambda v: InvarFitting.deserialize(v.serialize()), +) diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 74e3b042ab..185a3d5801 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -34,6 +35,13 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: EnerFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.ener_fitting import ( + EnergyFittingNet as EnerFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + EnerFittingPTExpt = None if INSTALLED_TF: from deepmd.tf.fit.ener import EnerFitting as EnerFittingTF else: @@ -151,9 +159,23 @@ def skip_tf(self) -> bool: ) = self.param return not INSTALLED_TF or default_fparam is not None + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # PyTorch does not support bfloat16 for some operations + return CommonTest.skip_pt_expt or precision == "bfloat16" + tf_class = EnerFittingTF dp_class = EnerFittingDP pt_class = EnerFittingPT + pt_expt_class = EnerFittingPTExpt jax_class = EnerFittingJAX pd_class = EnerFittingPD array_api_strict_class = EnerFittingStrict @@ -237,6 +259,35 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if (numb_fparam and default_fparam is None) # test default_fparam + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if numb_aparam + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, @@ -367,3 +418,377 @@ def atol(self) -> float: return 1e-1 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (True,), # resnet_dt + ("float64",), # precision + (True,), # mixed_types + ((3, None),), # (numb_fparam, default_fparam) + ((3, False),), # (numb_aparam, use_aparam_as_mask) + ([],), # atom_ener +) +class TestEnerStat(CommonTest, FittingTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return { + "neuron": [5, 5, 5], + "resnet_dt": resnet_dt, + "precision": precision, + "numb_fparam": numb_fparam, + "numb_aparam": numb_aparam, + "default_fparam": default_fparam, + "seed": 20240217, + "atom_ener": atom_ener, + "use_aparam_as_mask": use_aparam_as_mask, + } + + @property + def skip_pt(self) -> bool: + return CommonTest.skip_pt + + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + + @property + def skip_tf(self) -> bool: + return True + + skip_jax = not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + return not INSTALLED_ARRAY_API_STRICT + + @property + def skip_pd(self) -> bool: + return not INSTALLED_PD + + tf_class = EnerFittingTF + dp_class = EnerFittingDP + pt_class = EnerFittingPT + pt_expt_class = EnerFittingPTExpt + jax_class = EnerFittingJAX + pd_class = EnerFittingPD + array_api_strict_class = EnerFittingStrict + args = fitting_ener() + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + # inconsistent if not sorted + self.atype.sort() + + # Prepare data for compute_input_stats + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + + # Create fparam and aparam with correct dimensions + rng = np.random.default_rng(20240217) + self.fparam = ( + rng.normal(size=(1, numb_fparam)).astype(GLOBAL_NP_FLOAT_PRECISION) + if numb_fparam > 0 + else None + ) + self.aparam = ( + rng.normal(size=(1, 6, numb_aparam)).astype(GLOBAL_NP_FLOAT_PRECISION) + if numb_aparam > 0 + else None + ) + + self.stat_data = [ + { + "fparam": rng.normal(size=(2, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + "aparam": rng.normal(size=(2, 6, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + }, + { + "fparam": rng.normal(size=(3, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + "aparam": rng.normal(size=(3, 6, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + }, + ] + + @property + def additional_data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return { + "ntypes": self.ntypes, + "dim_descrpt": self.inputs.shape[-1], + "mixed_types": mixed_types, + } + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return self.build_tf_fitting( + obj, + self.inputs.ravel(), + self.natoms, + self.atype, + self.fparam if numb_fparam else None, + self.aparam if numb_aparam else None, + suffix, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to torch tensors for pt backend + pt_stat_data = [ + { + "fparam": torch.from_numpy(d["fparam"]).to(PT_DEVICE), + "aparam": torch.from_numpy(d["aparam"]).to(PT_DEVICE), + } + for d in self.stat_data + ] + pt_obj.compute_input_stats(pt_stat_data, protection=1e-2) + return ( + pt_obj( + torch.from_numpy(self.inputs).to(device=PT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # dpmodel's compute_input_stats accepts numpy arrays + pt_expt_obj.compute_input_stats(self.stat_data, protection=1e-2) + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def eval_dp(self, dp_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + dp_obj.compute_input_stats(self.stat_data, protection=1e-2) + return dp_obj( + self.inputs, + self.atype.reshape(1, -1), + fparam=self.fparam, + aparam=self.aparam, + )["energy"] + + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to jax arrays + jax_stat_data = [ + { + "fparam": jnp.asarray(d["fparam"]), + "aparam": jnp.asarray(d["aparam"]), + } + for d in self.stat_data + ] + jax_obj.compute_input_stats(jax_stat_data, protection=1e-2) + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if self.fparam is not None else None, + aparam=jnp.asarray(self.aparam) if self.aparam is not None else None, + )["energy"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to array_api_strict arrays + strict_stat_data = [ + { + "fparam": array_api_strict.asarray(d["fparam"]), + "aparam": array_api_strict.asarray(d["aparam"]), + } + for d in self.stat_data + ] + array_api_strict_obj.compute_input_stats(strict_stat_data, protection=1e-2) + return to_numpy_array( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) + if self.fparam is not None + else None, + aparam=array_api_strict.asarray(self.aparam) + if self.aparam is not None + else None, + )["energy"] + ) + + def eval_pd(self, pd_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to paddle tensors + pd_stat_data = [ + { + "fparam": paddle.to_tensor(d["fparam"]).to(PD_DEVICE), + "aparam": paddle.to_tensor(d["aparam"]).to(PD_DEVICE), + } + for d in self.stat_data + ] + pd_obj.compute_input_stats(pd_stat_data, protection=1e-2) + return ( + pd_obj( + paddle.to_tensor(self.inputs).to(device=PD_DEVICE), + paddle.to_tensor(self.atype.reshape([1, -1])).to(device=PD_DEVICE), + fparam=( + paddle.to_tensor(self.fparam).to(device=PD_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + paddle.to_tensor(self.aparam).to(device=PD_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + if backend == self.RefBackend.TF: + # shape is not same + ret = ret[0].reshape(-1, self.natoms[0], 1) + return (ret,) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt_expt/fitting/__init__.py b/source/tests/pt_expt/fitting/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/fitting/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py b/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py new file mode 100644 index 0000000000..63ae82ab9a --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestEnergyFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + efn0 = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + efn1 = EnergyFittingNet.deserialize(efn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = efn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = efn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + """Test that EnergyFittingNet serializes with type='ener' not 'invar'.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + ).to(self.device) + serialized = efn.serialize() + + # Check that the type is 'ener' not 'invar' + self.assertEqual(serialized["type"], "ener") + + # Check that it can be deserialized + efn2 = EnergyFittingNet.deserialize(serialized).to(self.device) + self.assertIsInstance(efn2, EnergyFittingNet) + + def test_torch_export_simple(self) -> None: + """Test that EnergyFittingNet can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + # Test forward pass works + ret = efn(descriptor, atype) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + efn, + (descriptor, atype), + kwargs={}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_torch_export_with_aparam(self) -> None: + """Test that EnergyFittingNet with aparam can be exported.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=0, + numb_aparam=4, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + aparam = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, 4))).to( + self.device + ) + + # Test forward pass works + ret = efn(descriptor, atype, aparam=aparam) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + efn, + (descriptor, atype), + kwargs={"aparam": aparam}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype, aparam=aparam) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py new file mode 100644 index 0000000000..d682b37145 --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + mixed_types, + od, + nfp, + nap, + et, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + [[], [0], [1]], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + exclude_types=et, + ).to(self.device) + ifn1 = InvarFitting.deserialize(ifn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = ifn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + sel_set = set(ifn0.get_sel_type()) + exclude_set = set(et) + self.assertEqual(sel_set | exclude_set, set(range(self.nt))) + self.assertEqual(sel_set & exclude_set, set()) + + def test_mask(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + od = 2 + mixed_types = True + # exclude type 1 + et = [1] + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + mixed_types=mixed_types, + exclude_types=et, + ).to(self.device) + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + ) + # atom index 2 is of type 1 that is excluded + zero_idx = 2 + np.testing.assert_allclose( + ret0["energy"][0, zero_idx, :].detach().cpu().numpy(), + np.zeros_like(ret0["energy"][0, zero_idx, :].detach().cpu().numpy()), + ) + zero_idx = 0 + np.testing.assert_allclose( + ret0["energy"][1, zero_idx, :].detach().cpu().numpy(), + np.zeros_like(ret0["energy"][1, zero_idx, :].detach().cpu().numpy()), + ) + + def test_self_exception( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + mixed_types, + od, + nfp, + nap, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + ).to(self.device) + + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + with self.assertRaises(ValueError) as context: + ret0 = ifn0( + torch.from_numpy(dd[0][:, :, :-2]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input descriptor", str(context.exception)) + + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to( + self.device + ) + with self.assertRaises(ValueError) as context: + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input fparam", str(context.exception)) + + if nap > 0: + iap = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, nap - 1)) + ).to(self.device) + with self.assertRaises(ValueError) as context: + ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input aparam", str(context.exception)) + + def test_get_set(self) -> None: + ifn0 = InvarFitting( + "energy", + self.nt, + 3, + 1, + ).to(self.device) + rng = np.random.default_rng(GLOBAL_SEED) + foo = rng.normal([3, 4]) + for ii in [ + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ]: + ifn0[ii] = torch.from_numpy(foo).to(self.device) + np.testing.assert_allclose( + foo, ifn0[ii].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + def test_torch_export_simple(self) -> None: + """Test that InvarFitting can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + ifn = InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=0, + numb_aparam=0, + mixed_types=True, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + # Test forward pass works + ret = ifn(descriptor, atype) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + ifn, + (descriptor, atype), + kwargs={}, + strict=False, # Use strict=False for now to handle dynamic shapes + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_torch_export_with_fparam(self) -> None: + """Test that InvarFitting with fparam can be exported.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + ifn = InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=3, + numb_aparam=0, + mixed_types=True, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + fparam = torch.from_numpy(rng.normal(size=(self.nf, 3))).to(self.device) + + # Test forward pass works + ret = ifn(descriptor, atype, fparam=fparam) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + ifn, + (descriptor, atype), + kwargs={"fparam": fparam}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype, fparam=fparam) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py new file mode 100644 index 0000000000..b473c9309c --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + + +def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): + """Make fake data as numpy arrays for dpmodel compute_input_stats.""" + merged_output_stat = [] + nsys = len(sys_natoms) + ndof = len(avgs) + for ii in range(nsys): + sys_dict = {} + tmp_data_f = [] + tmp_data_a = [] + for jj in range(ndof): + rng = np.random.default_rng(2025 * ii + 220 * jj) + tmp_data_f.append( + rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1)) + ) + rng = np.random.default_rng(220 * ii + 1636 * jj) + tmp_data_a.append( + rng.normal( + loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii]) + ) + ) + tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0)) + tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0)) + # dpmodel's compute_input_stats expects numpy arrays + sys_dict["fparam"] = tmp_data_f + sys_dict["aparam"] = tmp_data_a + merged_output_stat.append(sys_dict) + return merged_output_stat + + +def _brute_fparam_pt(data, ndim): + adata = [ii["fparam"] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +def _brute_aparam_pt(data, ndim): + adata = [ii["aparam"] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +class TestEnerFittingStat(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + + def test(self) -> None: + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=3, + numb_aparam=3, + ).to(self.device) + avgs = [0, 10, 100] + stds = [2, 0.4, 0.00001] + sys_natoms = [10, 100] + sys_nframes = [5, 2] + all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) + frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) + arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) + fitting.compute_input_stats(all_data, protection=1e-2) + frefs_inv = 1.0 / frefs + arefs_inv = 1.0 / arefs + frefs_inv[frefs_inv > 100] = 100 + arefs_inv[arefs_inv > 100] = 100 + # fparam_avg and fparam_inv_std are torch tensors on device + fparam_avg_np = ( + fitting.fparam_avg.detach().cpu().numpy() + if torch.is_tensor(fitting.fparam_avg) + else fitting.fparam_avg + ) + fparam_inv_std_np = ( + fitting.fparam_inv_std.detach().cpu().numpy() + if torch.is_tensor(fitting.fparam_inv_std) + else fitting.fparam_inv_std + ) + aparam_avg_np = ( + fitting.aparam_avg.detach().cpu().numpy() + if torch.is_tensor(fitting.aparam_avg) + else fitting.aparam_avg + ) + aparam_inv_std_np = ( + fitting.aparam_inv_std.detach().cpu().numpy() + if torch.is_tensor(fitting.aparam_inv_std) + else fitting.aparam_inv_std + ) + np.testing.assert_almost_equal(frefa, fparam_avg_np) + np.testing.assert_almost_equal(frefs_inv, fparam_inv_std_np) + np.testing.assert_almost_equal(arefa, aparam_avg_np) + np.testing.assert_almost_equal(arefs_inv, aparam_inv_std_np)