From 8356629395cd1e183373f31f76cc32675c609287 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 22 Feb 2026 16:10:11 +0800 Subject: [PATCH 1/3] feat(pt_expt): add dos, dipole, polar and property fittings --- .../dpmodel/fitting/polarizability_fitting.py | 4 +- deepmd/pt_expt/fitting/__init__.py | 16 +++ deepmd/pt_expt/fitting/dipole_fitting.py | 23 ++++ deepmd/pt_expt/fitting/dos_fitting.py | 23 ++++ .../pt_expt/fitting/polarizability_fitting.py | 23 ++++ deepmd/pt_expt/fitting/property_fitting.py | 25 ++++ .../tests/consistent/fitting/test_dipole.py | 25 ++++ source/tests/consistent/fitting/test_dos.py | 36 +++++ source/tests/consistent/fitting/test_polar.py | 25 ++++ .../tests/consistent/fitting/test_property.py | 39 ++++++ .../pt_expt/fitting/test_dipole_fitting.py | 128 ++++++++++++++++++ .../tests/pt_expt/fitting/test_dos_fitting.py | 120 ++++++++++++++++ ...g_ener_fitting.py => test_ener_fitting.py} | 0 ...invar_fitting.py => test_invar_fitting.py} | 0 .../pt_expt/fitting/test_polar_fitting.py | 125 +++++++++++++++++ .../pt_expt/fitting/test_property_fitting.py | 120 ++++++++++++++++ 16 files changed, 731 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/fitting/dipole_fitting.py create mode 100644 deepmd/pt_expt/fitting/dos_fitting.py create mode 100644 deepmd/pt_expt/fitting/polarizability_fitting.py create mode 100644 deepmd/pt_expt/fitting/property_fitting.py create mode 100644 source/tests/pt_expt/fitting/test_dipole_fitting.py create mode 100644 source/tests/pt_expt/fitting/test_dos_fitting.py rename source/tests/pt_expt/fitting/{test_fitting_ener_fitting.py => test_ener_fitting.py} (100%) rename source/tests/pt_expt/fitting/{test_fitting_invar_fitting.py => test_invar_fitting.py} (100%) create mode 100644 source/tests/pt_expt/fitting/test_polar_fitting.py create mode 100644 source/tests/pt_expt/fitting/test_property_fitting.py diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index f3e6318ba5..dff86f04cb 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -328,7 +328,9 @@ def call( ) # (nframes, nloc, 1) bias = bias[..., None] * scale_atype - eye = xp.eye(3, dtype=descriptor.dtype) + eye = xp.eye( + 3, dtype=descriptor.dtype, device=array_api_compat.device(descriptor) + ) eye = xp.tile(eye, (nframes, nloc, 1, 1)) # (nframes, nloc, 3, 3) bias = bias[..., None] * eye diff --git a/deepmd/pt_expt/fitting/__init__.py b/deepmd/pt_expt/fitting/__init__.py index 4a7c8100de..3b69392cfd 100644 --- a/deepmd/pt_expt/fitting/__init__.py +++ b/deepmd/pt_expt/fitting/__init__.py @@ -2,15 +2,31 @@ from .base_fitting import ( BaseFitting, ) +from .dipole_fitting import ( + DipoleFitting, +) +from .dos_fitting import ( + DOSFittingNet, +) from .ener_fitting import ( EnergyFittingNet, ) from .invar_fitting import ( InvarFitting, ) +from .polarizability_fitting import ( + PolarFitting, +) +from .property_fitting import ( + PropertyFittingNet, +) __all__ = [ "BaseFitting", + "DOSFittingNet", + "DipoleFitting", "EnergyFittingNet", "InvarFitting", + "PolarFitting", + "PropertyFittingNet", ] diff --git a/deepmd/pt_expt/fitting/dipole_fitting.py b/deepmd/pt_expt/fitting/dipole_fitting.py new file mode 100644 index 0000000000..a16a96fe72 --- /dev/null +++ b/deepmd/pt_expt/fitting/dipole_fitting.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("dipole") +@torch_module +class DipoleFitting(DipoleFittingDP): + pass + + +register_dpmodel_mapping( + DipoleFittingDP, + lambda v: DipoleFitting.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/dos_fitting.py b/deepmd/pt_expt/fitting/dos_fitting.py new file mode 100644 index 0000000000..8c51fcc0eb --- /dev/null +++ b/deepmd/pt_expt/fitting/dos_fitting.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("dos") +@torch_module +class DOSFittingNet(DOSFittingNetDP): + pass + + +register_dpmodel_mapping( + DOSFittingNetDP, + lambda v: DOSFittingNet.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/polarizability_fitting.py b/deepmd/pt_expt/fitting/polarizability_fitting.py new file mode 100644 index 0000000000..564df7e0d7 --- /dev/null +++ b/deepmd/pt_expt/fitting/polarizability_fitting.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("polar") +@torch_module +class PolarFitting(PolarFittingDP): + pass + + +register_dpmodel_mapping( + PolarFittingDP, + lambda v: PolarFitting.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/property_fitting.py b/deepmd/pt_expt/fitting/property_fitting.py new file mode 100644 index 0000000000..318e30fad6 --- /dev/null +++ b/deepmd/pt_expt/fitting/property_fitting.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("property") +@torch_module +class PropertyFittingNet(PropertyFittingNetDP): + pass + + +register_dpmodel_mapping( + PropertyFittingNetDP, + lambda v: PropertyFittingNet.deserialize(v.serialize()), +) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index c81499611f..245744a93e 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -18,6 +18,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -33,6 +34,13 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: DipoleFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.dipole_fitting import ( + DipoleFitting as DipoleFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + DipoleFittingPTExpt = None if INSTALLED_TF: from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF else: @@ -116,12 +124,17 @@ def skip_pt(self) -> bool: tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT + pt_expt_class = DipoleFittingPTExpt jax_class = DipoleFittingJAX array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + def setUp(self) -> None: CommonTest.setUp(self) @@ -184,6 +197,18 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + gr=torch.from_numpy(self.gr).to(device=PT_EXPT_DEVICE), + )["dipole"] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index a77bc28c17..f758c9d317 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -18,6 +18,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -33,6 +34,11 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: DOSFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.dos_fitting import DOSFittingNet as DOSFittingPTExpt + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + DOSFittingPTExpt = None if INSTALLED_TF: from deepmd.tf.fit.dos import DOSFitting as DOSFittingTF else: @@ -106,9 +112,14 @@ def skip_jax(self) -> bool: def skip_array_api_strict(self) -> bool: return not INSTALLED_ARRAY_API_STRICT + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + tf_class = DOSFittingTF dp_class = DOSFittingDP pt_class = DOSFittingPT + pt_expt_class = DOSFittingPTExpt jax_class = DOSFittingJAX array_api_strict_class = DOSFittingStrict args = fitting_dos() @@ -187,6 +198,31 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + numb_dos, + ) = self.param + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if numb_fparam + else None, + aparam=torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if numb_aparam + else None, + )["dos"] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index a52beea0c7..142cbefdc8 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -18,6 +18,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -33,6 +34,13 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PolarFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.polarizability_fitting import ( + PolarFitting as PolarFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + PolarFittingPTExpt = None if INSTALLED_TF: from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF else: @@ -90,12 +98,17 @@ def skip_pt(self) -> bool: tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT + pt_expt_class = PolarFittingPTExpt jax_class = PolarFittingJAX array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + def setUp(self) -> None: CommonTest.setUp(self) @@ -155,6 +168,18 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + gr=torch.from_numpy(self.gr).to(device=PT_EXPT_DEVICE), + )["polarizability"] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index bccd20bd54..a9da348410 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -23,6 +23,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) @@ -37,6 +38,13 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PropertyFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + PropertyFittingPTExpt = None if INSTALLED_JAX: from deepmd.jax.env import ( jnp, @@ -110,9 +118,14 @@ def skip_tf(self) -> bool: skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + tf_class = PropertyFittingTF dp_class = PropertyFittingDP pt_class = PropertyFittingPT + pt_expt_class = PropertyFittingPTExpt jax_class = PropertyFittingJAX array_api_strict_class = PropertyFittingStrict args = fitting_property() @@ -194,6 +207,32 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if numb_fparam + else None, + aparam=torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if numb_aparam + else None, + )[pt_expt_obj.var_name] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, diff --git a/source/tests/pt_expt/fitting/test_dipole_fitting.py b/source/tests/pt_expt/fitting/test_dipole_fitting.py new file mode 100644 index 0000000000..8d25b55075 --- /dev/null +++ b/source/tests/pt_expt/fitting/test_dipole_fitting.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + DipoleFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + # dd[0]: descriptor, dd[1]: gr (rotation matrix, nf x nloc x nnei x 3... but + # for se_a, gr shape is nf x nloc x m1 x 3) + embedding_width = ds.get_dim_emb() + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + fn0 = DipoleFitting( + self.nt, + ds.dim_out, + embedding_width, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = DipoleFitting.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["dipole"].detach().cpu().numpy(), + ret1["dipole"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + embedding_width = ds.get_dim_emb() + fn = DipoleFitting( + self.nt, + ds.dim_out, + embedding_width, + ).to(self.device) + serialized = fn.serialize() + self.assertEqual(serialized["type"], "dipole") + fn2 = DipoleFitting.deserialize(serialized).to(self.device) + self.assertIsInstance(fn2, DipoleFitting) + + def test_torch_export_simple(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + embedding_width = ds.get_dim_emb() + rng = np.random.default_rng(GLOBAL_SEED) + + fn = DipoleFitting( + self.nt, + ds.dim_out, + embedding_width, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + descriptor = torch.from_numpy(dd[0]).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + gr = torch.from_numpy(dd[1]).to(self.device) + + ret = fn(descriptor, atype, gr=gr) + self.assertIn("dipole", ret) + + exported = torch.export.export( + fn, + (descriptor, atype), + kwargs={"gr": gr}, + strict=False, + ) + self.assertIsNotNone(exported) + + ret_exported = exported.module()(descriptor, atype, gr=gr) + np.testing.assert_allclose( + ret["dipole"].detach().cpu().numpy(), + ret_exported["dipole"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_dos_fitting.py b/source/tests/pt_expt/fitting/test_dos_fitting.py new file mode 100644 index 0000000000..e16b6f6569 --- /dev/null +++ b/source/tests/pt_expt/fitting/test_dos_fitting.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + DOSFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDOSFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + fn0 = DOSFittingNet( + self.nt, + ds.dim_out, + numb_dos=10, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = DOSFittingNet.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["dos"].detach().cpu().numpy(), + ret1["dos"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + fn = DOSFittingNet( + self.nt, + ds.dim_out, + numb_dos=10, + ).to(self.device) + serialized = fn.serialize() + self.assertEqual(serialized["type"], "dos") + fn2 = DOSFittingNet.deserialize(serialized).to(self.device) + self.assertIsInstance(fn2, DOSFittingNet) + + def test_torch_export_simple(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + fn = DOSFittingNet( + self.nt, + ds.dim_out, + numb_dos=10, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + ret = fn(descriptor, atype) + self.assertIn("dos", ret) + + exported = torch.export.export( + fn, + (descriptor, atype), + kwargs={}, + strict=False, + ) + self.assertIsNotNone(exported) + + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["dos"].detach().cpu().numpy(), + ret_exported["dos"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py b/source/tests/pt_expt/fitting/test_ener_fitting.py similarity index 100% rename from source/tests/pt_expt/fitting/test_fitting_ener_fitting.py rename to source/tests/pt_expt/fitting/test_ener_fitting.py diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_invar_fitting.py similarity index 100% rename from source/tests/pt_expt/fitting/test_fitting_invar_fitting.py rename to source/tests/pt_expt/fitting/test_invar_fitting.py diff --git a/source/tests/pt_expt/fitting/test_polar_fitting.py b/source/tests/pt_expt/fitting/test_polar_fitting.py new file mode 100644 index 0000000000..24b38b1fe9 --- /dev/null +++ b/source/tests/pt_expt/fitting/test_polar_fitting.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + PolarFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestPolarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + embedding_width = ds.get_dim_emb() + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + fn0 = PolarFitting( + self.nt, + ds.dim_out, + embedding_width, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = PolarFitting.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["polarizability"].detach().cpu().numpy(), + ret1["polarizability"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + embedding_width = ds.get_dim_emb() + fn = PolarFitting( + self.nt, + ds.dim_out, + embedding_width, + ).to(self.device) + serialized = fn.serialize() + self.assertEqual(serialized["type"], "polar") + fn2 = PolarFitting.deserialize(serialized).to(self.device) + self.assertIsInstance(fn2, PolarFitting) + + def test_torch_export_simple(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + embedding_width = ds.get_dim_emb() + + fn = PolarFitting( + self.nt, + ds.dim_out, + embedding_width, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + descriptor = torch.from_numpy(dd[0]).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + gr = torch.from_numpy(dd[1]).to(self.device) + + ret = fn(descriptor, atype, gr=gr) + self.assertIn("polarizability", ret) + + exported = torch.export.export( + fn, + (descriptor, atype), + kwargs={"gr": gr}, + strict=False, + ) + self.assertIsNotNone(exported) + + ret_exported = exported.module()(descriptor, atype, gr=gr) + np.testing.assert_allclose( + ret["polarizability"].detach().cpu().numpy(), + ret_exported["polarizability"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_property_fitting.py b/source/tests/pt_expt/fitting/test_property_fitting.py new file mode 100644 index 0000000000..44a499ed9d --- /dev/null +++ b/source/tests/pt_expt/fitting/test_property_fitting.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + PropertyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestPropertyFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + fn0 = PropertyFittingNet( + self.nt, + ds.dim_out, + task_dim=3, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = PropertyFittingNet.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["property"].detach().cpu().numpy(), + ret1["property"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + fn = PropertyFittingNet( + self.nt, + ds.dim_out, + task_dim=3, + ).to(self.device) + serialized = fn.serialize() + self.assertEqual(serialized["type"], "property") + fn2 = PropertyFittingNet.deserialize(serialized).to(self.device) + self.assertIsInstance(fn2, PropertyFittingNet) + + def test_torch_export_simple(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + fn = PropertyFittingNet( + self.nt, + ds.dim_out, + task_dim=3, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + ret = fn(descriptor, atype) + self.assertIn("property", ret) + + exported = torch.export.export( + fn, + (descriptor, atype), + kwargs={}, + strict=False, + ) + self.assertIsNotNone(exported) + + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["property"].detach().cpu().numpy(), + ret_exported["property"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) From 292fa724576af4cfed07a7b2bb1975e8fa2cf3f6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 22 Feb 2026 17:03:28 +0800 Subject: [PATCH 2/3] add make_fx, mv itertools to parameterized --- .../pt_expt/fitting/test_dipole_fitting.py | 139 ++++++++++++------ .../tests/pt_expt/fitting/test_dos_fitting.py | 131 +++++++++++------ .../pt_expt/fitting/test_polar_fitting.py | 136 +++++++++++------ .../pt_expt/fitting/test_property_fitting.py | 131 +++++++++++------ 4 files changed, 360 insertions(+), 177 deletions(-) diff --git a/source/tests/pt_expt/fitting/test_dipole_fitting.py b/source/tests/pt_expt/fitting/test_dipole_fitting.py index 8d25b55075..f5ac7ba177 100644 --- a/source/tests/pt_expt/fitting/test_dipole_fitting.py +++ b/source/tests/pt_expt/fitting/test_dipole_fitting.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, @@ -22,59 +25,57 @@ ) -class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDipoleFitting(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_self_consistency(self) -> None: + @pytest.mark.parametrize("nfp", [0, 3]) # numb_fparam + @pytest.mark.parametrize("nap", [0, 4]) # numb_aparam + def test_self_consistency(self, nfp, nap) -> None: rng = np.random.default_rng(GLOBAL_SEED) nf, nloc, nnei = self.nlist.shape ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) atype = self.atype_ext[:, :nloc] - - # dd[0]: descriptor, dd[1]: gr (rotation matrix, nf x nloc x nnei x 3... but - # for se_a, gr shape is nf x nloc x m1 x 3) embedding_width = ds.get_dim_emb() - for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: - fn0 = DipoleFitting( - self.nt, - ds.dim_out, - embedding_width, - numb_fparam=nfp, - numb_aparam=nap, - ).to(self.device) - fn1 = DipoleFitting.deserialize(fn0.serialize()).to(self.device) - if nfp > 0: - ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) - else: - ifp = None - if nap > 0: - iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( - self.device - ) - else: - iap = None - ret0 = fn0( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - gr=torch.from_numpy(dd[1]).to(self.device), - fparam=ifp, - aparam=iap, - ) - ret1 = fn1( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - gr=torch.from_numpy(dd[1]).to(self.device), - fparam=ifp, - aparam=iap, - ) - np.testing.assert_allclose( - ret0["dipole"].detach().cpu().numpy(), - ret1["dipole"].detach().cpu().numpy(), + fn0 = DipoleFitting( + self.nt, + ds.dim_out, + embedding_width, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = DipoleFitting.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["dipole"].detach().cpu().numpy(), + ret1["dipole"].detach().cpu().numpy(), + ) def test_serialize_has_correct_type(self) -> None: ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) @@ -85,16 +86,15 @@ def test_serialize_has_correct_type(self) -> None: embedding_width, ).to(self.device) serialized = fn.serialize() - self.assertEqual(serialized["type"], "dipole") + assert serialized["type"] == "dipole" fn2 = DipoleFitting.deserialize(serialized).to(self.device) - self.assertIsInstance(fn2, DipoleFitting) + assert isinstance(fn2, DipoleFitting) def test_torch_export_simple(self) -> None: nf, nloc, nnei = self.nlist.shape ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) embedding_width = ds.get_dim_emb() - rng = np.random.default_rng(GLOBAL_SEED) fn = DipoleFitting( self.nt, @@ -109,7 +109,7 @@ def test_torch_export_simple(self) -> None: gr = torch.from_numpy(dd[1]).to(self.device) ret = fn(descriptor, atype, gr=gr) - self.assertIn("dipole", ret) + assert "dipole" in ret exported = torch.export.export( fn, @@ -117,7 +117,7 @@ def test_torch_export_simple(self) -> None: kwargs={"gr": gr}, strict=False, ) - self.assertIsNotNone(exported) + assert exported is not None ret_exported = exported.module()(descriptor, atype, gr=gr) np.testing.assert_allclose( @@ -126,3 +126,46 @@ def test_torch_export_simple(self) -> None: rtol=1e-10, atol=1e-10, ) + + def test_make_fx(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + embedding_width = ds.get_dim_emb() + + fn0 = ( + DipoleFitting( + self.nt, + ds.dim_out, + embedding_width, + precision="float64", + ) + .to(self.device) + .eval() + ) + + descriptor = torch.from_numpy(dd[0]).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + gr = torch.from_numpy(dd[1]).to(self.device) + + def fn(descriptor, atype, gr): + descriptor = descriptor.detach().requires_grad_(True) + ret = fn0(descriptor, atype, gr=gr)["dipole"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + ret_eager, grad_eager = fn(descriptor, atype, gr) + traced = make_fx(fn)(descriptor, atype, gr) + ret_traced, grad_traced = traced(descriptor, atype, gr) + np.testing.assert_allclose( + ret_eager.detach().cpu().numpy(), + ret_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_dos_fitting.py b/source/tests/pt_expt/fitting/test_dos_fitting.py index e16b6f6569..3fe06a8618 100644 --- a/source/tests/pt_expt/fitting/test_dos_fitting.py +++ b/source/tests/pt_expt/fitting/test_dos_fitting.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, @@ -22,53 +25,54 @@ ) -class TestDOSFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDOSFittingNet(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_self_consistency(self) -> None: + @pytest.mark.parametrize("nfp", [0, 3]) # numb_fparam + @pytest.mark.parametrize("nap", [0, 4]) # numb_aparam + def test_self_consistency(self, nfp, nap) -> None: rng = np.random.default_rng(GLOBAL_SEED) nf, nloc, nnei = self.nlist.shape ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) atype = self.atype_ext[:, :nloc] - for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: - fn0 = DOSFittingNet( - self.nt, - ds.dim_out, - numb_dos=10, - numb_fparam=nfp, - numb_aparam=nap, - ).to(self.device) - fn1 = DOSFittingNet.deserialize(fn0.serialize()).to(self.device) - if nfp > 0: - ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) - else: - ifp = None - if nap > 0: - iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( - self.device - ) - else: - iap = None - ret0 = fn0( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - fparam=ifp, - aparam=iap, - ) - ret1 = fn1( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - fparam=ifp, - aparam=iap, - ) - np.testing.assert_allclose( - ret0["dos"].detach().cpu().numpy(), - ret1["dos"].detach().cpu().numpy(), + fn0 = DOSFittingNet( + self.nt, + ds.dim_out, + numb_dos=10, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = DOSFittingNet.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["dos"].detach().cpu().numpy(), + ret1["dos"].detach().cpu().numpy(), + ) def test_serialize_has_correct_type(self) -> None: ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) @@ -78,9 +82,9 @@ def test_serialize_has_correct_type(self) -> None: numb_dos=10, ).to(self.device) serialized = fn.serialize() - self.assertEqual(serialized["type"], "dos") + assert serialized["type"] == "dos" fn2 = DOSFittingNet.deserialize(serialized).to(self.device) - self.assertIsInstance(fn2, DOSFittingNet) + assert isinstance(fn2, DOSFittingNet) def test_torch_export_simple(self) -> None: nf, nloc, nnei = self.nlist.shape @@ -101,7 +105,7 @@ def test_torch_export_simple(self) -> None: atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) ret = fn(descriptor, atype) - self.assertIn("dos", ret) + assert "dos" in ret exported = torch.export.export( fn, @@ -109,7 +113,7 @@ def test_torch_export_simple(self) -> None: kwargs={}, strict=False, ) - self.assertIsNotNone(exported) + assert exported is not None ret_exported = exported.module()(descriptor, atype) np.testing.assert_allclose( @@ -118,3 +122,46 @@ def test_torch_export_simple(self) -> None: rtol=1e-10, atol=1e-10, ) + + def test_make_fx(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + fn0 = ( + DOSFittingNet( + self.nt, + ds.dim_out, + numb_dos=10, + precision="float64", + ) + .to(self.device) + .eval() + ) + + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + def fn(descriptor, atype): + descriptor = descriptor.detach().requires_grad_(True) + ret = fn0(descriptor, atype)["dos"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + ret_eager, grad_eager = fn(descriptor, atype) + traced = make_fx(fn)(descriptor, atype) + ret_traced, grad_traced = traced(descriptor, atype) + np.testing.assert_allclose( + ret_eager.detach().cpu().numpy(), + ret_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_polar_fitting.py b/source/tests/pt_expt/fitting/test_polar_fitting.py index 24b38b1fe9..1c150f7154 100644 --- a/source/tests/pt_expt/fitting/test_polar_fitting.py +++ b/source/tests/pt_expt/fitting/test_polar_fitting.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, @@ -22,57 +25,57 @@ ) -class TestPolarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestPolarFitting(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_self_consistency(self) -> None: + @pytest.mark.parametrize("nfp", [0, 3]) # numb_fparam + @pytest.mark.parametrize("nap", [0, 4]) # numb_aparam + def test_self_consistency(self, nfp, nap) -> None: rng = np.random.default_rng(GLOBAL_SEED) nf, nloc, nnei = self.nlist.shape ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) atype = self.atype_ext[:, :nloc] - embedding_width = ds.get_dim_emb() - for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: - fn0 = PolarFitting( - self.nt, - ds.dim_out, - embedding_width, - numb_fparam=nfp, - numb_aparam=nap, - ).to(self.device) - fn1 = PolarFitting.deserialize(fn0.serialize()).to(self.device) - if nfp > 0: - ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) - else: - ifp = None - if nap > 0: - iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( - self.device - ) - else: - iap = None - ret0 = fn0( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - gr=torch.from_numpy(dd[1]).to(self.device), - fparam=ifp, - aparam=iap, - ) - ret1 = fn1( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - gr=torch.from_numpy(dd[1]).to(self.device), - fparam=ifp, - aparam=iap, - ) - np.testing.assert_allclose( - ret0["polarizability"].detach().cpu().numpy(), - ret1["polarizability"].detach().cpu().numpy(), + fn0 = PolarFitting( + self.nt, + ds.dim_out, + embedding_width, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = PolarFitting.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + gr=torch.from_numpy(dd[1]).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["polarizability"].detach().cpu().numpy(), + ret1["polarizability"].detach().cpu().numpy(), + ) def test_serialize_has_correct_type(self) -> None: ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) @@ -83,9 +86,9 @@ def test_serialize_has_correct_type(self) -> None: embedding_width, ).to(self.device) serialized = fn.serialize() - self.assertEqual(serialized["type"], "polar") + assert serialized["type"] == "polar" fn2 = PolarFitting.deserialize(serialized).to(self.device) - self.assertIsInstance(fn2, PolarFitting) + assert isinstance(fn2, PolarFitting) def test_torch_export_simple(self) -> None: nf, nloc, nnei = self.nlist.shape @@ -106,7 +109,7 @@ def test_torch_export_simple(self) -> None: gr = torch.from_numpy(dd[1]).to(self.device) ret = fn(descriptor, atype, gr=gr) - self.assertIn("polarizability", ret) + assert "polarizability" in ret exported = torch.export.export( fn, @@ -114,7 +117,7 @@ def test_torch_export_simple(self) -> None: kwargs={"gr": gr}, strict=False, ) - self.assertIsNotNone(exported) + assert exported is not None ret_exported = exported.module()(descriptor, atype, gr=gr) np.testing.assert_allclose( @@ -123,3 +126,46 @@ def test_torch_export_simple(self) -> None: rtol=1e-10, atol=1e-10, ) + + def test_make_fx(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + embedding_width = ds.get_dim_emb() + + fn0 = ( + PolarFitting( + self.nt, + ds.dim_out, + embedding_width, + precision="float64", + ) + .to(self.device) + .eval() + ) + + descriptor = torch.from_numpy(dd[0]).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + gr = torch.from_numpy(dd[1]).to(self.device) + + def fn(descriptor, atype, gr): + descriptor = descriptor.detach().requires_grad_(True) + ret = fn0(descriptor, atype, gr=gr)["polarizability"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + ret_eager, grad_eager = fn(descriptor, atype, gr) + traced = make_fx(fn)(descriptor, atype, gr) + ret_traced, grad_traced = traced(descriptor, atype, gr) + np.testing.assert_allclose( + ret_eager.detach().cpu().numpy(), + ret_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_property_fitting.py b/source/tests/pt_expt/fitting/test_property_fitting.py index 44a499ed9d..ca3dbc11af 100644 --- a/source/tests/pt_expt/fitting/test_property_fitting.py +++ b/source/tests/pt_expt/fitting/test_property_fitting.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, @@ -22,53 +25,54 @@ ) -class TestPropertyFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestPropertyFittingNet(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_self_consistency(self) -> None: + @pytest.mark.parametrize("nfp", [0, 3]) # numb_fparam + @pytest.mark.parametrize("nap", [0, 4]) # numb_aparam + def test_self_consistency(self, nfp, nap) -> None: rng = np.random.default_rng(GLOBAL_SEED) nf, nloc, nnei = self.nlist.shape ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) atype = self.atype_ext[:, :nloc] - for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: - fn0 = PropertyFittingNet( - self.nt, - ds.dim_out, - task_dim=3, - numb_fparam=nfp, - numb_aparam=nap, - ).to(self.device) - fn1 = PropertyFittingNet.deserialize(fn0.serialize()).to(self.device) - if nfp > 0: - ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) - else: - ifp = None - if nap > 0: - iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( - self.device - ) - else: - iap = None - ret0 = fn0( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - fparam=ifp, - aparam=iap, - ) - ret1 = fn1( - torch.from_numpy(dd[0]).to(self.device), - torch.from_numpy(atype).to(self.device), - fparam=ifp, - aparam=iap, - ) - np.testing.assert_allclose( - ret0["property"].detach().cpu().numpy(), - ret1["property"].detach().cpu().numpy(), + fn0 = PropertyFittingNet( + self.nt, + ds.dim_out, + task_dim=3, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + fn1 = PropertyFittingNet.deserialize(fn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device ) + else: + iap = None + ret0 = fn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = fn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["property"].detach().cpu().numpy(), + ret1["property"].detach().cpu().numpy(), + ) def test_serialize_has_correct_type(self) -> None: ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) @@ -78,9 +82,9 @@ def test_serialize_has_correct_type(self) -> None: task_dim=3, ).to(self.device) serialized = fn.serialize() - self.assertEqual(serialized["type"], "property") + assert serialized["type"] == "property" fn2 = PropertyFittingNet.deserialize(serialized).to(self.device) - self.assertIsInstance(fn2, PropertyFittingNet) + assert isinstance(fn2, PropertyFittingNet) def test_torch_export_simple(self) -> None: nf, nloc, nnei = self.nlist.shape @@ -101,7 +105,7 @@ def test_torch_export_simple(self) -> None: atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) ret = fn(descriptor, atype) - self.assertIn("property", ret) + assert "property" in ret exported = torch.export.export( fn, @@ -109,7 +113,7 @@ def test_torch_export_simple(self) -> None: kwargs={}, strict=False, ) - self.assertIsNotNone(exported) + assert exported is not None ret_exported = exported.module()(descriptor, atype) np.testing.assert_allclose( @@ -118,3 +122,46 @@ def test_torch_export_simple(self) -> None: rtol=1e-10, atol=1e-10, ) + + def test_make_fx(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + fn0 = ( + PropertyFittingNet( + self.nt, + ds.dim_out, + task_dim=3, + precision="float64", + ) + .to(self.device) + .eval() + ) + + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + def fn(descriptor, atype): + descriptor = descriptor.detach().requires_grad_(True) + ret = fn0(descriptor, atype)["property"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + ret_eager, grad_eager = fn(descriptor, atype) + traced = make_fx(fn)(descriptor, atype) + ret_traced, grad_traced = traced(descriptor, atype) + np.testing.assert_allclose( + ret_eager.detach().cpu().numpy(), + ret_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) From 6a71537d6a1a1b48369d7e5aaae320d478d2f8ad Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 22 Feb 2026 22:31:46 +0800 Subject: [PATCH 3/3] rm register_dpmodel_mapping from fitting --- deepmd/pt_expt/fitting/dipole_fitting.py | 7 ------- deepmd/pt_expt/fitting/dos_fitting.py | 7 ------- deepmd/pt_expt/fitting/ener_fitting.py | 12 ------------ deepmd/pt_expt/fitting/invar_fitting.py | 7 ------- deepmd/pt_expt/fitting/polarizability_fitting.py | 7 ------- deepmd/pt_expt/fitting/property_fitting.py | 7 ------- 6 files changed, 47 deletions(-) diff --git a/deepmd/pt_expt/fitting/dipole_fitting.py b/deepmd/pt_expt/fitting/dipole_fitting.py index a16a96fe72..23a10432da 100644 --- a/deepmd/pt_expt/fitting/dipole_fitting.py +++ b/deepmd/pt_expt/fitting/dipole_fitting.py @@ -2,7 +2,6 @@ from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP from deepmd.pt_expt.common import ( - register_dpmodel_mapping, torch_module, ) @@ -15,9 +14,3 @@ @torch_module class DipoleFitting(DipoleFittingDP): pass - - -register_dpmodel_mapping( - DipoleFittingDP, - lambda v: DipoleFitting.deserialize(v.serialize()), -) diff --git a/deepmd/pt_expt/fitting/dos_fitting.py b/deepmd/pt_expt/fitting/dos_fitting.py index 8c51fcc0eb..c42511bfe6 100644 --- a/deepmd/pt_expt/fitting/dos_fitting.py +++ b/deepmd/pt_expt/fitting/dos_fitting.py @@ -2,7 +2,6 @@ from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.pt_expt.common import ( - register_dpmodel_mapping, torch_module, ) @@ -15,9 +14,3 @@ @torch_module class DOSFittingNet(DOSFittingNetDP): pass - - -register_dpmodel_mapping( - DOSFittingNetDP, - lambda v: DOSFittingNet.deserialize(v.serialize()), -) diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index f9779e44af..f778af8fec 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -2,7 +2,6 @@ from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( - register_dpmodel_mapping, torch_module, ) @@ -14,15 +13,4 @@ @BaseFitting.register("ener") @torch_module class EnergyFittingNet(EnergyFittingNetDP): - """Energy fitting net for pt_expt backend. - - This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. - """ - pass - - -register_dpmodel_mapping( - EnergyFittingNetDP, - lambda v: EnergyFittingNet.deserialize(v.serialize()), -) diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index ab908ebe0d..f13fe2afbb 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -2,7 +2,6 @@ from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( - register_dpmodel_mapping, torch_module, ) from deepmd.pt_expt.fitting.base_fitting import ( @@ -14,9 +13,3 @@ @torch_module class InvarFitting(InvarFittingDP): pass - - -register_dpmodel_mapping( - InvarFittingDP, - lambda v: InvarFitting.deserialize(v.serialize()), -) diff --git a/deepmd/pt_expt/fitting/polarizability_fitting.py b/deepmd/pt_expt/fitting/polarizability_fitting.py index 564df7e0d7..e86b1224ef 100644 --- a/deepmd/pt_expt/fitting/polarizability_fitting.py +++ b/deepmd/pt_expt/fitting/polarizability_fitting.py @@ -2,7 +2,6 @@ from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP from deepmd.pt_expt.common import ( - register_dpmodel_mapping, torch_module, ) @@ -15,9 +14,3 @@ @torch_module class PolarFitting(PolarFittingDP): pass - - -register_dpmodel_mapping( - PolarFittingDP, - lambda v: PolarFitting.deserialize(v.serialize()), -) diff --git a/deepmd/pt_expt/fitting/property_fitting.py b/deepmd/pt_expt/fitting/property_fitting.py index 318e30fad6..f1bd9becbf 100644 --- a/deepmd/pt_expt/fitting/property_fitting.py +++ b/deepmd/pt_expt/fitting/property_fitting.py @@ -4,7 +4,6 @@ PropertyFittingNet as PropertyFittingNetDP, ) from deepmd.pt_expt.common import ( - register_dpmodel_mapping, torch_module, ) @@ -17,9 +16,3 @@ @torch_module class PropertyFittingNet(PropertyFittingNetDP): pass - - -register_dpmodel_mapping( - PropertyFittingNetDP, - lambda v: PropertyFittingNet.deserialize(v.serialize()), -)