Skip to content

Commit e98b6a8

Browse files
committed
fix ut
1 parent e619fa8 commit e98b6a8

5 files changed

Lines changed: 60 additions & 13 deletions

File tree

deepmd/pt/modifier/base_modifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def forward(
8383
fparam: torch.Tensor | None = None,
8484
aparam: torch.Tensor | None = None,
8585
do_atomic_virial: bool = False,
86+
charge_spin: torch.Tensor | None = None,
8687
) -> dict[str, torch.Tensor]:
8788
"""Compute energy, force, and virial corrections."""
8889

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,7 @@ def _eval_model(
11371137
mapping_t,
11381138
fparam_t,
11391139
aparam_t,
1140-
charge_spin=charge_spin_t,
1140+
charge_spin_t,
11411141
)
11421142

11431143
# Apply communicate_extended_output to map extended atoms → local atoms
@@ -1317,7 +1317,7 @@ def _eval_model_spin(
13171317
mapping_t,
13181318
fparam_t,
13191319
aparam_t,
1320-
charge_spin=charge_spin_t,
1320+
charge_spin_t,
13211321
)
13221322

13231323
# Apply communicate_extended_output to map extended atoms → local atoms

source/tests/pt/test_data_modifier.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def forward(
116116
fparam: torch.Tensor | None = None,
117117
aparam: torch.Tensor | None = None,
118118
do_atomic_virial: bool = False,
119+
charge_spin: torch.Tensor | None = None,
119120
) -> dict[str, torch.Tensor]:
120121
"""Implementation of abstractmethod."""
121122
return {}
@@ -158,6 +159,7 @@ def forward(
158159
fparam: torch.Tensor | None = None,
159160
aparam: torch.Tensor | None = None,
160161
do_atomic_virial: bool = False,
162+
charge_spin: torch.Tensor | None = None,
161163
) -> dict[str, torch.Tensor]:
162164
"""Implementation of abstractmethod."""
163165
return {}
@@ -205,6 +207,7 @@ def forward(
205207
fparam: torch.Tensor | None = None,
206208
aparam: torch.Tensor | None = None,
207209
do_atomic_virial: bool = False,
210+
charge_spin: torch.Tensor | None = None,
208211
) -> dict[str, torch.Tensor]:
209212
"""Take scaled model prediction as data modification."""
210213
model_pred = self.model(

source/tests/pt_expt/export_helpers.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def model_forward_lower_export_round_trip(
118118
fparam,
119119
aparam,
120120
output_keys: tuple[str, ...],
121+
charge_spin=None,
121122
rtol: float = 1e-10,
122123
atol: float = 1e-10,
123124
):
@@ -141,6 +142,9 @@ def model_forward_lower_export_round_trip(
141142
Frame and atom parameters.
142143
output_keys : tuple of str
143144
Output dictionary keys to verify.
145+
charge_spin : torch.Tensor or None
146+
Charge/spin parameter for descriptors that consume it (e.g. DPA3
147+
with ``add_chg_spin_ebd=True``).
144148
rtol, atol : float
145149
Tolerances for np.testing.assert_allclose.
146150
"""
@@ -156,6 +160,7 @@ def model_forward_lower_export_round_trip(
156160
mapping_t,
157161
fparam=fparam,
158162
aparam=aparam,
163+
charge_spin=charge_spin,
159164
)
160165

161166
# 2. Concrete trace
@@ -166,21 +171,24 @@ def model_forward_lower_export_round_trip(
166171
mapping_t,
167172
fparam=fparam,
168173
aparam=aparam,
174+
charge_spin=charge_spin,
169175
)
170176
assert isinstance(traced, torch.nn.Module)
171177

172178
# 3. Basic export (no dynamic shapes)
173179
exported = torch.export.export(
174180
traced,
175-
(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam),
181+
(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin),
176182
strict=False,
177183
)
178184
assert exported is not None
179185

180186
# 4. Compare traced and exported vs eager
181-
ret_traced = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
187+
ret_traced = traced(
188+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
189+
)
182190
ret_exported = exported.module()(
183-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam
191+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
184192
)
185193
for key in output_keys:
186194
np.testing.assert_allclose(
@@ -201,7 +209,15 @@ def model_forward_lower_export_round_trip(
201209
# 5. Symbolic trace + dynamic shapes + .pte round-trip
202210
inputs_2f = tuple(
203211
torch.cat([t, t], dim=0) if t is not None else None
204-
for t in (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
212+
for t in (
213+
ext_coord,
214+
ext_atype,
215+
nlist_t,
216+
mapping_t,
217+
fparam,
218+
aparam,
219+
charge_spin,
220+
)
205221
)
206222
traced_sym = md_pt.forward_lower_exportable(
207223
inputs_2f[0],
@@ -210,6 +226,7 @@ def model_forward_lower_export_round_trip(
210226
inputs_2f[3],
211227
fparam=inputs_2f[4],
212228
aparam=inputs_2f[5],
229+
charge_spin=inputs_2f[6],
213230
tracing_mode="symbolic",
214231
_allow_non_fake_inputs=True,
215232
)
@@ -226,7 +243,9 @@ def model_forward_lower_export_round_trip(
226243
loaded = torch.export.load(f.name).module()
227244

228245
# 6. Compare loaded vs eager (nf=1 — different shapes)
229-
ret_loaded_1f = loaded(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
246+
ret_loaded_1f = loaded(
247+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
248+
)
230249
for key in output_keys:
231250
np.testing.assert_allclose(
232251
ret_eager[key].detach().cpu().numpy(),

source/tests/pt_expt/model/test_export_pipeline.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
121121
inputs_trace = _make_sample_inputs(model2, nframes=5, nloc=7)
122122
finally:
123123
_env.DEVICE = orig_device
124-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = inputs_trace
124+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = (
125+
inputs_trace
126+
)
125127

126128
# 4. Eager reference
127129
eager_out = model2.forward_common_lower(
@@ -132,6 +134,7 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
132134
fparam=fparam,
133135
aparam=aparam,
134136
do_atomic_virial=True,
137+
charge_spin=charge_spin,
135138
)
136139

137140
# 5. Trace with symbolic mode (same as dp freeze)
@@ -142,6 +145,7 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
142145
mapping_t,
143146
fparam=fparam,
144147
aparam=aparam,
148+
charge_spin=charge_spin,
145149
do_atomic_virial=True,
146150
tracing_mode="symbolic",
147151
_allow_non_fake_inputs=True,
@@ -155,11 +159,12 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
155159
mapping_t,
156160
fparam,
157161
aparam,
162+
charge_spin,
158163
model_nnei=sum(model2.get_sel()),
159164
)
160165
exported = torch.export.export(
161166
traced,
162-
(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam),
167+
(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin),
163168
dynamic_shapes=dynamic_shapes,
164169
strict=False,
165170
prefer_deferred_runtime_asserts_over_guards=True,
@@ -171,7 +176,9 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
171176
loaded = torch.export.load(tmp.name).module()
172177

173178
# 8. Verify: traced output matches eager (same shapes as trace)
174-
traced_out = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
179+
traced_out = traced(
180+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
181+
)
175182
for key in eager_out:
176183
np.testing.assert_allclose(
177184
eager_out[key].detach().cpu().numpy(),
@@ -182,7 +189,9 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
182189
)
183190

184191
# 9. Verify: loaded (.pte) output matches eager (same shapes)
185-
loaded_out = loaded(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
192+
loaded_out = loaded(
193+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
194+
)
186195
for key in eager_out:
187196
np.testing.assert_allclose(
188197
eager_out[key].detach().cpu().numpy(),
@@ -206,6 +215,7 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
206215
mapping_t2,
207216
fparam2,
208217
aparam2,
218+
charge_spin2,
209219
) = inputs_infer
210220

211221
eager_out2 = model2.forward_common_lower(
@@ -216,9 +226,16 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
216226
fparam=fparam2,
217227
aparam=aparam2,
218228
do_atomic_virial=True,
229+
charge_spin=charge_spin2,
219230
)
220231
loaded_out2 = loaded(
221-
ext_coord2, ext_atype2, nlist_t2, mapping_t2, fparam2, aparam2
232+
ext_coord2,
233+
ext_atype2,
234+
nlist_t2,
235+
mapping_t2,
236+
fparam2,
237+
aparam2,
238+
charge_spin2,
222239
)
223240
for key in eager_out2:
224241
np.testing.assert_allclose(
@@ -248,9 +265,16 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None:
248265
fparam=fparam_ones,
249266
aparam=aparam,
250267
do_atomic_virial=True,
268+
charge_spin=charge_spin,
251269
)
252270
loaded_out_fp1 = loaded(
253-
ext_coord, ext_atype, nlist_t, mapping_t, fparam_ones, aparam
271+
ext_coord,
272+
ext_atype,
273+
nlist_t,
274+
mapping_t,
275+
fparam_ones,
276+
aparam,
277+
charge_spin,
254278
)
255279
# Loaded with fparam=1 should match eager with fparam=1
256280
for key in eager_out_fp1:

0 commit comments

Comments
 (0)