Commit d67e75b
feat(pt_expt): multi-rank LAMMPS support for GNN models (DPA3 / DPA2 / spin) (#5430)
## Summary
Adds multi-rank LAMMPS support to the `pt_expt` (.pt2 / AOTInductor)
backend for GNN models — DPA3 (repflows) and DPA2 (repformers), plus
spin GNN — at parity with the existing `pt` (.pth / torch.jit) backend.
Without this, multi-rank LAMMPS users with GNN .pt2 models fall back to
single-rank-only, and the C++ side crashes on the first ghost exchange
when given a non-`use_loc_mapping` GNN .pt2.
The mechanism mirrors the pt backend's per-layer ghost-atom MPI
exchange: each repflow/repformer block exchanges `g1` across ranks via
`border_op` so each rank sees up-to-date ghost embeddings. To survive
`torch.export` + AOTInductor packaging, `border_op` is wrapped as an
opaque `torch.library.custom_op` (`deepmd_export::border_op`) with a
separate `border_op_backward` C++ symbol for autograd.
## Design
- **Phase 1 — dpmodel plumbing**: thread `comm_dict: dict | None = None`
through `make_model`, `base_atomic_model`, descriptor wrappers
(dpa1/dpa2/dpa3/hybrid/se_*), and the repflows/repformers blocks. Lift
the per-layer `node_ebd_ext` construction into a `_exchange_ghosts`
method (default array-api impl ignores `comm_dict`).
- **Phase 2 — pt_expt opaque op + block overrides**:
- `deepmd::border_op_backward` C++ op (additive accumulation into local
atom slots — symmetric exchange used by autograd backward).
- `deepmd_export::border_op` Python `custom_op` wrapper with
`register_fake` and `register_autograd` so the op is opaque to
`torch.export`.
- `pt_expt/descriptor/repflows.py` and `repformers.py` block subclasses
with `_exchange_ghosts` overrides that call the opaque op (with the spin
real/virtual split + `concat_switch_virtual` when
`comm_dict[\"has_spin\"]` is set).
- **Phase 3 — dual-artifact AOTInductor export**: when
`_has_message_passing(model)` is true, compile **two** artifacts into
the .pt2 ZIP:
- `forward_lower_no_comm.pt2` — current behavior (single-rank,
mapping-based gather).
- `forward_lower_with_comm.pt2` — adds positional comm tensors
`(send_list, send_proc, recv_proc, send_num, recv_num, communicator,
nlocal, nghost)` to the trace input list, plus `has_spin=tensor([1])`
baked in for spin GNN.
- Metadata: `has_message_passing` + `has_comm_artifact` flags so the C++
loader picks the right artifact.
- **Phase 4 — C++ dispatch**: `DeepPotPTExpt::compute` and
`DeepSpinPTExpt::compute` route to the with-comm artifact when
`lmp_list.nswap > 0`. `commPTExpt` adds `build_comm_tensors_positional`
and `build_comm_tensors_positional_with_virtual_atoms` (the latter
remaps sendlists via `fwd_map` when NULL atoms drop out of the model's
view).
- **Phase 5 — LAMMPS tests**: end-to-end multi-rank tests (DPA3, DPA2,
spin DPA3) covering the basic case + the structural edge cases
(NULL-type atoms, empty subdomain, all-NULL rank, isolated NULL, nlist
rebuild, N>2 decomposition).
## Coverage matrix
| | DPA3 | DPA2 | spin DPA3 |
| - | - | - | - |
| basic mpi-2 | ✓ | ✓ | ✓ |
| empty subdomain (rank 1 empty) | ✓ | — | ✓ |
| all-NULL rank | ✓ | — | — |
| isolated NULL | ✓ | — | — |
| NULL atoms straddling boundary | ✓ | — | ✓ |
| NULL across nlist rebuild | ✓ | — | — |
| N>2 (2x2x1, 4x1x1, 2x2x2) | ✓ | — | — |
| cached mapping_tensor (ago>0) | ✓ | — | — |
Tests compare mpi-N vs same-archive mpi-1 for force / force_mag / virial
/ energy (atol 1e-8); no hardcoded numerical references.
Plus a unit test (`test_has_message_passing.py`) pinning the
`_has_message_passing` schema-drift contract.
## Co-existence with #5407
This branch was rebased onto upstream master after PR #5407 (.pt2 perf)
merged. The merge required:
- `forward_common_lower_exportable_with_comm` (spin and non-spin
variants) now applies `_pad_nlist_for_export` +
`need_sorted_nlist_for_lower=True` workarounds matching #5407's regular
variant — keeps the with-comm trace's `nnei` axis dynamic.
- CUDA `realize_opcount_threshold=0` workaround applied around BOTH
artifact compiles.
- `do_atomic_virial=True` is used for all multi-rank fixtures to avoid
AOTI compile-time changes from #5407's default `=False`. (Multi-rank
with `do_atomic_virial=False` is a known coverage gap — see
Limitations.)
## Known limitations
- **Hybrid + GNN + multi-rank**: `_has_message_passing` doesn't recurse
into hybrid children → no with-comm artifact produced for hybrid-of-GNN.
Multi-rank LAMMPS with such a model would fall back to single-artifact
and crash on first ghost exchange. Out of scope for this PR; if a user
hits this, they get the same crash they had before.
- **`do_atomic_virial=False` under multi-rank**: production default. All
multi-rank tests use `=True` (matches #5407 fixture conventions). Not
exercised end-to-end yet.
- **CUDA**: the `TORCH_LIBRARY_IMPL(deepmd_export, CUDA, m)`
registration exists and the symbol is callable, but no GPU end-to-end
test runs in this PR (CPU-only build environment locally).
- **DPA2 (repformers) edge cases**: only basic mpi-2 is tested for DPA2.
NULL / empty-subdomain / nlist-rebuild covered for DPA3 only — these
paths are descriptor-agnostic by construction but not exercised
end-to-end for DPA2.
- **Spin DPA3 edge cases**: empty subdomain and NULL straddling boundary
covered. all-NULL-rank, isolated-NULL, nlist-rebuild not.
- **Multiple spin-active types** (`use_spin=[True, True, ...]`): only
`[True, False]` tested.
- **Frame batching nb>1 with `comm_dict`**: `_exchange_ghosts` uses
`.squeeze(0)` / `.unsqueeze(0)` (mirrors pt). LAMMPS feeds nb=1 — fine
in practice; breaks if reused outside LAMMPS.
- **Float16/bfloat16**: comm_dict path was developed in float64 only.
- **Cross-backend (jax/paddle/numpy) `comm_dict=None` neutrality**:
dpmodel default `_exchange_ghosts` is the original code lifted into a
method, behaviorally equivalent. Not separately re-tested via running
consistent tests with `comm_dict=None` explicitly threaded.
- **Old .pt2 files lacking `has_comm_artifact` metadata**: C++ defaults
to single-artifact when key is missing. Not negative-tested.
- **AOTI compile failure cleanup**: if the with-comm compile fails after
the regular artifact is written, the resulting .pt2 is half-formed. RAII
/ SIGKILL leakage of `TempFile` in /tmp also pre-exists.
A full catalog of touched-but-untested paths is maintained at
`memory/gnn_mpi_untested_paths.md` (local to the author).
## Test plan
- [ ] `pytest source/tests/pt_expt/` — eager-parity, export round-trip,
schema-drift unit tests
- [ ] `pytest source/tests/consistent/descriptor/test_dpa3.py
source/tests/consistent/descriptor/test_dpa2.py` — non-regression for
the single-rank path
- [ ] `pytest source/lmp/tests/test_lammps_dpa3_pt2.py
source/lmp/tests/test_lammps_dpa2_pt2.py
source/lmp/tests/test_lammps_spin_dpa3_pt2.py` — multi-rank end-to-end
(requires mpirun + mpi4py)
- [ ] `ctest -R PtExpt` — C++ tests (single-rank, since multi-rank is
exercised via LAMMPS)
- [ ] CI runs across CPU + CUDA matrix
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional comms/MPI metadata support across inference paths enabling
ghost-exchange; dual “with-comm” export/run artifact and tracing
entrypoints for message-passing GNNs and spin models, plus improved
export/tracing support for comm-enabled ops.
* **Bug Fixes**
* Fixed unsafe neighbor-list memory handling.
* **Tests**
* Added extensive MPI parity, export-with-comm, repflow/repformer
parallel-path, spin integration tests and CLI MPI runner scripts.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>1 parent 7ca7baf commit d67e75b
67 files changed
Lines changed: 5737 additions & 135 deletions
File tree
- deepmd
- dpmodel
- atomic_model
- descriptor
- model
- jax
- atomic_model
- model
- pt_expt
- descriptor
- model
- utils
- source
- api_cc
- include
- src
- tests
- lmp/tests
- op/pt
- tests
- infer
- pt_expt
- descriptor
- model
- utils
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
231 | 231 | | |
232 | 232 | | |
233 | 233 | | |
| 234 | + | |
234 | 235 | | |
235 | 236 | | |
236 | 237 | | |
| |||
252 | 253 | | |
253 | 254 | | |
254 | 255 | | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
255 | 259 | | |
256 | 260 | | |
257 | 261 | | |
| |||
279 | 283 | | |
280 | 284 | | |
281 | 285 | | |
| 286 | + | |
282 | 287 | | |
283 | 288 | | |
284 | 289 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
157 | 157 | | |
158 | 158 | | |
159 | 159 | | |
| 160 | + | |
160 | 161 | | |
161 | 162 | | |
162 | 163 | | |
| |||
174 | 175 | | |
175 | 176 | | |
176 | 177 | | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
177 | 181 | | |
178 | 182 | | |
179 | 183 | | |
| |||
215 | 219 | | |
216 | 220 | | |
217 | 221 | | |
| 222 | + | |
218 | 223 | | |
219 | 224 | | |
220 | 225 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
224 | 224 | | |
225 | 225 | | |
226 | 226 | | |
| 227 | + | |
227 | 228 | | |
228 | 229 | | |
229 | 230 | | |
| |||
241 | 242 | | |
242 | 243 | | |
243 | 244 | | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
244 | 249 | | |
245 | 250 | | |
246 | 251 | | |
| |||
280 | 285 | | |
281 | 286 | | |
282 | 287 | | |
| 288 | + | |
283 | 289 | | |
284 | 290 | | |
285 | 291 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
253 | 253 | | |
254 | 254 | | |
255 | 255 | | |
| 256 | + | |
256 | 257 | | |
| 258 | + | |
257 | 259 | | |
258 | 260 | | |
259 | 261 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
397 | 397 | | |
398 | 398 | | |
399 | 399 | | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
400 | 408 | | |
401 | 409 | | |
402 | 410 | | |
| |||
500 | 508 | | |
501 | 509 | | |
502 | 510 | | |
| 511 | + | |
503 | 512 | | |
504 | 513 | | |
505 | 514 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
687 | 687 | | |
688 | 688 | | |
689 | 689 | | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
690 | 700 | | |
691 | 701 | | |
692 | 702 | | |
| |||
831 | 841 | | |
832 | 842 | | |
833 | 843 | | |
| 844 | + | |
834 | 845 | | |
835 | 846 | | |
836 | 847 | | |
| |||
844 | 855 | | |
845 | 856 | | |
846 | 857 | | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
847 | 863 | | |
848 | 864 | | |
849 | 865 | | |
| |||
912 | 928 | | |
913 | 929 | | |
914 | 930 | | |
915 | | - | |
916 | | - | |
917 | | - | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
| 934 | + | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
918 | 943 | | |
919 | 944 | | |
920 | 945 | | |
| |||
926 | 951 | | |
927 | 952 | | |
928 | 953 | | |
| 954 | + | |
929 | 955 | | |
930 | 956 | | |
931 | 957 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
527 | 527 | | |
528 | 528 | | |
529 | 529 | | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
530 | 541 | | |
531 | 542 | | |
532 | 543 | | |
| |||
616 | 627 | | |
617 | 628 | | |
618 | 629 | | |
| 630 | + | |
619 | 631 | | |
620 | 632 | | |
621 | 633 | | |
| |||
629 | 641 | | |
630 | 642 | | |
631 | 643 | | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
632 | 647 | | |
633 | 648 | | |
634 | 649 | | |
| |||
695 | 710 | | |
696 | 711 | | |
697 | 712 | | |
| 713 | + | |
698 | 714 | | |
699 | 715 | | |
700 | 716 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
168 | 168 | | |
169 | 169 | | |
170 | 170 | | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
171 | 181 | | |
172 | 182 | | |
173 | 183 | | |
| |||
276 | 286 | | |
277 | 287 | | |
278 | 288 | | |
| 289 | + | |
279 | 290 | | |
280 | 291 | | |
281 | 292 | | |
| |||
332 | 343 | | |
333 | 344 | | |
334 | 345 | | |
335 | | - | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
336 | 349 | | |
337 | 350 | | |
338 | 351 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
107 | 107 | | |
108 | 108 | | |
109 | 109 | | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
110 | 128 | | |
111 | 129 | | |
112 | 130 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
506 | 506 | | |
507 | 507 | | |
508 | 508 | | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
| 521 | + | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
509 | 535 | | |
510 | 536 | | |
511 | 537 | | |
| |||
514 | 540 | | |
515 | 541 | | |
516 | 542 | | |
| 543 | + | |
517 | 544 | | |
518 | 545 | | |
519 | 546 | | |
| |||
641 | 668 | | |
642 | 669 | | |
643 | 670 | | |
644 | | - | |
645 | | - | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
646 | 679 | | |
647 | 680 | | |
648 | | - | |
649 | | - | |
650 | | - | |
651 | | - | |
652 | | - | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
653 | 689 | | |
654 | 690 | | |
655 | 691 | | |
| |||
696 | 732 | | |
697 | 733 | | |
698 | 734 | | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
699 | 745 | | |
700 | 746 | | |
701 | 747 | | |
| |||
0 commit comments