Skip to content

Commit 7f01a21

Browse files
wanghan-iapcmHan WangiProzd
authored
feat(pt_expt): implement DeepSpin model in pt_expt backend (#5293)
## Summary - Implement `SpinModel` and `SpinEnergyModel` in the pt_expt backend, supporting spin degrees of freedom for magnetic systems - Make dpmodel `SpinModel` array-API compatible so the same code works across numpy/torch/jax backends - Add spin virial correction (`coord_corr_for_virial`) to dpmodel and pt_expt, matching the pt backend - Fix `get_spin_model` in dpmodel to not mutate the caller's input data dict (pt backend already used `deepcopy`) ## Changes ### dpmodel (`deepmd/dpmodel/model/`) - `spin_model.py`: Replace all `np.*` operations with `array_api_compat` equivalents (`xp.concat`, `xp.where`, `xp.zeros` with `device=`, slicing instead of `xp.split`). Add `compute_or_load_stat` and virial correction support via `coord_corr_for_virial` / `extended_coord_corr`. - `make_model.py`: Thread `coord_corr_for_virial` through `call_common` → `model_call_from_call_lower` (extends to ghost atoms via mapping) → `call_common_lower` → `forward_common_atomic`. - `model.py`: Add `copy.deepcopy(data)` in `get_spin_model` to prevent in-place mutation of input dict. ### pt_expt (`deepmd/pt_expt/model/`) - `spin_model.py` (new): `@torch_module` wrapper inheriting from dpmodel `SpinModel`. - `spin_ener_model.py` (new): `SpinEnergyModel` with `forward()` / `forward_lower()` / `forward_lower_exportable()` providing user-facing output translation. - `make_model.py`, `transform_output.py`: Accept `extended_coord_corr` for virial correction. ### Tests - `test_spin_ener_model.py` (new): Unit tests for output keys/shapes, serialize/deserialize round-trip, dpmodel consistency, force finite-difference, virial finite-difference, and `torch.export` exportability. - `test_spin_ener.py`: Cross-backend consistency tests for `call`/`call_lower`, `compute_or_load_stat`, and load-from-file. Virial output now compared across pt and pt_expt. ## Test plan - [x] `python -m pytest source/tests/pt_expt/model/ -v` — all 28 tests pass - [x] `python -m pytest source/tests/consistent/model/test_spin_ener.py -v` — all 12 tests pass (18 skipped for uninstalled backends) - [x] Force and virial verified by finite-difference tests - [x] `torch.export.export` verified on `forward_lower_exportable` - [x] `compute_or_load_stat` load-from-file verified across dp/pt/pt_expt <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added SpinEnergyModel with exportable lower-level forward, energy/force/virial outputs, and compute_or_load_stat preprocessing. * Optional virial coordinate-correction can be supplied and is propagated through forward paths. * **Bug Fixes** * Prevented in-place mutation of input data during model preparation. * **Tests** * Expanded tests for exportable workflows, force/virial validation, multi-backend (including PT_EXPT) and array‑API strict modes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com>
1 parent 91e3d62 commit 7f01a21

15 files changed

Lines changed: 1723 additions & 99 deletions

File tree

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,14 @@ def compute_input_stats(
277277
f"numb_fparam > 0 but no fparam data is provided "
278278
f"for system {ii}."
279279
)
280-
cat_data = np.concatenate(
281-
[frame["fparam"] for frame in sampled], axis=0
282-
)
283-
cat_data = np.reshape(cat_data, [-1, self.numb_fparam])
280+
xp_fp = array_api_compat.array_namespace(sampled[0]["fparam"])
281+
cat_data = xp_fp.concat([frame["fparam"] for frame in sampled], axis=0)
282+
cat_data = xp_fp.reshape(cat_data, (-1, self.numb_fparam))
284283
fparam_stats = [
285284
StatItem(
286285
number=cat_data.shape[0],
287-
sum=np.sum(cat_data[:, ii]),
288-
squared_sum=np.sum(cat_data[:, ii] ** 2),
286+
sum=float(xp_fp.sum(cat_data[:, ii])),
287+
squared_sum=float(xp_fp.sum(cat_data[:, ii] ** 2)),
289288
)
290289
for ii in range(self.numb_fparam)
291290
]
@@ -335,22 +334,23 @@ def compute_input_stats(
335334
f"numb_aparam > 0 but no aparam data is provided "
336335
f"for system {ii}."
337336
)
337+
xp_ap = array_api_compat.array_namespace(sampled[0]["aparam"])
338338
sys_sumv = []
339339
sys_sumv2 = []
340340
sys_sumn = []
341341
for ss_ in [frame["aparam"] for frame in sampled]:
342-
ss = np.reshape(ss_, [-1, self.numb_aparam])
343-
sys_sumv.append(np.sum(ss, axis=0))
344-
sys_sumv2.append(np.sum(ss * ss, axis=0))
342+
ss = xp_ap.reshape(ss_, (-1, self.numb_aparam))
343+
sys_sumv.append(xp_ap.sum(ss, axis=0))
344+
sys_sumv2.append(xp_ap.sum(ss * ss, axis=0))
345345
sys_sumn.append(ss.shape[0])
346-
sumv = np.sum(np.stack(sys_sumv), axis=0)
347-
sumv2 = np.sum(np.stack(sys_sumv2), axis=0)
346+
sumv = xp_ap.sum(xp_ap.stack(sys_sumv), axis=0)
347+
sumv2 = xp_ap.sum(xp_ap.stack(sys_sumv2), axis=0)
348348
sumn = sum(sys_sumn)
349349
aparam_stats = [
350350
StatItem(
351351
number=sumn,
352-
sum=sumv[ii],
353-
squared_sum=sumv2[ii],
352+
sum=float(sumv[ii]),
353+
squared_sum=float(sumv2[ii]),
354354
)
355355
for ii in range(self.numb_aparam)
356356
]

deepmd/dpmodel/model/make_model.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def model_call_from_call_lower(
6868
fparam: Array | None = None,
6969
aparam: Array | None = None,
7070
do_atomic_virial: bool = False,
71+
coord_corr_for_virial: Array | None = None,
7172
) -> dict[str, Array]:
7273
"""Return model prediction from lower interface.
7374
@@ -119,14 +120,33 @@ def model_call_from_call_lower(
119120
distinguish_types=False,
120121
)
121122
extended_coord = extended_coord.reshape(nframes, -1, 3)
123+
if coord_corr_for_virial is not None:
124+
xp = array_api_compat.array_namespace(coord_corr_for_virial)
125+
# mapping: nf x nall -> nf x nall x 1, then tile to nf x nall x 3
126+
mapping_idx = xp.tile(
127+
xp.reshape(mapping, (nframes, -1, 1)),
128+
(1, 1, 3),
129+
)
130+
extended_coord_corr = xp.take_along_axis(
131+
coord_corr_for_virial,
132+
mapping_idx,
133+
axis=1,
134+
)
135+
else:
136+
extended_coord_corr = None
137+
call_lower_kwargs: dict[str, Any] = {
138+
"fparam": fp,
139+
"aparam": ap,
140+
"do_atomic_virial": do_atomic_virial,
141+
}
142+
if extended_coord_corr is not None:
143+
call_lower_kwargs["extended_coord_corr"] = extended_coord_corr
122144
model_predict_lower = call_lower(
123145
extended_coord,
124146
extended_atype,
125147
nlist,
126148
mapping,
127-
fparam=fp,
128-
aparam=ap,
129-
do_atomic_virial=do_atomic_virial,
149+
**call_lower_kwargs,
130150
)
131151
model_predict = communicate_extended_output(
132152
model_predict_lower,
@@ -237,6 +257,7 @@ def call_common(
237257
fparam: Array | None = None,
238258
aparam: Array | None = None,
239259
do_atomic_virial: bool = False,
260+
coord_corr_for_virial: Array | None = None,
240261
) -> dict[str, Array]:
241262
"""Return model prediction.
242263
@@ -255,6 +276,9 @@ def call_common(
255276
atomic parameter. nf x nloc x nda
256277
do_atomic_virial
257278
If calculate the atomic virial.
279+
coord_corr_for_virial
280+
The coordinates correction for virial.
281+
shape: nf x (nloc x 3)
258282
259283
Returns
260284
-------
@@ -279,6 +303,7 @@ def call_common(
279303
fparam=fp,
280304
aparam=ap,
281305
do_atomic_virial=do_atomic_virial,
306+
coord_corr_for_virial=coord_corr_for_virial,
282307
)
283308
model_predict = self._output_type_cast(model_predict, input_prec)
284309
return model_predict
@@ -292,6 +317,7 @@ def call_common_lower(
292317
fparam: Array | None = None,
293318
aparam: Array | None = None,
294319
do_atomic_virial: bool = False,
320+
extended_coord_corr: Array | None = None,
295321
) -> dict[str, Array]:
296322
"""Return model prediction. Lower interface that takes
297323
extended atomic coordinates and types, nlist, and mapping
@@ -314,6 +340,9 @@ def call_common_lower(
314340
atomic parameter. nf x nloc x nda
315341
do_atomic_virial
316342
whether calculate atomic virial
343+
extended_coord_corr
344+
coordinates correction for virial in extended region.
345+
nf x (nall x 3)
317346
318347
Returns
319348
-------
@@ -341,6 +370,7 @@ def call_common_lower(
341370
fparam=fp,
342371
aparam=ap,
343372
do_atomic_virial=do_atomic_virial,
373+
extended_coord_corr=extended_coord_corr,
344374
)
345375
model_predict = self._output_type_cast(model_predict, input_prec)
346376
return model_predict
@@ -354,6 +384,7 @@ def forward_common_atomic(
354384
fparam: Array | None = None,
355385
aparam: Array | None = None,
356386
do_atomic_virial: bool = False,
387+
extended_coord_corr: Array | None = None,
357388
) -> dict[str, Array]:
358389
atomic_ret = self.atomic_model.forward_common_atomic(
359390
extended_coord,

deepmd/dpmodel/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def get_spin_model(data: dict) -> SpinModel:
164164
data : dict
165165
The data to construct the model.
166166
"""
167+
data = copy.deepcopy(data)
167168
# include virtual spin and placeholder types
168169
data["type_map"] += [item + "_spin" for item in data["type_map"]]
169170
spin = Spin(

0 commit comments

Comments
 (0)