Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepmd/pt_expt/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def forward_lower_exportable(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self

Expand All @@ -146,6 +147,6 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn)(
return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
3 changes: 2 additions & 1 deletion deepmd/pt_expt/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def forward_lower_exportable(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self

Expand All @@ -126,6 +127,6 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn)(
return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
3 changes: 2 additions & 1 deletion deepmd/pt_expt/model/dp_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def forward_lower_exportable(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self

Expand All @@ -151,7 +152,7 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn)(
return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)

Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt_expt/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def forward_lower_exportable(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self

Expand All @@ -148,6 +149,6 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn)(
return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
3 changes: 2 additions & 1 deletion deepmd/pt_expt/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def forward_lower_exportable(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self

Expand All @@ -126,6 +127,6 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn)(
return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
3 changes: 2 additions & 1 deletion deepmd/pt_expt/model/property_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def forward_lower_exportable(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
**make_fx_kwargs: Any,
) -> torch.nn.Module:
model = self

Expand All @@ -133,6 +134,6 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn)(
return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
9 changes: 7 additions & 2 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,17 @@ def _trace_and_export(
metadata = _collect_metadata(model)

# 3. Create sample inputs on CPU for tracing
# Use nframes=2 so make_fx doesn't specialize on nframes=1
# Use nframes=5 to avoid two specialization traps:
# - nframes=1 causes make_fx to specialize on the scalar case
# - nframes=N where N == numb_fparam or numb_aparam causes PyTorch's
# symbolic tracer to merge symbols (e.g. fparam.shape=(2,2) when
# nframes=2 and numb_fparam=2), so a guard on one dim constrains
# the other. 5 is unlikely to collide with typical param counts.
_orig_device = _env.DEVICE
_env.DEVICE = torch.device("cpu")
try:
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs(
model, nframes=2
model, nframes=5
Comment thread
wanghan-iapcm marked this conversation as resolved.
)
finally:
_env.DEVICE = _orig_device
Expand Down
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_dpa1.py
Comment thread
iProzd marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptDPA1(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -223,3 +227,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptDPA2(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -410,3 +414,15 @@ def fn(coord_ext, atype_ext, nlist, mapping):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=True)
inputs = (coord_ext, atype_ext, nlist, mapping)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptDPA3(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -244,3 +248,15 @@ def fn(coord_ext, atype_ext, nlist, mapping):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=True)
inputs = (coord_ext, atype_ext, nlist, mapping)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptHybrid(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -215,3 +219,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptSeAttenV2(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -218,3 +222,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptSeA(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -205,3 +209,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptSeR(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -200,3 +204,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptSeT(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -204,3 +208,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
16 changes: 16 additions & 0 deletions source/tests/pt_expt/descriptor/test_se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from ...seed import (
GLOBAL_SEED,
)
from ..export_helpers import (
export_save_load_and_compare,
make_descriptor_dynamic_shapes,
)


class TestDescrptSeTTebd(TestCaseSingleFrameWithNlist):
Expand Down Expand Up @@ -235,3 +239,15 @@ def fn(coord_ext, atype_ext, nlist):
rtol=rtol,
atol=atol,
)

# --- symbolic trace + export + .pte round-trip ---
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
inputs = (coord_ext, atype_ext, nlist)
export_save_load_and_compare(
fn,
inputs,
(rd_eager, grad_eager),
dynamic_shapes,
rtol=rtol,
atol=atol,
)
Loading
Loading