Skip to content

Commit 34ad36d

Browse files
committed
fix ut
1 parent adc550f commit 34ad36d

2 files changed

Lines changed: 64 additions & 12 deletions

File tree

deepmd/pt_expt/model/spin_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def forward_common_lower_exportable(
5757
mapping: torch.Tensor | None = None,
5858
fparam: torch.Tensor | None = None,
5959
aparam: torch.Tensor | None = None,
60+
charge_spin: torch.Tensor | None = None,
6061
do_atomic_virial: bool = False,
6162
**make_fx_kwargs: Any,
6263
) -> torch.nn.Module:
@@ -96,6 +97,7 @@ def fn(
9697
mapping: torch.Tensor | None,
9798
fparam: torch.Tensor | None,
9899
aparam: torch.Tensor | None,
100+
charge_spin: torch.Tensor | None,
99101
) -> dict[str, torch.Tensor]:
100102
extended_coord = extended_coord.detach().requires_grad_(True)
101103
nlist = _pad_nlist_for_export(nlist)
@@ -107,6 +109,7 @@ def fn(
107109
mapping,
108110
fparam=fparam,
109111
aparam=aparam,
112+
charge_spin=charge_spin,
110113
do_atomic_virial=do_atomic_virial,
111114
)
112115

@@ -130,6 +133,7 @@ def fn(
130133
mapping,
131134
fparam,
132135
aparam,
136+
charge_spin,
133137
)
134138
finally:
135139
backbone.need_sorted_nlist_for_lower = _orig_need_sort

deepmd/pt_expt/utils/serialization.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ def _make_sample_inputs(
119119
Returns
120120
-------
121121
tuple
122-
(ext_coord, ext_atype, nlist, mapping, fparam, aparam) or
123-
(ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) when has_spin.
122+
(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin) or
123+
(ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam,
124+
charge_spin) when has_spin.
124125
"""
125126
rcut = model.get_rcut()
126127
sel = model.get_sel()
@@ -187,14 +188,31 @@ def _make_sample_inputs(
187188
else:
188189
aparam = None
189190

191+
dim_chg_spin = model.get_dim_chg_spin() if hasattr(model, "get_dim_chg_spin") else 0
192+
if dim_chg_spin > 0:
193+
charge_spin = torch.zeros(
194+
nframes, dim_chg_spin, dtype=torch.float64, device=_env.DEVICE
195+
)
196+
else:
197+
charge_spin = None
198+
190199
if has_spin:
191200
nall = extended_coord.shape[1]
192201
ext_spin = torch.zeros(
193202
nframes, nall, 3, dtype=torch.float64, device=_env.DEVICE
194203
)
195-
return ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam
204+
return (
205+
ext_coord,
206+
ext_atype,
207+
ext_spin,
208+
nlist_t,
209+
mapping_t,
210+
fparam,
211+
aparam,
212+
charge_spin,
213+
)
196214

197-
return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam
215+
return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
198216

199217

200218
def _build_dynamic_shapes(
@@ -224,9 +242,10 @@ def _build_dynamic_shapes(
224242
nnei_dim = torch.export.Dim("nnei", min=max(1, model_nnei))
225243

226244
if has_spin:
227-
# (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam)
245+
# (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam, charge_spin)
228246
fparam = sample_inputs[5]
229247
aparam = sample_inputs[6]
248+
charge_spin = sample_inputs[7]
230249
return (
231250
{0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3)
232251
{0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall)
@@ -239,11 +258,13 @@ def _build_dynamic_shapes(
239258
{0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall)
240259
{0: nframes_dim} if fparam is not None else None, # fparam
241260
{0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam
261+
{0: nframes_dim} if charge_spin is not None else None, # charge_spin
242262
)
243263
else:
244-
# (ext_coord, ext_atype, nlist, mapping, fparam, aparam)
264+
# (ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin)
245265
fparam = sample_inputs[4]
246266
aparam = sample_inputs[5]
267+
charge_spin = sample_inputs[6]
247268
return (
248269
{0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3)
249270
{0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall)
@@ -255,6 +276,7 @@ def _build_dynamic_shapes(
255276
{0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall)
256277
{0: nframes_dim} if fparam is not None else None, # fparam
257278
{0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam
279+
{0: nframes_dim} if charge_spin is not None else None, # charge_spin
258280
)
259281

260282

@@ -487,11 +509,26 @@ def _trace_and_export(
487509
_env.DEVICE = _orig_device
488510

489511
if is_spin:
490-
ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam = (
491-
sample_inputs
492-
)
512+
(
513+
ext_coord,
514+
ext_atype,
515+
ext_spin,
516+
nlist_t,
517+
mapping_t,
518+
fparam,
519+
aparam,
520+
charge_spin,
521+
) = sample_inputs
493522
else:
494-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = sample_inputs
523+
(
524+
ext_coord,
525+
ext_atype,
526+
nlist_t,
527+
mapping_t,
528+
fparam,
529+
aparam,
530+
charge_spin,
531+
) = sample_inputs
495532

496533
# 4. Trace via make_fx on CPU.
497534
# This decomposes torch.autograd.grad into aten ops so the resulting
@@ -505,13 +542,21 @@ def _trace_and_export(
505542
mapping_t,
506543
fparam=fparam,
507544
aparam=aparam,
545+
charge_spin=charge_spin,
508546
do_atomic_virial=do_atomic_virial,
509547
tracing_mode="symbolic",
510548
_allow_non_fake_inputs=True,
511549
)
512550
# 5. Extract output keys from the CPU-traced module.
513551
sample_out = traced(
514-
ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam
552+
ext_coord,
553+
ext_atype,
554+
ext_spin,
555+
nlist_t,
556+
mapping_t,
557+
fparam,
558+
aparam,
559+
charge_spin,
515560
)
516561
else:
517562
traced = model.forward_common_lower_exportable(
@@ -521,12 +566,15 @@ def _trace_and_export(
521566
mapping_t,
522567
fparam=fparam,
523568
aparam=aparam,
569+
charge_spin=charge_spin,
524570
do_atomic_virial=do_atomic_virial,
525571
tracing_mode="symbolic",
526572
_allow_non_fake_inputs=True,
527573
)
528574
# 5. Extract output keys from the CPU-traced module.
529-
sample_out = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam)
575+
sample_out = traced(
576+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
577+
)
530578

531579
output_keys = list(sample_out.keys())
532580

0 commit comments

Comments
 (0)