Commit 7f01a21
feat(pt_expt): implement DeepSpin model in pt_expt backend (#5293)
## Summary
- Implement `SpinModel` and `SpinEnergyModel` in the pt_expt backend,
supporting spin degrees of freedom for magnetic systems
- Make dpmodel `SpinModel` array-API compatible so the same code works
across numpy/torch/jax backends
- Add spin virial correction (`coord_corr_for_virial`) to dpmodel and
pt_expt, matching the pt backend
- Fix `get_spin_model` in dpmodel to not mutate the caller's input data
dict (pt backend already used `deepcopy`)
## Changes
### dpmodel (`deepmd/dpmodel/model/`)
- `spin_model.py`: Replace all `np.*` operations with `array_api_compat`
equivalents (`xp.concat`, `xp.where`, `xp.zeros` with `device=`, slicing
instead of `xp.split`). Add `compute_or_load_stat` and virial correction
support via
`coord_corr_for_virial` / `extended_coord_corr`.
- `make_model.py`: Thread `coord_corr_for_virial` through `call_common`
→ `model_call_from_call_lower` (extends to ghost atoms via mapping) →
`call_common_lower` → `forward_common_atomic`.
- `model.py`: Add `copy.deepcopy(data)` in `get_spin_model` to prevent
in-place mutation of input dict.
### pt_expt (`deepmd/pt_expt/model/`)
- `spin_model.py` (new): `@torch_module` wrapper inheriting from dpmodel
`SpinModel`.
- `spin_ener_model.py` (new): `SpinEnergyModel` with `forward()` /
`forward_lower()` / `forward_lower_exportable()` providing user-facing
output translation.
- `make_model.py`, `transform_output.py`: Accept `extended_coord_corr`
for virial correction.
### Tests
- `test_spin_ener_model.py` (new): Unit tests for output keys/shapes,
serialize/deserialize round-trip, dpmodel consistency, force
finite-difference, virial finite-difference, and `torch.export`
exportability.
- `test_spin_ener.py`: Cross-backend consistency tests for
`call`/`call_lower`, `compute_or_load_stat`, and load-from-file. Virial
output now compared across pt and pt_expt.
## Test plan
- [x] `python -m pytest source/tests/pt_expt/model/ -v` — all 28 tests
pass
- [x] `python -m pytest source/tests/consistent/model/test_spin_ener.py
-v` — all 12 tests pass (18 skipped for uninstalled backends)
- [x] Force and virial verified by finite-difference tests
- [x] `torch.export.export` verified on `forward_lower_exportable`
- [x] `compute_or_load_stat` load-from-file verified across
dp/pt/pt_expt
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added SpinEnergyModel with exportable lower-level forward,
energy/force/virial outputs, and compute_or_load_stat preprocessing.
* Optional virial coordinate-correction can be supplied and is
propagated through forward paths.
* **Bug Fixes**
* Prevented in-place mutation of input data during model preparation.
* **Tests**
* Expanded tests for exportable workflows, force/virial validation,
multi-backend (including PT_EXPT) and array‑API strict modes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com>1 parent 91e3d62 commit 7f01a21
15 files changed
Lines changed: 1723 additions & 99 deletions
File tree
- deepmd
- dpmodel
- fitting
- model
- jax/model
- pt_expt/model
- pt/model/model
- source/tests
- consistent/model
- pt_expt/model
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
277 | 277 | | |
278 | 278 | | |
279 | 279 | | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
284 | 283 | | |
285 | 284 | | |
286 | 285 | | |
287 | | - | |
288 | | - | |
| 286 | + | |
| 287 | + | |
289 | 288 | | |
290 | 289 | | |
291 | 290 | | |
| |||
335 | 334 | | |
336 | 335 | | |
337 | 336 | | |
| 337 | + | |
338 | 338 | | |
339 | 339 | | |
340 | 340 | | |
341 | 341 | | |
342 | | - | |
343 | | - | |
344 | | - | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
345 | 345 | | |
346 | | - | |
347 | | - | |
| 346 | + | |
| 347 | + | |
348 | 348 | | |
349 | 349 | | |
350 | 350 | | |
351 | 351 | | |
352 | | - | |
353 | | - | |
| 352 | + | |
| 353 | + | |
354 | 354 | | |
355 | 355 | | |
356 | 356 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
| 71 | + | |
71 | 72 | | |
72 | 73 | | |
73 | 74 | | |
| |||
119 | 120 | | |
120 | 121 | | |
121 | 122 | | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
122 | 144 | | |
123 | 145 | | |
124 | 146 | | |
125 | 147 | | |
126 | 148 | | |
127 | | - | |
128 | | - | |
129 | | - | |
| 149 | + | |
130 | 150 | | |
131 | 151 | | |
132 | 152 | | |
| |||
237 | 257 | | |
238 | 258 | | |
239 | 259 | | |
| 260 | + | |
240 | 261 | | |
241 | 262 | | |
242 | 263 | | |
| |||
255 | 276 | | |
256 | 277 | | |
257 | 278 | | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
258 | 282 | | |
259 | 283 | | |
260 | 284 | | |
| |||
279 | 303 | | |
280 | 304 | | |
281 | 305 | | |
| 306 | + | |
282 | 307 | | |
283 | 308 | | |
284 | 309 | | |
| |||
292 | 317 | | |
293 | 318 | | |
294 | 319 | | |
| 320 | + | |
295 | 321 | | |
296 | 322 | | |
297 | 323 | | |
| |||
314 | 340 | | |
315 | 341 | | |
316 | 342 | | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
317 | 346 | | |
318 | 347 | | |
319 | 348 | | |
| |||
341 | 370 | | |
342 | 371 | | |
343 | 372 | | |
| 373 | + | |
344 | 374 | | |
345 | 375 | | |
346 | 376 | | |
| |||
354 | 384 | | |
355 | 385 | | |
356 | 386 | | |
| 387 | + | |
357 | 388 | | |
358 | 389 | | |
359 | 390 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
164 | 164 | | |
165 | 165 | | |
166 | 166 | | |
| 167 | + | |
167 | 168 | | |
168 | 169 | | |
169 | 170 | | |
| |||
0 commit comments