diff --git a/deepmd/pt_expt/model/dipole_model.py b/deepmd/pt_expt/model/dipole_model.py index 73ebba6bac..79ae26024e 100644 --- a/deepmd/pt_expt/model/dipole_model.py +++ b/deepmd/pt_expt/model/dipole_model.py @@ -124,6 +124,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self @@ -146,6 +147,6 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/dos_model.py b/deepmd/pt_expt/model/dos_model.py index 137c2b2901..2e69d90ab3 100644 --- a/deepmd/pt_expt/model/dos_model.py +++ b/deepmd/pt_expt/model/dos_model.py @@ -104,6 +104,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self @@ -126,6 +127,6 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/dp_linear_model.py b/deepmd/pt_expt/model/dp_linear_model.py index 134d738f66..46790c877e 100644 --- a/deepmd/pt_expt/model/dp_linear_model.py +++ b/deepmd/pt_expt/model/dp_linear_model.py @@ -129,6 +129,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self @@ -151,7 +152,7 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/dp_zbl_model.py b/deepmd/pt_expt/model/dp_zbl_model.py index c4bb668353..b7f164114b 100644 --- a/deepmd/pt_expt/model/dp_zbl_model.py +++ b/deepmd/pt_expt/model/dp_zbl_model.py @@ -126,6 +126,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self @@ -148,6 +149,6 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/polar_model.py b/deepmd/pt_expt/model/polar_model.py index 2bec72d4f7..d421bb76a4 100644 --- a/deepmd/pt_expt/model/polar_model.py +++ b/deepmd/pt_expt/model/polar_model.py @@ -104,6 +104,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self @@ -126,6 +127,6 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/property_model.py b/deepmd/pt_expt/model/property_model.py index 50f8f0eeb4..72a327fb03 100644 --- a/deepmd/pt_expt/model/property_model.py +++ b/deepmd/pt_expt/model/property_model.py @@ -111,6 +111,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: model = self @@ -133,6 +134,6 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index c8678b4d8d..f23d0bb025 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -330,12 +330,17 @@ def _trace_and_export( metadata = _collect_metadata(model) # 3. Create sample inputs on CPU for tracing - # Use nframes=2 so make_fx doesn't specialize on nframes=1 + # Use nframes=5 to avoid two specialization traps: + # - nframes=1 causes make_fx to specialize on the scalar case + # - nframes=N where N == numb_fparam or numb_aparam causes PyTorch's + # symbolic tracer to merge symbols (e.g. fparam.shape=(2,2) when + # nframes=2 and numb_fparam=2), so a guard on one dim constrains + # the other. 5 is unlikely to collide with typical param counts. _orig_device = _env.DEVICE _env.DEVICE = torch.device("cpu") try: ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs( - model, nframes=2 + model, nframes=5 ) finally: _env.DEVICE = _orig_device diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index 0524c6c98a..e90c67bc82 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -25,6 +25,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptDPA1(TestCaseSingleFrameWithNlist): @@ -223,3 +227,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_dpa2.py b/source/tests/pt_expt/descriptor/test_dpa2.py index a3794052f4..fb0005e13a 100644 --- a/source/tests/pt_expt/descriptor/test_dpa2.py +++ b/source/tests/pt_expt/descriptor/test_dpa2.py @@ -29,6 +29,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptDPA2(TestCaseSingleFrameWithNlist): @@ -410,3 +414,15 @@ def fn(coord_ext, atype_ext, nlist, mapping): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=True) + inputs = (coord_ext, atype_ext, nlist, mapping) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_dpa3.py b/source/tests/pt_expt/descriptor/test_dpa3.py index 7cdcd6ced7..ecc94d24f5 100644 --- a/source/tests/pt_expt/descriptor/test_dpa3.py +++ b/source/tests/pt_expt/descriptor/test_dpa3.py @@ -28,6 +28,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptDPA3(TestCaseSingleFrameWithNlist): @@ -244,3 +248,15 @@ def fn(coord_ext, atype_ext, nlist, mapping): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=True) + inputs = (coord_ext, atype_ext, nlist, mapping) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_hybrid.py b/source/tests/pt_expt/descriptor/test_hybrid.py index 3185733ddf..a3c673d774 100644 --- a/source/tests/pt_expt/descriptor/test_hybrid.py +++ b/source/tests/pt_expt/descriptor/test_hybrid.py @@ -31,6 +31,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptHybrid(TestCaseSingleFrameWithNlist): @@ -215,3 +219,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_atten_v2.py b/source/tests/pt_expt/descriptor/test_se_atten_v2.py index 01ead9e179..326a78acad 100644 --- a/source/tests/pt_expt/descriptor/test_se_atten_v2.py +++ b/source/tests/pt_expt/descriptor/test_se_atten_v2.py @@ -25,6 +25,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptSeAttenV2(TestCaseSingleFrameWithNlist): @@ -218,3 +222,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index 0efb90b1bb..e4bd1e385e 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -28,6 +28,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptSeA(TestCaseSingleFrameWithNlist): @@ -205,3 +209,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index 0bac2251b4..cde3295e7a 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -25,6 +25,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptSeR(TestCaseSingleFrameWithNlist): @@ -200,3 +204,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py index 3da6ef02f3..bb1f9a4b3f 100644 --- a/source/tests/pt_expt/descriptor/test_se_t.py +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -25,6 +25,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptSeT(TestCaseSingleFrameWithNlist): @@ -204,3 +208,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index bb4b1dc80d..30808f5070 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -25,6 +25,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_descriptor_dynamic_shapes, +) class TestDescrptSeTTebd(TestCaseSingleFrameWithNlist): @@ -235,3 +239,15 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False) + inputs = (coord_ext, atype_ext, nlist) + export_save_load_and_compare( + fn, + inputs, + (rd_eager, grad_eager), + dynamic_shapes, + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/export_helpers.py b/source/tests/pt_expt/export_helpers.py new file mode 100644 index 0000000000..453b9d0c01 --- /dev/null +++ b/source/tests/pt_expt/export_helpers.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Helpers for enhanced per-module export tests. + +Provides symbolic tracing, torch.export with dynamic shapes, and .pte +save/load round-trip verification used by descriptor, fitting, and model +test_make_fx / test_forward_lower_exportable methods. +""" + +import tempfile + +import numpy as np +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + + +def export_save_load_and_compare( + fn, + inputs: tuple, + eager_outputs: tuple, + dynamic_shapes: tuple, + rtol: float = 0.0, + atol: float = 1e-12, +): + """Symbolic trace -> export with dynamic shapes -> .pte save/load -> compare. + + Parameters + ---------- + fn : callable + The function to trace (same one used for eager and concrete make_fx). + inputs : tuple of torch.Tensor + Input tensors for the function. + eager_outputs : tuple of torch.Tensor + Reference outputs from eager execution. + dynamic_shapes : tuple + Dynamic shape specs for torch.export.export. + rtol, atol : float + Tolerances for np.testing.assert_allclose. + + Returns + ------- + loaded_module : torch.nn.Module + The module loaded from the .pte round-trip, for further testing. + """ + # 1. Symbolic make_fx trace + traced_sym = make_fx(fn, tracing_mode="symbolic", _allow_non_fake_inputs=True)( + *inputs + ) + + # 2. Compare symbolic-traced output vs eager + sym_outputs = traced_sym(*inputs) + if not isinstance(sym_outputs, tuple): + sym_outputs = (sym_outputs,) + if not isinstance(eager_outputs, tuple): + eager_outputs = (eager_outputs,) + for sym_out, eager_out in zip(sym_outputs, eager_outputs, strict=True): + np.testing.assert_allclose( + sym_out.detach().cpu().numpy(), + eager_out.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + # 3. torch.export.export with dynamic shapes + exported = torch.export.export( + traced_sym, + inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + # 4. .pte save -> load round-trip + with tempfile.NamedTemporaryFile(suffix=".pte") as f: + torch.export.save(exported, f.name) + loaded = torch.export.load(f.name) + + loaded_module = loaded.module() + + # 5. Compare loaded output vs eager (same shapes) + loaded_outputs = loaded_module(*inputs) + if not isinstance(loaded_outputs, tuple): + loaded_outputs = (loaded_outputs,) + for loaded_out, eager_out in zip(loaded_outputs, eager_outputs, strict=True): + np.testing.assert_allclose( + loaded_out.detach().cpu().numpy(), + eager_out.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + # 6. Compare loaded output vs eager (different nframes via nf=1 slice) + inputs_1f = tuple(t[0:1] if t is not None else None for t in inputs) + eager_1f = fn(*inputs_1f) + loaded_1f = loaded_module(*inputs_1f) + if not isinstance(eager_1f, tuple): + eager_1f = (eager_1f,) + if not isinstance(loaded_1f, tuple): + loaded_1f = (loaded_1f,) + for eager_out, loaded_out in zip(eager_1f, loaded_1f, strict=True): + np.testing.assert_allclose( + eager_out.detach().cpu().numpy(), + loaded_out.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + return loaded_module + + +def model_forward_lower_export_round_trip( + md_pt, + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + output_keys: tuple[str, ...], + rtol: float = 1e-10, + atol: float = 1e-10, +): + """Full forward_lower_exportable test: concrete trace + export + symbolic + .pte. + + Performs the complete export pipeline test: + 1. Eager reference via forward_lower + 2. Concrete trace via forward_lower_exportable + 3. torch.export.export (no dynamic shapes) + 4. Compare traced/exported vs eager + 5. Symbolic trace + dynamic shapes + .pte save/load round-trip + 6. Compare loaded vs eager (nf=1 — different shapes) + + Parameters + ---------- + md_pt : torch.nn.Module + The model (already deserialized and in eval mode). + ext_coord, ext_atype, nlist_t, mapping_t : torch.Tensor + Extended coordinates, atom types, neighbor list, mapping. + fparam, aparam : torch.Tensor or None + Frame and atom parameters. + output_keys : tuple of str + Output dictionary keys to verify. + rtol, atol : float + Tolerances for np.testing.assert_allclose. + """ + from deepmd.pt_expt.utils.serialization import ( + _build_dynamic_shapes, + ) + + # 1. Eager reference + ret_eager = md_pt.forward_lower( + ext_coord.requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + ) + + # 2. Concrete trace + traced = md_pt.forward_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + ) + assert isinstance(traced, torch.nn.Module) + + # 3. Basic export (no dynamic shapes) + exported = torch.export.export( + traced, + (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), + strict=False, + ) + assert exported is not None + + # 4. Compare traced and exported vs eager + ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + ret_exported = exported.module()( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_traced[key].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"traced vs eager: {key}", + ) + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_exported[key].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"exported vs eager: {key}", + ) + + # 5. Symbolic trace + dynamic shapes + .pte round-trip + inputs_2f = tuple( + torch.cat([t, t], dim=0) if t is not None else None + for t in (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + ) + traced_sym = md_pt.forward_lower_exportable( + inputs_2f[0], + inputs_2f[1], + inputs_2f[2], + inputs_2f[3], + fparam=inputs_2f[4], + aparam=inputs_2f[5], + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + dynamic_shapes = _build_dynamic_shapes(*inputs_2f) + exported_dyn = torch.export.export( + traced_sym, + inputs_2f, + dynamic_shapes=dynamic_shapes, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + with tempfile.NamedTemporaryFile(suffix=".pte") as f: + torch.export.save(exported_dyn, f.name) + loaded = torch.export.load(f.name).module() + + # 6. Compare loaded vs eager (nf=1 — different shapes) + ret_loaded_1f = loaded(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + for key in output_keys: + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_loaded_1f[key].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"loaded vs eager (nf=1): {key}", + ) + + +def make_descriptor_dynamic_shapes(has_mapping: bool = False) -> tuple: + """Build dynamic shapes for descriptor inputs (coord_ext, atype_ext, nlist[, mapping]). + + Note: coord_ext is in flattened form (nframes, nall*3), not (nframes, nall, 3). + + Parameters + ---------- + has_mapping : bool + Whether the descriptor takes a mapping argument. + """ + nframes_dim = torch.export.Dim("nframes", min=1) + nall_dim = torch.export.Dim("nall", min=1) + nloc_dim = torch.export.Dim("nloc", min=1) + + shapes = ( + {0: nframes_dim, 1: 3 * nall_dim}, # coord_ext: (nframes, nall*3) + {0: nframes_dim, 1: nall_dim}, # atype_ext: (nframes, nall) + {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + ) + if has_mapping: + shapes = (*shapes, {0: nframes_dim, 1: nall_dim}) # + mapping: (nframes, nall) + return shapes + + +def make_fitting_dynamic_shapes( + has_gr: bool = False, + has_fparam: bool = False, + has_aparam: bool = False, +) -> tuple: + """Build dynamic shapes for fitting inputs (descriptor, atype[, gr][, fparam][, aparam]). + + Only nframes is marked dynamic. Fitting nets tested in isolation may + specialize on nloc during symbolic tracing, making nloc incompatible + with dynamic dim specs. In the full model pipeline, nloc comes from + the descriptor output and remains dynamic; here we only test nframes. + + Parameters + ---------- + has_gr : bool + Whether the fitting takes a gr (rotation matrix) argument. + has_fparam : bool + Whether the fitting takes fparam. + has_aparam : bool + Whether the fitting takes aparam. + """ + nframes_dim = torch.export.Dim("nframes", min=1) + + shapes: list = [ + {0: nframes_dim}, # descriptor: (nframes, nloc, dim_descrpt) + {0: nframes_dim}, # atype: (nframes, nloc) + ] + if has_gr: + shapes.append({0: nframes_dim}) # gr: (nframes, nloc, nnei, 3) + if has_fparam: + shapes.append({0: nframes_dim}) # fparam: (nframes, nfp) + if has_aparam: + shapes.append({0: nframes_dim}) # aparam: (nframes, nloc, nap) + return tuple(shapes) diff --git a/source/tests/pt_expt/fitting/test_dipole_fitting.py b/source/tests/pt_expt/fitting/test_dipole_fitting.py index 959b23d6ea..7892f3d31d 100644 --- a/source/tests/pt_expt/fitting/test_dipole_fitting.py +++ b/source/tests/pt_expt/fitting/test_dipole_fitting.py @@ -23,6 +23,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_fitting_dynamic_shapes, +) class TestDipoleFitting(TestCaseSingleFrameWithNlist): @@ -169,3 +173,15 @@ def fn(descriptor, atype, gr): rtol=1e-10, atol=1e-10, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_fitting_dynamic_shapes(has_gr=True) + inputs = (descriptor, atype, gr) + export_save_load_and_compare( + fn, + inputs, + (ret_eager, grad_eager), + dynamic_shapes, + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_dos_fitting.py b/source/tests/pt_expt/fitting/test_dos_fitting.py index 340088e672..a82f285c3b 100644 --- a/source/tests/pt_expt/fitting/test_dos_fitting.py +++ b/source/tests/pt_expt/fitting/test_dos_fitting.py @@ -23,6 +23,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_fitting_dynamic_shapes, +) class TestDOSFittingNet(TestCaseSingleFrameWithNlist): @@ -165,3 +169,15 @@ def fn(descriptor, atype): rtol=1e-10, atol=1e-10, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_fitting_dynamic_shapes(has_gr=False) + inputs = (descriptor, atype) + export_save_load_and_compare( + fn, + inputs, + (ret_eager, grad_eager), + dynamic_shapes, + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_ener_fitting.py b/source/tests/pt_expt/fitting/test_ener_fitting.py index fe55bd628a..fe61daf6bc 100644 --- a/source/tests/pt_expt/fitting/test_ener_fitting.py +++ b/source/tests/pt_expt/fitting/test_ener_fitting.py @@ -3,6 +3,9 @@ import numpy as np import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, @@ -20,6 +23,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, +) class TestEnergyFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): @@ -89,6 +95,81 @@ def test_serialize_has_correct_type(self) -> None: efn2 = EnergyFittingNet.deserialize(serialized).to(self.device) self.assertIsInstance(efn2, EnergyFittingNet) + def test_make_fx(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + for nfp, nap in [(0, 0), (3, 4)]: + efn = ( + EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=nfp, + numb_aparam=nap, + precision="float64", + ) + .to(self.device) + .eval() + ) + + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + fparam = ( + torch.from_numpy(rng.standard_normal((self.nf, nfp))).to(self.device) + if nfp > 0 + else None + ) + aparam = ( + torch.from_numpy(rng.standard_normal((self.nf, self.nloc, nap))).to( + self.device + ) + if nap > 0 + else None + ) + + def fn(descriptor, atype, fparam, aparam): + descriptor = descriptor.detach().requires_grad_(True) + ret = efn(descriptor, atype, fparam=fparam, aparam=aparam)["energy"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + ret_eager, grad_eager = fn(descriptor, atype, fparam, aparam) + traced = make_fx(fn)(descriptor, atype, fparam, aparam) + ret_traced, grad_traced = traced(descriptor, atype, fparam, aparam) + np.testing.assert_allclose( + ret_eager.detach().cpu().numpy(), + ret_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + # --- symbolic trace + export + .pte round-trip --- + nframes_dim = torch.export.Dim("nframes", min=1) + dynamic_shapes = ( + {0: nframes_dim}, # descriptor + {0: nframes_dim}, # atype + {0: nframes_dim} if fparam is not None else None, # fparam + {0: nframes_dim} if aparam is not None else None, # aparam + ) + inputs = (descriptor, atype, fparam, aparam) + export_save_load_and_compare( + fn, + inputs, + (ret_eager, grad_eager), + dynamic_shapes, + rtol=1e-10, + atol=1e-10, + ) + def test_torch_export_simple(self) -> None: """Test that EnergyFittingNet can be exported with torch.export.""" nf, nloc, nnei = self.nlist.shape diff --git a/source/tests/pt_expt/fitting/test_invar_fitting.py b/source/tests/pt_expt/fitting/test_invar_fitting.py index cf54ea500b..72f23f8e05 100644 --- a/source/tests/pt_expt/fitting/test_invar_fitting.py +++ b/source/tests/pt_expt/fitting/test_invar_fitting.py @@ -4,6 +4,9 @@ import numpy as np import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, @@ -21,6 +24,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, +) class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): @@ -226,6 +232,84 @@ def test_get_set(self) -> None: foo, ifn0[ii].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) + def test_make_fx(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + for nfp, nap in [(0, 0), (3, 4)]: + ifn0 = ( + InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=True, + precision="float64", + ) + .to(self.device) + .eval() + ) + + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + fparam = ( + torch.from_numpy(rng.standard_normal((self.nf, nfp))).to(self.device) + if nfp > 0 + else None + ) + aparam = ( + torch.from_numpy(rng.standard_normal((self.nf, self.nloc, nap))).to( + self.device + ) + if nap > 0 + else None + ) + + def fn(descriptor, atype, fparam, aparam): + descriptor = descriptor.detach().requires_grad_(True) + ret = ifn0(descriptor, atype, fparam=fparam, aparam=aparam)["energy"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + ret_eager, grad_eager = fn(descriptor, atype, fparam, aparam) + traced = make_fx(fn)(descriptor, atype, fparam, aparam) + ret_traced, grad_traced = traced(descriptor, atype, fparam, aparam) + np.testing.assert_allclose( + ret_eager.detach().cpu().numpy(), + ret_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + # --- symbolic trace + export + .pte round-trip --- + nframes_dim = torch.export.Dim("nframes", min=1) + dynamic_shapes = ( + {0: nframes_dim}, # descriptor + {0: nframes_dim}, # atype + {0: nframes_dim} if fparam is not None else None, # fparam + {0: nframes_dim} if aparam is not None else None, # aparam + ) + inputs = (descriptor, atype, fparam, aparam) + export_save_load_and_compare( + fn, + inputs, + (ret_eager, grad_eager), + dynamic_shapes, + rtol=1e-10, + atol=1e-10, + ) + def test_torch_export_simple(self) -> None: """Test that InvarFitting can be exported with torch.export.""" nf, nloc, nnei = self.nlist.shape diff --git a/source/tests/pt_expt/fitting/test_polar_fitting.py b/source/tests/pt_expt/fitting/test_polar_fitting.py index e11b4455e5..ae4799e164 100644 --- a/source/tests/pt_expt/fitting/test_polar_fitting.py +++ b/source/tests/pt_expt/fitting/test_polar_fitting.py @@ -23,6 +23,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_fitting_dynamic_shapes, +) class TestPolarFitting(TestCaseSingleFrameWithNlist): @@ -169,3 +173,15 @@ def fn(descriptor, atype, gr): rtol=1e-10, atol=1e-10, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_fitting_dynamic_shapes(has_gr=True) + inputs = (descriptor, atype, gr) + export_save_load_and_compare( + fn, + inputs, + (ret_eager, grad_eager), + dynamic_shapes, + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_property_fitting.py b/source/tests/pt_expt/fitting/test_property_fitting.py index 19177be849..7950931412 100644 --- a/source/tests/pt_expt/fitting/test_property_fitting.py +++ b/source/tests/pt_expt/fitting/test_property_fitting.py @@ -23,6 +23,10 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + export_save_load_and_compare, + make_fitting_dynamic_shapes, +) class TestPropertyFittingNet(TestCaseSingleFrameWithNlist): @@ -165,3 +169,15 @@ def fn(descriptor, atype): rtol=1e-10, atol=1e-10, ) + + # --- symbolic trace + export + .pte round-trip --- + dynamic_shapes = make_fitting_dynamic_shapes(has_gr=False) + inputs = (descriptor, atype) + export_save_load_and_compare( + fn, + inputs, + (ret_eager, grad_eager), + dynamic_shapes, + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/model/test_dipole_model.py b/source/tests/pt_expt/model/test_dipole_model.py index 4dafd9d0ae..a19feacf36 100644 --- a/source/tests/pt_expt/model/test_dipole_model.py +++ b/source/tests/pt_expt/model/test_dipole_model.py @@ -24,6 +24,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + model_forward_lower_export_round_trip, +) class TestDipoleModel(unittest.TestCase): @@ -135,52 +138,16 @@ def test_forward_lower_exportable(self) -> None: fparam = None aparam = None - ret_eager = md_pt.forward_lower( - ext_coord.requires_grad_(True), - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - ) - - traced = md_pt.forward_lower_exportable( + model_forward_lower_export_round_trip( + md_pt, ext_coord, ext_atype, nlist_t, mapping_t, - fparam=fparam, - aparam=aparam, + fparam, + aparam, + output_keys=("dipole", "global_dipole"), ) - self.assertIsInstance(traced, torch.nn.Module) - - exported = torch.export.export( - traced, - (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), - strict=False, - ) - self.assertIsNotNone(exported) - - ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) - ret_exported = exported.module()( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam - ) - - for key in ("dipole", "global_dipole"): - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_traced[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"traced vs eager: {key}", - ) - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_exported[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"exported vs eager: {key}", - ) if __name__ == "__main__": diff --git a/source/tests/pt_expt/model/test_dos_model.py b/source/tests/pt_expt/model/test_dos_model.py index 993c55972b..9a5b697978 100644 --- a/source/tests/pt_expt/model/test_dos_model.py +++ b/source/tests/pt_expt/model/test_dos_model.py @@ -24,6 +24,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + model_forward_lower_export_round_trip, +) class TestDOSModel(unittest.TestCase): @@ -135,52 +138,16 @@ def test_forward_lower_exportable(self) -> None: fparam = None aparam = None - ret_eager = md_pt.forward_lower( - ext_coord.requires_grad_(True), - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - ) - - traced = md_pt.forward_lower_exportable( + model_forward_lower_export_round_trip( + md_pt, ext_coord, ext_atype, nlist_t, mapping_t, - fparam=fparam, - aparam=aparam, + fparam, + aparam, + output_keys=("atom_dos", "dos"), ) - self.assertIsInstance(traced, torch.nn.Module) - - exported = torch.export.export( - traced, - (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), - strict=False, - ) - self.assertIsNotNone(exported) - - ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) - ret_exported = exported.module()( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam - ) - - for key in ("atom_dos", "dos"): - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_traced[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"traced vs eager: {key}", - ) - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_exported[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"exported vs eager: {key}", - ) if __name__ == "__main__": diff --git a/source/tests/pt_expt/model/test_dp_zbl_model.py b/source/tests/pt_expt/model/test_dp_zbl_model.py index 1fa1d332e5..80d5f8b844 100644 --- a/source/tests/pt_expt/model/test_dp_zbl_model.py +++ b/source/tests/pt_expt/model/test_dp_zbl_model.py @@ -31,6 +31,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + model_forward_lower_export_round_trip, +) TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) TAB_FILE = os.path.join( @@ -185,52 +188,16 @@ def test_forward_lower_exportable(self) -> None: fparam = None aparam = None - ret_eager = md_pt.forward_lower( - ext_coord.requires_grad_(True), - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - ) - - traced = md_pt.forward_lower_exportable( + model_forward_lower_export_round_trip( + md_pt, ext_coord, ext_atype, nlist_t, mapping_t, - fparam=fparam, - aparam=aparam, + fparam, + aparam, + output_keys=("atom_energy", "energy"), ) - self.assertIsInstance(traced, torch.nn.Module) - - exported = torch.export.export( - traced, - (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), - strict=False, - ) - self.assertIsNotNone(exported) - - ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) - ret_exported = exported.module()( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam - ) - - for key in ("atom_energy", "energy"): - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_traced[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"traced vs eager: {key}", - ) - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_exported[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"exported vs eager: {key}", - ) if __name__ == "__main__": diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index f4fd9106e8..cd520b5cc0 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -331,6 +331,76 @@ def test_forward_lower_exportable(self) -> None: "aparam may be baked in as a constant", ) + # --- symbolic trace + export with dynamic shapes + .pte round-trip --- + # Use nf=5 to avoid two specialization traps: + # nf=1 makes make_fx specialize on the scalar case; + # nf=N where N matches numb_fparam or numb_aparam causes + # PyTorch's symbolic tracer to merge unrelated dim symbols. + import tempfile + + from deepmd.pt_expt.utils.serialization import ( + _build_dynamic_shapes, + ) + + inputs_5f = tuple( + torch.cat([t] * 5, dim=0) + for t in ( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_zero, + ) + ) + + traced_sym = md.forward_common_lower_exportable( + inputs_5f[0], + inputs_5f[1], + inputs_5f[2], + inputs_5f[3], + fparam=inputs_5f[4], + aparam=inputs_5f[5], + do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + + dynamic_shapes = _build_dynamic_shapes(*inputs_5f) + exported_dyn = torch.export.export( + traced_sym, + inputs_5f, + dynamic_shapes=dynamic_shapes, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".pte") as f: + torch.export.save(exported_dyn, f.name) + loaded = torch.export.load(f.name).module() + + # Compare loaded vs eager at nf=1 (different shapes) + ret_common = md.forward_common_lower( + ext_coord.clone().requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam_zero, + aparam=aparam_zero, + do_atomic_virial=True, + ) + ret_loaded_1f = loaded( + ext_coord, ext_atype, nlist_t, mapping_t, fparam_zero, aparam_zero + ) + for key in ret_common: + np.testing.assert_allclose( + ret_common[key].detach().cpu().numpy(), + ret_loaded_1f[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"loaded vs eager (nf=1): {key}", + ) + def test_dp_consistency(self) -> None: """Test numerical consistency with dpmodel (energy values).""" # Build dpmodel version diff --git a/source/tests/pt_expt/model/test_export_pipeline.py b/source/tests/pt_expt/model/test_export_pipeline.py new file mode 100644 index 0000000000..8bff3d7130 --- /dev/null +++ b/source/tests/pt_expt/model/test_export_pipeline.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Full export pipeline tests mirroring _trace_and_export in serialization.py. + +Each test case exercises the complete dp freeze pipeline: + 1. Build model via get_model + 2. Serialize → deserialize round-trip + 3. Eager reference + 4. make_fx tracing with tracing_mode="symbolic" + 5. torch.export.export with dynamic shapes + 6. .pte save → load round-trip + 7. Verify loaded matches eager (same shapes) + 8. Verify loaded matches eager (different shapes) + 9. Verify fparam actually affects output (when with_fparam=True) +""" + +import tempfile + +import numpy as np +import pytest +import torch + +import deepmd.pt_expt.utils.env as _env +from deepmd.pt_expt.model.get_model import ( + get_model, +) +from deepmd.pt_expt.model.model import ( + BaseModel, +) +from deepmd.pt_expt.utils.serialization import ( + _build_dynamic_shapes, + _collect_metadata, + _make_sample_inputs, +) + +CONFIGS = { + "se_e2_a": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 6], + "rcut": 4.0, + "rcut_smth": 0.5, + "neuron": [8, 16], + "axis_neuron": 4, + "seed": 1, + }, + "fitting_net": {"neuron": [16, 16], "seed": 1}, + }, + "dpa1": { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa1", + "sel": 12, + "rcut": 4.0, + "rcut_smth": 0.5, + "neuron": [8, 16], + "axis_neuron": 4, + "attn": 4, + "attn_layer": 1, + "attn_dotr": True, + "seed": 1, + }, + "fitting_net": {"neuron": [16, 16], "seed": 1}, + }, + "dpa2": { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 12, + "neuron": [6, 12], + "axis_neuron": 3, + }, + "repformer": { + "rcut": 3.0, + "rcut_smth": 0.5, + "nsel": 6, + "nlayers": 2, + "g1_dim": 8, + "g2_dim": 4, + }, + }, + "fitting_net": {"neuron": [16, 16], "seed": 1}, + }, +} + + +def _get_config(descriptor_type: str, with_fparam: bool) -> dict: + """Return a deep copy of the config with optional fparam.""" + import copy + + config = copy.deepcopy(CONFIGS[descriptor_type]) + if with_fparam: + config["fitting_net"]["numb_fparam"] = 2 + return config + + +class TestExportPipeline: + @pytest.mark.parametrize("descriptor_type", ["se_e2_a", "dpa1", "dpa2"]) + @pytest.mark.parametrize("with_fparam", [False, True]) # frame parameter + def test_export_pipeline(self, descriptor_type, with_fparam) -> None: + config = _get_config(descriptor_type, with_fparam) + + # 1. Build model via get_model (same as dp freeze) + model = get_model(config) + model.to("cpu") + model.eval() + + # 2. Serialize → deserialize round-trip (same as dp freeze) + model_data = model.serialize() + model2 = BaseModel.deserialize(model_data) + model2.to("cpu") + model2.eval() + + # 3. Create sample inputs on CPU for tracing (nframes=5 as in _trace_and_export) + orig_device = _env.DEVICE + _env.DEVICE = torch.device("cpu") + try: + inputs_trace = _make_sample_inputs(model2, nframes=5, nloc=7) + finally: + _env.DEVICE = orig_device + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = inputs_trace + + # 4. Eager reference + eager_out = model2.forward_common_lower( + ext_coord.detach().requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + # 5. Trace with symbolic mode (same as dp freeze) + traced = model2.forward_common_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + + # 6. Export with dynamic shapes (same as dp freeze) + dynamic_shapes = _build_dynamic_shapes( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ) + exported = torch.export.export( + traced, + (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), + dynamic_shapes=dynamic_shapes, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + # 7. .pte save → load round-trip + with tempfile.NamedTemporaryFile(suffix=".pte") as tmp: + torch.export.save(exported, tmp.name) + loaded = torch.export.load(tmp.name).module() + + # 8. Verify: traced output matches eager (same shapes as trace) + traced_out = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + for key in eager_out: + np.testing.assert_allclose( + eager_out[key].detach().cpu().numpy(), + traced_out[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"traced vs eager (same shape): {key}", + ) + + # 9. Verify: loaded (.pte) output matches eager (same shapes) + loaded_out = loaded(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + for key in eager_out: + np.testing.assert_allclose( + eager_out[key].detach().cpu().numpy(), + loaded_out[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"loaded (.pte) vs eager (same shape): {key}", + ) + + # 10. Verify: loaded output matches eager with DIFFERENT shapes + # (tests that dynamic shapes work) + _env.DEVICE = torch.device("cpu") + try: + inputs_infer = _make_sample_inputs(model2, nframes=3, nloc=11) + finally: + _env.DEVICE = orig_device + ( + ext_coord2, + ext_atype2, + nlist_t2, + mapping_t2, + fparam2, + aparam2, + ) = inputs_infer + + eager_out2 = model2.forward_common_lower( + ext_coord2.detach().requires_grad_(True), + ext_atype2, + nlist_t2, + mapping_t2, + fparam=fparam2, + aparam=aparam2, + do_atomic_virial=True, + ) + loaded_out2 = loaded( + ext_coord2, ext_atype2, nlist_t2, mapping_t2, fparam2, aparam2 + ) + for key in eager_out2: + np.testing.assert_allclose( + eager_out2[key].detach().cpu().numpy(), + loaded_out2[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"loaded (.pte) vs eager (different shape): {key}", + ) + + # 11. Verify: metadata correctness + metadata = _collect_metadata(model2) + assert metadata["type_map"] == config["type_map"] + assert metadata["dim_fparam"] == (2 if with_fparam else 0) + assert metadata["rcut"] == model2.get_rcut() + assert metadata["sel"] == model2.get_sel() + assert metadata["mixed_types"] == model2.mixed_types() + + # 12. Verify: fparam actually affects output (when with_fparam=True) + if with_fparam: + fparam_ones = torch.ones_like(fparam) + eager_out_fp1 = model2.forward_common_lower( + ext_coord.detach().requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam_ones, + aparam=aparam, + do_atomic_virial=True, + ) + loaded_out_fp1 = loaded( + ext_coord, ext_atype, nlist_t, mapping_t, fparam_ones, aparam + ) + # Loaded with fparam=1 should match eager with fparam=1 + for key in eager_out_fp1: + np.testing.assert_allclose( + eager_out_fp1[key].detach().cpu().numpy(), + loaded_out_fp1[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"loaded (.pte) vs eager (modified fparam): {key}", + ) + # Output with fparam=0 should differ from fparam=1 + assert not np.allclose( + eager_out["energy"].detach().cpu().numpy(), + eager_out_fp1["energy"].detach().cpu().numpy(), + ), ( + "Changing fparam did not change output — " + "fparam may be baked in as a constant" + ) diff --git a/source/tests/pt_expt/model/test_linear_ener_model.py b/source/tests/pt_expt/model/test_linear_ener_model.py deleted file mode 100644 index de276f42ae..0000000000 --- a/source/tests/pt_expt/model/test_linear_ener_model.py +++ /dev/null @@ -1,192 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import copy -import unittest -from unittest.mock import ( - patch, -) - -import numpy as np -import torch - -from deepmd.pt_expt.model.dp_linear_model import ( - LinearEnergyModel, -) -from deepmd.pt_expt.model.get_model import ( - get_linear_model, - get_standard_model, -) -from deepmd.pt_expt.utils import ( - env, -) - -_sub_model_1 = { - "descriptor": { - "type": "se_atten", - "sel": 40, - "rcut_smth": 0.5, - "rcut": 4.0, - "neuron": [3, 6], - "axis_neuron": 2, - "attn": 8, - "attn_layer": 2, - "attn_dotr": True, - "attn_mask": False, - "activation_function": "tanh", - "scaling_factor": 1.0, - "normalize": False, - "temperature": 1.0, - "set_davg_zero": True, - "type_one_side": True, - "seed": 1, - }, - "fitting_net": { - "neuron": [5, 5], - "resnet_dt": True, - "seed": 1, - }, -} -_sub_model_2 = copy.deepcopy(_sub_model_1) -_sub_model_2["descriptor"]["seed"] = 2 -_sub_model_2["fitting_net"]["seed"] = 2 - -_type_map = ["O", "H"] - - -class TestLinearEnerWeights(unittest.TestCase): - """Test that weights parameter affects energy, force, and virial.""" - - def setUp(self) -> None: - self.device = env.DEVICE - - # Build individual standard models for reference - std_data_1 = copy.deepcopy(_sub_model_1) - std_data_1["type_map"] = copy.deepcopy(_type_map) - std_data_2 = copy.deepcopy(_sub_model_2) - std_data_2["type_map"] = copy.deepcopy(_type_map) - self.std_model_1 = get_standard_model(std_data_1) - self.std_model_2 = get_standard_model(std_data_2) - - # Build linear models with different weights - def _make_linear(weights): - data = { - "type_map": copy.deepcopy(_type_map), - "models": [copy.deepcopy(_sub_model_1), copy.deepcopy(_sub_model_2)], - "weights": weights, - } - return get_linear_model(data) - - self.model_mean = _make_linear("mean") - self.model_sum = _make_linear("sum") - self.model_custom = _make_linear([0.3, 0.7]) - - # Sync sub-model weights so linear models use the same params as std models - for linear_model in [self.model_mean, self.model_sum, self.model_custom]: - linear_model.atomic_model.models[0].load_state_dict( - self.std_model_1.atomic_model.state_dict() - ) - linear_model.atomic_model.models[1].load_state_dict( - self.std_model_2.atomic_model.state_dict() - ) - - # Test inputs - generator = torch.Generator(device=self.device).manual_seed(20) - cell = torch.rand( - [3, 3], dtype=torch.float64, device=self.device, generator=generator - ) - cell = (cell + cell.T) + 5.0 * torch.eye( - 3, dtype=torch.float64, device=self.device - ) - self.cell = cell.unsqueeze(0) - natoms = 6 - coord = torch.rand( - [natoms, 3], - dtype=torch.float64, - device=self.device, - generator=generator, - ) - coord = torch.matmul(coord, cell) - self.coord = coord.unsqueeze(0) - self.atype = torch.tensor( - [[0, 0, 0, 1, 1, 1]], dtype=torch.int64, device=self.device - ) - self.box = self.cell.reshape(1, 9) - - def _eval(self, model): - coord = self.coord.clone().detach().requires_grad_(True) - ret = model( - coord, - self.atype, - box=self.box, - ) - return {k: v.detach().cpu().numpy() for k, v in ret.items()} - - def test_mean_weights(self) -> None: - ret1 = self._eval(self.std_model_1) - ret2 = self._eval(self.std_model_2) - ret_mean = self._eval(self.model_mean) - for key in ["energy", "force", "virial"]: - expected = 0.5 * ret1[key] + 0.5 * ret2[key] - np.testing.assert_allclose(ret_mean[key], expected, atol=1e-10) - - def test_sum_weights(self) -> None: - ret1 = self._eval(self.std_model_1) - ret2 = self._eval(self.std_model_2) - ret_sum = self._eval(self.model_sum) - for key in ["energy", "force", "virial"]: - expected = ret1[key] + ret2[key] - np.testing.assert_allclose(ret_sum[key], expected, atol=1e-10) - - def test_custom_weights(self) -> None: - ret1 = self._eval(self.std_model_1) - ret2 = self._eval(self.std_model_2) - ret_custom = self._eval(self.model_custom) - for key in ["energy", "force", "virial"]: - expected = 0.3 * ret1[key] + 0.7 * ret2[key] - np.testing.assert_allclose(ret_custom[key], expected, atol=1e-10) - - -class TestLinearUpdateSel(unittest.TestCase): - """Test that update_sel writes updated sub-model configs back.""" - - @patch("deepmd.pt_expt.model.dp_linear_model.DPModelCommon.update_sel") - def test_updated_sel_written_back(self, mock_update_sel) -> None: - """Verify that update_sel returns configs with updated sel values.""" - - def side_effect(train_data, type_map, sub_jdata): - updated = copy.deepcopy(sub_jdata) - updated["descriptor"]["sel"] = 99 - return updated, 0.5 - - mock_update_sel.side_effect = side_effect - - local_jdata = { - "type_map": ["O", "H"], - "models": [ - { - "descriptor": {"type": "se_atten", "sel": 10, "rcut": 4.0}, - "fitting_net": {"neuron": [5, 5]}, - }, - { - "descriptor": {"type": "se_atten", "sel": 10, "rcut": 4.0}, - "fitting_net": {"neuron": [5, 5]}, - }, - ], - "weights": "mean", - } - - result, min_dist = LinearEnergyModel.update_sel( - train_data=None, - type_map=["O", "H"], - local_jdata=local_jdata, - ) - - for idx, sub_model in enumerate(result["models"]): - self.assertEqual( - sub_model["descriptor"]["sel"], - 99, - f"Sub-model {idx} sel was not updated in returned config", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/source/tests/pt_expt/model/test_linear_model.py b/source/tests/pt_expt/model/test_linear_model.py new file mode 100644 index 0000000000..a18cabd9e1 --- /dev/null +++ b/source/tests/pt_expt/model/test_linear_model.py @@ -0,0 +1,498 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np +import torch + +from deepmd.dpmodel.atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + LinearEnergyAtomicModel, +) +from deepmd.dpmodel.descriptor import DescrptDPA1 as DPDescrptDPA1 +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.dpmodel.model.make_model import ( + make_model, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.pt_expt.model import ( + LinearEnergyModel, +) +from deepmd.pt_expt.model.dp_linear_model import ( + LinearEnergyModel as LinearEnergyModelDirect, +) +from deepmd.pt_expt.model.get_model import ( + get_linear_model, + get_standard_model, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) +from ..export_helpers import ( + model_forward_lower_export_round_trip, +) + + +class TestLinearModel(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + self.natoms = 5 + self.rcut = 4.0 + self.rcut_smth = 0.5 + self.sel = 20 + self.nt = 2 + self.type_map = ["foo", "bar"] + + generator = torch.Generator(device=self.device).manual_seed(GLOBAL_SEED) + cell = torch.rand( + [3, 3], dtype=torch.float64, device=self.device, generator=generator + ) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device=self.device) + self.cell = cell.unsqueeze(0) + coord = torch.rand( + [self.natoms, 3], + dtype=torch.float64, + device=self.device, + generator=generator, + ) + coord = torch.matmul(coord, cell) + self.coord = coord.unsqueeze(0).to(self.device) + self.atype = torch.tensor( + [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device + ) + + def _make_dp_atomic_model(self, seed: int) -> DPAtomicModel: + """Build a dpmodel DPAtomicModel with DPA1 descriptor (mixed type).""" + ds = DPDescrptDPA1( + rcut_smth=self.rcut_smth, + rcut=self.rcut, + sel=self.sel, + ntypes=self.nt, + neuron=[3, 6], + axis_neuron=2, + attn=4, + attn_layer=2, + attn_dotr=True, + attn_mask=False, + activation_function="tanh", + set_davg_zero=True, + type_one_side=True, + seed=seed, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=seed, + ) + return DPAtomicModel(ds, ft, type_map=self.type_map) + + def _make_dp_linear_model(self) -> LinearEnergyAtomicModel: + """Build a dpmodel LinearEnergyAtomicModel with two sub-models.""" + model1 = self._make_dp_atomic_model(seed=GLOBAL_SEED) + model2 = self._make_dp_atomic_model(seed=GLOBAL_SEED + 1) + return LinearEnergyAtomicModel( + models=[model1, model2], + type_map=self.type_map, + ) + + def _prepare_lower_inputs(self): + """Build extended coords, atype, nlist, mapping as torch tensors.""" + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, self.natoms, 3), + cell_np.reshape(1, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype_np, cell_np, self.rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + self.natoms, + self.rcut, + [self.sel], + distinguish_types=False, + ) + extended_coord = extended_coord.reshape(1, -1, 3) + return ( + torch.tensor(extended_coord, dtype=torch.float64, device=self.device), + torch.tensor(extended_atype, dtype=torch.int64, device=self.device), + torch.tensor(nlist, dtype=torch.int64, device=self.device), + torch.tensor(mapping, dtype=torch.int64, device=self.device), + ) + + def test_linear_model_consistency(self) -> None: + """Create a LinearEnergyModel, run forward() and forward_lower(), + verify outputs have correct keys and shapes. + """ + md_dp = self._make_dp_linear_model() + md_pt = LinearEnergyModel.deserialize(md_dp.serialize()).to(self.device) + md_pt.eval() + + # Test forward() + coord = self.coord.clone().requires_grad_(True) + ret = md_pt(coord, self.atype, self.cell.reshape(1, 9)) + + self.assertIn("energy", ret) + self.assertIn("atom_energy", ret) + self.assertIn("force", ret) + self.assertIn("virial", ret) + + self.assertEqual(ret["energy"].shape, (1, 1)) + self.assertEqual(ret["atom_energy"].shape, (1, self.natoms, 1)) + self.assertEqual(ret["force"].shape, (1, self.natoms, 3)) + self.assertEqual(ret["virial"].shape, (1, 9)) + + # Test forward_lower() + ext_coord, ext_atype, nlist_t, mapping_t = self._prepare_lower_inputs() + ret_lower = md_pt.forward_lower( + ext_coord.requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + ) + + self.assertIn("energy", ret_lower) + self.assertIn("atom_energy", ret_lower) + self.assertIn("extended_force", ret_lower) + self.assertIn("virial", ret_lower) + + nall = ext_coord.shape[1] + self.assertEqual(ret_lower["energy"].shape, (1, 1)) + self.assertEqual(ret_lower["atom_energy"].shape, (1, self.natoms, 1)) + self.assertEqual(ret_lower["extended_force"].shape, (1, nall, 3)) + self.assertEqual(ret_lower["virial"].shape, (1, 9)) + + def test_linear_model_serialize(self) -> None: + """Create a LinearEnergyModel, serialize, deserialize, verify + outputs match. + """ + md_dp = self._make_dp_linear_model() + md_pt0 = LinearEnergyModel.deserialize(md_dp.serialize()).to(self.device) + md_pt0.eval() + + # Serialize and deserialize + md_pt1 = LinearEnergyModel.deserialize(md_pt0.serialize()).to(self.device) + md_pt1.eval() + + coord = self.coord.clone().requires_grad_(True) + ret0 = md_pt0(coord, self.atype, self.cell.reshape(1, 9)) + + coord = self.coord.clone().requires_grad_(True) + ret1 = md_pt1(coord, self.atype, self.cell.reshape(1, 9)) + + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg="energy mismatch after serialize/deserialize", + ) + np.testing.assert_allclose( + ret0["atom_energy"].detach().cpu().numpy(), + ret1["atom_energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg="atom_energy mismatch after serialize/deserialize", + ) + np.testing.assert_allclose( + ret0["force"].detach().cpu().numpy(), + ret1["force"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg="force mismatch after serialize/deserialize", + ) + np.testing.assert_allclose( + ret0["virial"].detach().cpu().numpy(), + ret1["virial"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg="virial mismatch after serialize/deserialize", + ) + + def test_linear_model_dpmodel_consistency(self) -> None: + """Compare pt_expt LinearEnergyModel output with dpmodel + LinearEnergyAtomicModel output (same weights) to verify + cross-backend consistency. + """ + md_dp_atomic = self._make_dp_linear_model() + + # Build pt_expt version from the same serialized data + md_pt = LinearEnergyModel.deserialize(md_dp_atomic.serialize()).to(self.device) + md_pt.eval() + + # Use forward_lower for both backends to compare + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, self.natoms, 3), + cell_np.reshape(1, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype_np, cell_np, self.rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + self.natoms, + self.rcut, + [self.sel], + distinguish_types=False, + ) + extended_coord = extended_coord.reshape(1, -1, 3) + + # dpmodel forward_lower via make_model wrapper + DPLinearModel = make_model(LinearEnergyAtomicModel) + md_dp = DPLinearModel.deserialize(md_dp_atomic.serialize()) + ret_dp = md_dp.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + ) + + # pt_expt forward_lower + ext_coord = torch.tensor( + extended_coord, dtype=torch.float64, device=self.device + ) + ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=self.device) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=self.device) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) + ret_pt = md_pt.forward_lower( + ext_coord.requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + ) + + np.testing.assert_allclose( + ret_dp["energy_redu"], + ret_pt["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg="energy mismatch between dpmodel and pt_expt", + ) + np.testing.assert_allclose( + ret_dp["energy"], + ret_pt["atom_energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg="atom_energy mismatch between dpmodel and pt_expt", + ) + + def test_forward_lower_exportable(self) -> None: + """Test that LinearEnergyModel.forward_lower_exportable returns + an exportable module whose outputs match eager execution. + """ + md_dp = self._make_dp_linear_model() + md_pt = LinearEnergyModel.deserialize(md_dp.serialize()).to(self.device) + md_pt.eval() + + ext_coord, ext_atype, nlist_t, mapping_t = self._prepare_lower_inputs() + fparam = None + aparam = None + + model_forward_lower_export_round_trip( + md_pt, + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + output_keys=("atom_energy", "energy"), + ) + + +_sub_model_1 = { + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [3, 6], + "axis_neuron": 2, + "attn": 8, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "seed": 1, + }, +} +_sub_model_2 = copy.deepcopy(_sub_model_1) +_sub_model_2["descriptor"]["seed"] = 2 +_sub_model_2["fitting_net"]["seed"] = 2 + +_type_map = ["O", "H"] + + +class TestLinearEnerWeights(unittest.TestCase): + """Test that weights parameter affects energy, force, and virial.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + # Build individual standard models for reference + std_data_1 = copy.deepcopy(_sub_model_1) + std_data_1["type_map"] = copy.deepcopy(_type_map) + std_data_2 = copy.deepcopy(_sub_model_2) + std_data_2["type_map"] = copy.deepcopy(_type_map) + self.std_model_1 = get_standard_model(std_data_1) + self.std_model_2 = get_standard_model(std_data_2) + + # Build linear models with different weights + def _make_linear(weights): + data = { + "type_map": copy.deepcopy(_type_map), + "models": [copy.deepcopy(_sub_model_1), copy.deepcopy(_sub_model_2)], + "weights": weights, + } + return get_linear_model(data) + + self.model_mean = _make_linear("mean") + self.model_sum = _make_linear("sum") + self.model_custom = _make_linear([0.3, 0.7]) + + # Sync sub-model weights so linear models use the same params as std models + for linear_model in [self.model_mean, self.model_sum, self.model_custom]: + linear_model.atomic_model.models[0].load_state_dict( + self.std_model_1.atomic_model.state_dict() + ) + linear_model.atomic_model.models[1].load_state_dict( + self.std_model_2.atomic_model.state_dict() + ) + + # Test inputs + generator = torch.Generator(device=self.device).manual_seed(20) + cell = torch.rand( + [3, 3], dtype=torch.float64, device=self.device, generator=generator + ) + cell = (cell + cell.T) + 5.0 * torch.eye( + 3, dtype=torch.float64, device=self.device + ) + self.cell = cell.unsqueeze(0) + natoms = 6 + coord = torch.rand( + [natoms, 3], + dtype=torch.float64, + device=self.device, + generator=generator, + ) + coord = torch.matmul(coord, cell) + self.coord = coord.unsqueeze(0) + self.atype = torch.tensor( + [[0, 0, 0, 1, 1, 1]], dtype=torch.int64, device=self.device + ) + self.box = self.cell.reshape(1, 9) + + def _eval(self, model): + coord = self.coord.clone().detach().requires_grad_(True) + ret = model( + coord, + self.atype, + box=self.box, + ) + return {k: v.detach().cpu().numpy() for k, v in ret.items()} + + def test_mean_weights(self) -> None: + ret1 = self._eval(self.std_model_1) + ret2 = self._eval(self.std_model_2) + ret_mean = self._eval(self.model_mean) + for key in ["energy", "force", "virial"]: + expected = 0.5 * ret1[key] + 0.5 * ret2[key] + np.testing.assert_allclose(ret_mean[key], expected, atol=1e-10) + + def test_sum_weights(self) -> None: + ret1 = self._eval(self.std_model_1) + ret2 = self._eval(self.std_model_2) + ret_sum = self._eval(self.model_sum) + for key in ["energy", "force", "virial"]: + expected = ret1[key] + ret2[key] + np.testing.assert_allclose(ret_sum[key], expected, atol=1e-10) + + def test_custom_weights(self) -> None: + ret1 = self._eval(self.std_model_1) + ret2 = self._eval(self.std_model_2) + ret_custom = self._eval(self.model_custom) + for key in ["energy", "force", "virial"]: + expected = 0.3 * ret1[key] + 0.7 * ret2[key] + np.testing.assert_allclose(ret_custom[key], expected, atol=1e-10) + + +class TestLinearUpdateSel(unittest.TestCase): + """Test that update_sel writes updated sub-model configs back.""" + + @patch("deepmd.pt_expt.model.dp_linear_model.DPModelCommon.update_sel") + def test_updated_sel_written_back(self, mock_update_sel) -> None: + """Verify that update_sel returns configs with updated sel values.""" + + def side_effect(train_data, type_map, sub_jdata): + updated = copy.deepcopy(sub_jdata) + updated["descriptor"]["sel"] = 99 + return updated, 0.5 + + mock_update_sel.side_effect = side_effect + + local_jdata = { + "type_map": ["O", "H"], + "models": [ + { + "descriptor": {"type": "se_atten", "sel": 10, "rcut": 4.0}, + "fitting_net": {"neuron": [5, 5]}, + }, + { + "descriptor": {"type": "se_atten", "sel": 10, "rcut": 4.0}, + "fitting_net": {"neuron": [5, 5]}, + }, + ], + "weights": "mean", + } + + result, min_dist = LinearEnergyModelDirect.update_sel( + train_data=None, + type_map=["O", "H"], + local_jdata=local_jdata, + ) + + for idx, sub_model in enumerate(result["models"]): + self.assertEqual( + sub_model["descriptor"]["sel"], + 99, + f"Sub-model {idx} sel was not updated in returned config", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/model/test_polar_model.py b/source/tests/pt_expt/model/test_polar_model.py index acfa929db2..a1a12b8702 100644 --- a/source/tests/pt_expt/model/test_polar_model.py +++ b/source/tests/pt_expt/model/test_polar_model.py @@ -24,6 +24,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + model_forward_lower_export_round_trip, +) class TestPolarModel(unittest.TestCase): @@ -135,52 +138,16 @@ def test_forward_lower_exportable(self) -> None: fparam = None aparam = None - ret_eager = md_pt.forward_lower( - ext_coord.requires_grad_(True), - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - ) - - traced = md_pt.forward_lower_exportable( + model_forward_lower_export_round_trip( + md_pt, ext_coord, ext_atype, nlist_t, mapping_t, - fparam=fparam, - aparam=aparam, + fparam, + aparam, + output_keys=("polar", "global_polar"), ) - self.assertIsInstance(traced, torch.nn.Module) - - exported = torch.export.export( - traced, - (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), - strict=False, - ) - self.assertIsNotNone(exported) - - ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) - ret_exported = exported.module()( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam - ) - - for key in ("polar", "global_polar"): - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_traced[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"traced vs eager: {key}", - ) - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_exported[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"exported vs eager: {key}", - ) if __name__ == "__main__": diff --git a/source/tests/pt_expt/model/test_property_model.py b/source/tests/pt_expt/model/test_property_model.py index 12b12afea1..5359ee55c6 100644 --- a/source/tests/pt_expt/model/test_property_model.py +++ b/source/tests/pt_expt/model/test_property_model.py @@ -24,6 +24,9 @@ from ...seed import ( GLOBAL_SEED, ) +from ..export_helpers import ( + model_forward_lower_export_round_trip, +) class TestPropertyModel(unittest.TestCase): @@ -137,53 +140,17 @@ def test_forward_lower_exportable(self) -> None: fparam = None aparam = None - ret_eager = md_pt.forward_lower( - ext_coord.requires_grad_(True), - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - ) - - traced = md_pt.forward_lower_exportable( + var_name = md_pt.get_var_name() + model_forward_lower_export_round_trip( + md_pt, ext_coord, ext_atype, nlist_t, mapping_t, - fparam=fparam, - aparam=aparam, + fparam, + aparam, + output_keys=(f"atom_{var_name}", var_name), ) - self.assertIsInstance(traced, torch.nn.Module) - - exported = torch.export.export( - traced, - (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), - strict=False, - ) - self.assertIsNotNone(exported) - - ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) - ret_exported = exported.module()( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam - ) - - var_name = md_pt.get_var_name() - for key in (f"atom_{var_name}", var_name): - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_traced[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"traced vs eager: {key}", - ) - np.testing.assert_allclose( - ret_eager[key].detach().cpu().numpy(), - ret_exported[key].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, - err_msg=f"exported vs eager: {key}", - ) if __name__ == "__main__": diff --git a/source/tests/pt_expt/model/test_spin_ener_model.py b/source/tests/pt_expt/model/test_spin_ener_model.py index 8789f944e9..5dd6dbecf2 100644 --- a/source/tests/pt_expt/model/test_spin_ener_model.py +++ b/source/tests/pt_expt/model/test_spin_ener_model.py @@ -476,6 +476,68 @@ def test_forward_lower_exportable(self) -> None: err_msg=f"exported vs eager: {key}", ) + # --- symbolic trace + export with dynamic shapes + .pte round-trip --- + import tempfile + + # Use nf=2 data for tracing to avoid nframes=1 specialization + inputs_2f = ( + torch.cat([ext_coord_t, ext_coord_t], dim=0), + torch.cat([ext_atype_t, ext_atype_t], dim=0), + torch.cat([ext_spin_t, ext_spin_t], dim=0), + torch.cat([nlist_t, nlist_t], dim=0), + torch.cat([mapping_t, mapping_t], dim=0), + None, + None, + ) + + traced_sym = model.forward_lower_exportable( + inputs_2f[0], + inputs_2f[1], + inputs_2f[2], + inputs_2f[3], + inputs_2f[4], + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + + # Build dynamic shapes for spin model + # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) + nframes_dim = torch.export.Dim("nframes", min=1) + nall_dim = torch.export.Dim("nall", min=1) + nloc_dim = torch.export.Dim("nloc", min=1) + dynamic_shapes = ( + {0: nframes_dim, 1: nall_dim}, # ext_coord + {0: nframes_dim, 1: nall_dim}, # ext_atype + {0: nframes_dim, 1: nall_dim}, # ext_spin + {0: nframes_dim, 1: nloc_dim}, # nlist + {0: nframes_dim, 1: nall_dim}, # mapping + None, # fparam + None, # aparam + ) + exported_dyn = torch.export.export( + traced_sym, + inputs_2f, + dynamic_shapes=dynamic_shapes, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".pte") as f: + torch.export.save(exported_dyn, f.name) + loaded = torch.export.load(f.name).module() + + ret_loaded_1f = loaded( + ext_coord_t, ext_atype_t, ext_spin_t, nlist_t, mapping_t, None, None + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_loaded_1f[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"loaded vs eager (nf=1): {key}", + ) + if __name__ == "__main__": unittest.main()