Skip to content

Commit 4880f21

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): enhance per-module export tests + fix nframes symbol collision (#5367)
## Summary - Enhance existing `test_make_fx` (9 descriptors, 6 fittings) and `test_forward_lower_exportable` (8 models) to cover the full `dp freeze` pipeline: symbolic tracing, `torch.export` with dynamic shapes, `.pte` save/load round-trip, and inference with different `nframes` - Add `**make_fx_kwargs` forwarding to 6 model `forward_lower_exportable` methods (dipole, dos, polar, property, dp_zbl, dp_linear) - Create `source/tests/pt_expt/export_helpers.py` with shared test helpers - Add new `test_make_fx` to `test_ener_fitting.py` and `test_invar_fitting.py` with fparam/aparam parametrization - Fix nframes symbol collision: change tracing `nframes` from 2 to 5 in `_trace_and_export` to avoid PyTorch's symbolic tracer merging dim symbols when `nframes == numb_fparam` ## Test plan - [x] 16 descriptor `test_make_fx` pass - [x] 6 fitting `test_make_fx` pass - [x] 8 model `test_forward_lower_exportable` pass - [x] `test_export_pipeline` with `fparam=True` passes for se_e2_a and dpa1 - [ ] CI <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * FX tracing for export is now caller-configurable for exportable model methods. * Increased sample frames used for export traces to avoid unwanted specialization/collisions. * **Tests** * Added comprehensive export → save → load round-trip validations across descriptor, fitting, and model tests. * Introduced shared test helpers/utilities for dynamic-shape export verification. * Added export-pipeline and linear-model test coverage; many tests refactored to use the new helpers and one obsolete test module removed. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 229cd57 commit 4880f21

33 files changed

+1625
-406
lines changed

deepmd/pt_expt/model/dipole_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def forward_lower_exportable(
124124
fparam: torch.Tensor | None = None,
125125
aparam: torch.Tensor | None = None,
126126
do_atomic_virial: bool = False,
127+
**make_fx_kwargs: Any,
127128
) -> torch.nn.Module:
128129
model = self
129130

@@ -146,6 +147,6 @@ def fn(
146147
do_atomic_virial=do_atomic_virial,
147148
)
148149

149-
return make_fx(fn)(
150+
return make_fx(fn, **make_fx_kwargs)(
150151
extended_coord, extended_atype, nlist, mapping, fparam, aparam
151152
)

deepmd/pt_expt/model/dos_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def forward_lower_exportable(
104104
fparam: torch.Tensor | None = None,
105105
aparam: torch.Tensor | None = None,
106106
do_atomic_virial: bool = False,
107+
**make_fx_kwargs: Any,
107108
) -> torch.nn.Module:
108109
model = self
109110

@@ -126,6 +127,6 @@ def fn(
126127
do_atomic_virial=do_atomic_virial,
127128
)
128129

129-
return make_fx(fn)(
130+
return make_fx(fn, **make_fx_kwargs)(
130131
extended_coord, extended_atype, nlist, mapping, fparam, aparam
131132
)

deepmd/pt_expt/model/dp_linear_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def forward_lower_exportable(
129129
fparam: torch.Tensor | None = None,
130130
aparam: torch.Tensor | None = None,
131131
do_atomic_virial: bool = False,
132+
**make_fx_kwargs: Any,
132133
) -> torch.nn.Module:
133134
model = self
134135

@@ -151,7 +152,7 @@ def fn(
151152
do_atomic_virial=do_atomic_virial,
152153
)
153154

154-
return make_fx(fn)(
155+
return make_fx(fn, **make_fx_kwargs)(
155156
extended_coord, extended_atype, nlist, mapping, fparam, aparam
156157
)
157158

deepmd/pt_expt/model/dp_zbl_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def forward_lower_exportable(
126126
fparam: torch.Tensor | None = None,
127127
aparam: torch.Tensor | None = None,
128128
do_atomic_virial: bool = False,
129+
**make_fx_kwargs: Any,
129130
) -> torch.nn.Module:
130131
model = self
131132

@@ -148,6 +149,6 @@ def fn(
148149
do_atomic_virial=do_atomic_virial,
149150
)
150151

151-
return make_fx(fn)(
152+
return make_fx(fn, **make_fx_kwargs)(
152153
extended_coord, extended_atype, nlist, mapping, fparam, aparam
153154
)

deepmd/pt_expt/model/polar_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def forward_lower_exportable(
104104
fparam: torch.Tensor | None = None,
105105
aparam: torch.Tensor | None = None,
106106
do_atomic_virial: bool = False,
107+
**make_fx_kwargs: Any,
107108
) -> torch.nn.Module:
108109
model = self
109110

@@ -126,6 +127,6 @@ def fn(
126127
do_atomic_virial=do_atomic_virial,
127128
)
128129

129-
return make_fx(fn)(
130+
return make_fx(fn, **make_fx_kwargs)(
130131
extended_coord, extended_atype, nlist, mapping, fparam, aparam
131132
)

deepmd/pt_expt/model/property_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def forward_lower_exportable(
111111
fparam: torch.Tensor | None = None,
112112
aparam: torch.Tensor | None = None,
113113
do_atomic_virial: bool = False,
114+
**make_fx_kwargs: Any,
114115
) -> torch.nn.Module:
115116
model = self
116117

@@ -133,6 +134,6 @@ def fn(
133134
do_atomic_virial=do_atomic_virial,
134135
)
135136

136-
return make_fx(fn)(
137+
return make_fx(fn, **make_fx_kwargs)(
137138
extended_coord, extended_atype, nlist, mapping, fparam, aparam
138139
)

deepmd/pt_expt/utils/serialization.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,17 @@ def _trace_and_export(
330330
metadata = _collect_metadata(model)
331331

332332
# 3. Create sample inputs on CPU for tracing
333-
# Use nframes=2 so make_fx doesn't specialize on nframes=1
333+
# Use nframes=5 to avoid two specialization traps:
334+
# - nframes=1 causes make_fx to specialize on the scalar case
335+
# - nframes=N where N == numb_fparam or numb_aparam causes PyTorch's
336+
# symbolic tracer to merge symbols (e.g. fparam.shape=(2,2) when
337+
# nframes=2 and numb_fparam=2), so a guard on one dim constrains
338+
# the other. 5 is unlikely to collide with typical param counts.
334339
_orig_device = _env.DEVICE
335340
_env.DEVICE = torch.device("cpu")
336341
try:
337342
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs(
338-
model, nframes=2
343+
model, nframes=5
339344
)
340345
finally:
341346
_env.DEVICE = _orig_device

source/tests/pt_expt/descriptor/test_dpa1.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from ...seed import (
2626
GLOBAL_SEED,
2727
)
28+
from ..export_helpers import (
29+
export_save_load_and_compare,
30+
make_descriptor_dynamic_shapes,
31+
)
2832

2933

3034
class TestDescrptDPA1(TestCaseSingleFrameWithNlist):
@@ -223,3 +227,15 @@ def fn(coord_ext, atype_ext, nlist):
223227
rtol=rtol,
224228
atol=atol,
225229
)
230+
231+
# --- symbolic trace + export + .pte round-trip ---
232+
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=False)
233+
inputs = (coord_ext, atype_ext, nlist)
234+
export_save_load_and_compare(
235+
fn,
236+
inputs,
237+
(rd_eager, grad_eager),
238+
dynamic_shapes,
239+
rtol=rtol,
240+
atol=atol,
241+
)

source/tests/pt_expt/descriptor/test_dpa2.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from ...seed import (
3030
GLOBAL_SEED,
3131
)
32+
from ..export_helpers import (
33+
export_save_load_and_compare,
34+
make_descriptor_dynamic_shapes,
35+
)
3236

3337

3438
class TestDescrptDPA2(TestCaseSingleFrameWithNlist):
@@ -410,3 +414,15 @@ def fn(coord_ext, atype_ext, nlist, mapping):
410414
rtol=rtol,
411415
atol=atol,
412416
)
417+
418+
# --- symbolic trace + export + .pte round-trip ---
419+
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=True)
420+
inputs = (coord_ext, atype_ext, nlist, mapping)
421+
export_save_load_and_compare(
422+
fn,
423+
inputs,
424+
(rd_eager, grad_eager),
425+
dynamic_shapes,
426+
rtol=rtol,
427+
atol=atol,
428+
)

source/tests/pt_expt/descriptor/test_dpa3.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from ...seed import (
2929
GLOBAL_SEED,
3030
)
31+
from ..export_helpers import (
32+
export_save_load_and_compare,
33+
make_descriptor_dynamic_shapes,
34+
)
3135

3236

3337
class TestDescrptDPA3(TestCaseSingleFrameWithNlist):
@@ -244,3 +248,15 @@ def fn(coord_ext, atype_ext, nlist, mapping):
244248
rtol=rtol,
245249
atol=atol,
246250
)
251+
252+
# --- symbolic trace + export + .pte round-trip ---
253+
dynamic_shapes = make_descriptor_dynamic_shapes(has_mapping=True)
254+
inputs = (coord_ext, atype_ext, nlist, mapping)
255+
export_save_load_and_compare(
256+
fn,
257+
inputs,
258+
(rd_eager, grad_eager),
259+
dynamic_shapes,
260+
rtol=rtol,
261+
atol=atol,
262+
)

0 commit comments

Comments
 (0)