Skip to content

Commit d67e75b

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): multi-rank LAMMPS support for GNN models (DPA3 / DPA2 / spin) (#5430)
## Summary Adds multi-rank LAMMPS support to the `pt_expt` (.pt2 / AOTInductor) backend for GNN models — DPA3 (repflows) and DPA2 (repformers), plus spin GNN — at parity with the existing `pt` (.pth / torch.jit) backend. Without this, multi-rank LAMMPS users with GNN .pt2 models fall back to single-rank-only, and the C++ side crashes on the first ghost exchange when given a non-`use_loc_mapping` GNN .pt2. The mechanism mirrors the pt backend's per-layer ghost-atom MPI exchange: each repflow/repformer block exchanges `g1` across ranks via `border_op` so each rank sees up-to-date ghost embeddings. To survive `torch.export` + AOTInductor packaging, `border_op` is wrapped as an opaque `torch.library.custom_op` (`deepmd_export::border_op`) with a separate `border_op_backward` C++ symbol for autograd. ## Design - **Phase 1 — dpmodel plumbing**: thread `comm_dict: dict | None = None` through `make_model`, `base_atomic_model`, descriptor wrappers (dpa1/dpa2/dpa3/hybrid/se_*), and the repflows/repformers blocks. Lift the per-layer `node_ebd_ext` construction into a `_exchange_ghosts` method (default array-api impl ignores `comm_dict`). - **Phase 2 — pt_expt opaque op + block overrides**: - `deepmd::border_op_backward` C++ op (additive accumulation into local atom slots — symmetric exchange used by autograd backward). - `deepmd_export::border_op` Python `custom_op` wrapper with `register_fake` and `register_autograd` so the op is opaque to `torch.export`. - `pt_expt/descriptor/repflows.py` and `repformers.py` block subclasses with `_exchange_ghosts` overrides that call the opaque op (with the spin real/virtual split + `concat_switch_virtual` when `comm_dict[\"has_spin\"]` is set). - **Phase 3 — dual-artifact AOTInductor export**: when `_has_message_passing(model)` is true, compile **two** artifacts into the .pt2 ZIP: - `forward_lower_no_comm.pt2` — current behavior (single-rank, mapping-based gather). - `forward_lower_with_comm.pt2` — adds positional comm tensors `(send_list, send_proc, recv_proc, send_num, recv_num, communicator, nlocal, nghost)` to the trace input list, plus `has_spin=tensor([1])` baked in for spin GNN. - Metadata: `has_message_passing` + `has_comm_artifact` flags so the C++ loader picks the right artifact. - **Phase 4 — C++ dispatch**: `DeepPotPTExpt::compute` and `DeepSpinPTExpt::compute` route to the with-comm artifact when `lmp_list.nswap > 0`. `commPTExpt` adds `build_comm_tensors_positional` and `build_comm_tensors_positional_with_virtual_atoms` (the latter remaps sendlists via `fwd_map` when NULL atoms drop out of the model's view). - **Phase 5 — LAMMPS tests**: end-to-end multi-rank tests (DPA3, DPA2, spin DPA3) covering the basic case + the structural edge cases (NULL-type atoms, empty subdomain, all-NULL rank, isolated NULL, nlist rebuild, N>2 decomposition). ## Coverage matrix | | DPA3 | DPA2 | spin DPA3 | | - | - | - | - | | basic mpi-2 | ✓ | ✓ | ✓ | | empty subdomain (rank 1 empty) | ✓ | — | ✓ | | all-NULL rank | ✓ | — | — | | isolated NULL | ✓ | — | — | | NULL atoms straddling boundary | ✓ | — | ✓ | | NULL across nlist rebuild | ✓ | — | — | | N>2 (2x2x1, 4x1x1, 2x2x2) | ✓ | — | — | | cached mapping_tensor (ago>0) | ✓ | — | — | Tests compare mpi-N vs same-archive mpi-1 for force / force_mag / virial / energy (atol 1e-8); no hardcoded numerical references. Plus a unit test (`test_has_message_passing.py`) pinning the `_has_message_passing` schema-drift contract. ## Co-existence with #5407 This branch was rebased onto upstream master after PR #5407 (.pt2 perf) merged. The merge required: - `forward_common_lower_exportable_with_comm` (spin and non-spin variants) now applies `_pad_nlist_for_export` + `need_sorted_nlist_for_lower=True` workarounds matching #5407's regular variant — keeps the with-comm trace's `nnei` axis dynamic. - CUDA `realize_opcount_threshold=0` workaround applied around BOTH artifact compiles. - `do_atomic_virial=True` is used for all multi-rank fixtures to avoid AOTI compile-time changes from #5407's default `=False`. (Multi-rank with `do_atomic_virial=False` is a known coverage gap — see Limitations.) ## Known limitations - **Hybrid + GNN + multi-rank**: `_has_message_passing` doesn't recurse into hybrid children → no with-comm artifact produced for hybrid-of-GNN. Multi-rank LAMMPS with such a model would fall back to single-artifact and crash on first ghost exchange. Out of scope for this PR; if a user hits this, they get the same crash they had before. - **`do_atomic_virial=False` under multi-rank**: production default. All multi-rank tests use `=True` (matches #5407 fixture conventions). Not exercised end-to-end yet. - **CUDA**: the `TORCH_LIBRARY_IMPL(deepmd_export, CUDA, m)` registration exists and the symbol is callable, but no GPU end-to-end test runs in this PR (CPU-only build environment locally). - **DPA2 (repformers) edge cases**: only basic mpi-2 is tested for DPA2. NULL / empty-subdomain / nlist-rebuild covered for DPA3 only — these paths are descriptor-agnostic by construction but not exercised end-to-end for DPA2. - **Spin DPA3 edge cases**: empty subdomain and NULL straddling boundary covered. all-NULL-rank, isolated-NULL, nlist-rebuild not. - **Multiple spin-active types** (`use_spin=[True, True, ...]`): only `[True, False]` tested. - **Frame batching nb>1 with `comm_dict`**: `_exchange_ghosts` uses `.squeeze(0)` / `.unsqueeze(0)` (mirrors pt). LAMMPS feeds nb=1 — fine in practice; breaks if reused outside LAMMPS. - **Float16/bfloat16**: comm_dict path was developed in float64 only. - **Cross-backend (jax/paddle/numpy) `comm_dict=None` neutrality**: dpmodel default `_exchange_ghosts` is the original code lifted into a method, behaviorally equivalent. Not separately re-tested via running consistent tests with `comm_dict=None` explicitly threaded. - **Old .pt2 files lacking `has_comm_artifact` metadata**: C++ defaults to single-artifact when key is missing. Not negative-tested. - **AOTI compile failure cleanup**: if the with-comm compile fails after the regular artifact is written, the resulting .pt2 is half-formed. RAII / SIGKILL leakage of `TempFile` in /tmp also pre-exists. A full catalog of touched-but-untested paths is maintained at `memory/gnn_mpi_untested_paths.md` (local to the author). ## Test plan - [ ] `pytest source/tests/pt_expt/` — eager-parity, export round-trip, schema-drift unit tests - [ ] `pytest source/tests/consistent/descriptor/test_dpa3.py source/tests/consistent/descriptor/test_dpa2.py` — non-regression for the single-rank path - [ ] `pytest source/lmp/tests/test_lammps_dpa3_pt2.py source/lmp/tests/test_lammps_dpa2_pt2.py source/lmp/tests/test_lammps_spin_dpa3_pt2.py` — multi-rank end-to-end (requires mpirun + mpi4py) - [ ] `ctest -R PtExpt` — C++ tests (single-rank, since multi-rank is exercised via LAMMPS) - [ ] CI runs across CPU + CUDA matrix <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional comms/MPI metadata support across inference paths enabling ghost-exchange; dual “with-comm” export/run artifact and tracing entrypoints for message-passing GNNs and spin models, plus improved export/tracing support for comm-enabled ops. * **Bug Fixes** * Fixed unsafe neighbor-list memory handling. * **Tests** * Added extensive MPI parity, export-with-comm, repflow/repformer parallel-path, spin integration tests and CLI MPI runner scripts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 7ca7baf commit d67e75b

67 files changed

Lines changed: 5737 additions & 135 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def forward_common_atomic(
231231
mapping: Array | None = None,
232232
fparam: Array | None = None,
233233
aparam: Array | None = None,
234+
comm_dict: dict | None = None,
234235
) -> dict[str, Array]:
235236
"""Common interface for atomic inference.
236237
@@ -252,6 +253,9 @@ def forward_common_atomic(
252253
frame parameters, shape: nf x dim_fparam
253254
aparam
254255
atomic parameter, shape: nf x nloc x dim_aparam
256+
comm_dict
257+
MPI communication metadata for parallel inference. ``None`` for
258+
non-parallel inference (default).
255259
256260
Returns
257261
-------
@@ -279,6 +283,7 @@ def forward_common_atomic(
279283
mapping=mapping,
280284
fparam=fparam,
281285
aparam=aparam,
286+
comm_dict=comm_dict,
282287
)
283288
ret_dict = self.apply_out_stat(ret_dict, atype)
284289

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def forward_atomic(
157157
mapping: Array | None = None,
158158
fparam: Array | None = None,
159159
aparam: Array | None = None,
160+
comm_dict: dict | None = None,
160161
) -> dict[str, Array]:
161162
"""Models' atomic predictions.
162163
@@ -174,6 +175,9 @@ def forward_atomic(
174175
frame parameter. nf x ndf
175176
aparam
176177
atomic parameter. nf x nloc x nda
178+
comm_dict
179+
MPI communication metadata for parallel inference. ``None`` for
180+
non-parallel inference (default). Forwarded to the descriptor.
177181
178182
Returns
179183
-------
@@ -215,6 +219,7 @@ def forward_atomic(
215219
nlist,
216220
mapping=mapping,
217221
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
222+
comm_dict=comm_dict,
218223
)
219224
ret = self.fitting_net(
220225
descriptor,

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def forward_atomic(
224224
mapping: Array | None = None,
225225
fparam: Array | None = None,
226226
aparam: Array | None = None,
227+
comm_dict: dict | None = None,
227228
) -> dict[str, Array]:
228229
"""Return atomic prediction.
229230
@@ -241,6 +242,10 @@ def forward_atomic(
241242
frame parameter. (nframes, ndf)
242243
aparam
243244
atomic parameter. (nframes, nloc, nda)
245+
comm_dict
246+
MPI communication metadata. Forwarded to each sub-model so GNN
247+
sub-descriptors can perform parallel ghost exchange. ``None`` for
248+
non-parallel inference (default).
244249
245250
Returns
246251
-------
@@ -280,6 +285,7 @@ def forward_atomic(
280285
mapping,
281286
fparam,
282287
aparam,
288+
comm_dict,
283289
)["energy"]
284290
)
285291
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def forward_atomic(
253253
mapping: Array | None = None,
254254
fparam: Array | None = None,
255255
aparam: Array | None = None,
256+
comm_dict: dict | None = None,
256257
) -> dict[str, Array]:
258+
del comm_dict # pairtab is local; no MPI ghost exchange needed.
257259
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
258260
nframes, nloc, nnei = nlist.shape
259261
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,14 @@ def has_message_passing(self) -> bool:
397397
"""Returns whether the descriptor has message passing."""
398398
return self.se_atten.has_message_passing()
399399

400+
def has_message_passing_across_ranks(self) -> bool:
401+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
402+
403+
DPA1 (se_atten) is single-layer and does not exchange features
404+
across ranks; same as the base se_e2_a path.
405+
"""
406+
return False
407+
400408
def need_sorted_nlist_for_lower(self) -> bool:
401409
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
402410
return self.se_atten.need_sorted_nlist_for_lower()
@@ -500,6 +508,7 @@ def call(
500508
nlist: Array,
501509
mapping: Array | None = None,
502510
fparam: Array | None = None,
511+
comm_dict: dict | None = None,
503512
) -> Array:
504513
"""Compute the descriptor.
505514

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,16 @@ def has_message_passing(self) -> bool:
687687
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
688688
)
689689

690+
def has_message_passing_across_ranks(self) -> bool:
691+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
692+
693+
DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]``
694+
layout (no ``use_loc_mapping`` opt-out exists at the block level),
695+
so multi-rank deployment always needs cross-rank exchange of
696+
per-atom features between layers.
697+
"""
698+
return self.repformers.has_message_passing_across_ranks()
699+
690700
def need_sorted_nlist_for_lower(self) -> bool:
691701
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
692702
return True
@@ -831,6 +841,7 @@ def call(
831841
nlist: Array,
832842
mapping: Array | None = None,
833843
fparam: Array | None = None,
844+
comm_dict: dict | None = None,
834845
) -> tuple[Array, Array, Array, Array, Array]:
835846
"""Compute the descriptor.
836847
@@ -844,6 +855,11 @@ def call(
844855
The neighbor list. shape: nf x nloc x nnei
845856
mapping
846857
The index mapping, maps extended region index to local region.
858+
comm_dict
859+
MPI communication metadata for parallel inference. Forwarded to
860+
the repformer block (the message-passing part). The repinit
861+
sub-block does no message passing and does not receive it.
862+
``None`` for non-parallel inference (default).
847863
848864
Returns
849865
-------
@@ -912,9 +928,18 @@ def call(
912928
assert self.tebd_transform is not None
913929
g1 = g1 + self.tebd_transform(g1_inp)
914930
# mapping g1
915-
assert mapping is not None
916-
mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]))
917-
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
931+
if comm_dict is None:
932+
# non-parallel: gather g1 -> g1_ext via mapping, hand the
933+
# nall-sized embedding to the repformer block.
934+
assert mapping is not None
935+
mapping_ext = xp.tile(
936+
xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1])
937+
)
938+
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
939+
else:
940+
# parallel mode: hand the local-only g1 to the repformer block;
941+
# its per-layer override fills ghosts via the MPI exchange.
942+
g1_ext = g1
918943
# repformer
919944
g1, g2, h2, rot_mat, sw = self.repformers(
920945
nlist_dict[
@@ -926,6 +951,7 @@ def call(
926951
atype_ext,
927952
g1_ext,
928953
mapping,
954+
comm_dict=comm_dict,
929955
)
930956
if self.concat_output_tebd:
931957
g1 = xp.concat([g1, g1_inp], axis=-1)

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,17 @@ def has_message_passing(self) -> bool:
527527
"""Returns whether the descriptor has message passing."""
528528
return self.repflows.has_message_passing()
529529

530+
def has_message_passing_across_ranks(self) -> bool:
531+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
532+
533+
Delegates to repflows: ``False`` when ``use_loc_mapping=True``
534+
(per-layer messages stay within each rank's local atoms),
535+
``True`` when ``use_loc_mapping=False`` (ghost slots in
536+
``[nb, nall, n_dim]`` layout must be filled by cross-rank
537+
exchange before each layer).
538+
"""
539+
return self.repflows.has_message_passing_across_ranks()
540+
530541
def need_sorted_nlist_for_lower(self) -> bool:
531542
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
532543
return True
@@ -616,6 +627,7 @@ def call(
616627
nlist: Array,
617628
mapping: Array | None = None,
618629
fparam: Array | None = None,
630+
comm_dict: dict | None = None,
619631
) -> tuple[Array, Array, Array, Array, Array]:
620632
"""Compute the descriptor.
621633
@@ -629,6 +641,9 @@ def call(
629641
The neighbor list. shape: nf x nloc x nnei
630642
mapping
631643
The index mapping, mapps extended region index to local region.
644+
comm_dict
645+
MPI communication metadata for parallel inference. Forwarded to
646+
the repflows block. ``None`` for non-parallel inference (default).
632647
633648
Returns
634649
-------
@@ -695,6 +710,7 @@ def call(
695710
atype_ext,
696711
node_ebd_ext,
697712
mapping,
713+
comm_dict=comm_dict,
698714
)
699715
if self.concat_output_tebd:
700716
node_ebd = xp.concat([node_ebd, node_ebd_inp], axis=-1)

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ def has_message_passing(self) -> bool:
168168
"""Returns whether the descriptor has message passing."""
169169
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)
170170

171+
def has_message_passing_across_ranks(self) -> bool:
172+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
173+
174+
``True`` if any child descriptor needs cross-rank message passing
175+
(e.g. a hybrid wrapping a DPA3 with ``use_loc_mapping=False``).
176+
"""
177+
return any(
178+
descrpt.has_message_passing_across_ranks() for descrpt in self.descrpt_list
179+
)
180+
171181
def need_sorted_nlist_for_lower(self) -> bool:
172182
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
173183
return True
@@ -276,6 +286,7 @@ def call(
276286
nlist: Array,
277287
mapping: Array | None = None,
278288
fparam: Array | None = None,
289+
comm_dict: dict | None = None,
279290
) -> tuple[
280291
Array,
281292
Array | None,
@@ -332,7 +343,9 @@ def call(
332343
# mixed_types is True, but descrpt.mixed_types is False
333344
assert nl_distinguish_types is not None
334345
nl = nl_distinguish_types[:, :, nci]
335-
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
346+
odescriptor, gr, _g2, _h2, _sw = descrpt(
347+
coord_ext, atype_ext, nl, mapping, comm_dict=comm_dict
348+
)
336349
out_descriptor.append(odescriptor)
337350
if gr is not None:
338351
out_gr.append(gr)

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ def mixed_types(self) -> bool:
107107
def has_message_passing(self) -> bool:
108108
"""Returns whether the descriptor has message passing."""
109109

110+
def has_message_passing_across_ranks(self) -> bool:
111+
"""Returns whether the descriptor's message passing extends across rank
112+
boundaries — i.e. whether it requires cross-rank exchange of intermediate
113+
atomic features (per-layer node embeddings) during the forward pass.
114+
115+
Distinct from generic ghost-coord/force exchange that every LAMMPS
116+
pair_style does. This question gates whether the pt_expt backend
117+
compiles a second "with-comm" AOTI artifact for multi-rank deployment.
118+
119+
Concrete default ``False`` (non-GNN behavior) so pt and pd backend
120+
descriptors that subclass ``BaseDescriptor`` directly do not have
121+
to implement this method until they grow a multi-rank GNN path of
122+
their own. GNN descriptors that need MPI ghost-feature exchange
123+
(DPA2, DPA3 with ``use_loc_mapping=False``, hybrids wrapping such
124+
children) override to return ``True``.
125+
"""
126+
return False
127+
110128
@abstractmethod
111129
def need_sorted_nlist_for_lower(self) -> bool:
112130
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,32 @@ def reinit_exclude(
506506
self.exclude_types = exclude_types
507507
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
508508

509+
def _exchange_ghosts(
510+
self,
511+
node_ebd: Array,
512+
mapping_tiled: Array | None,
513+
comm_dict: dict | None,
514+
nall: int,
515+
nloc: int,
516+
) -> Array:
517+
"""Build node_ebd_ext (the ghost-aware embedding) for the per-layer loop.
518+
519+
Default: array-api gather via the pre-tiled `mapping_tiled`, or pass the
520+
local-only `node_ebd` through when ``self.use_loc_mapping`` is set.
521+
``comm_dict``, ``nall``, ``nloc`` are unused in this default impl; they
522+
exist so the pt_expt subclass can perform the per-layer MPI ghost
523+
exchange (``deepmd_export::border_op``) when ``comm_dict is not None``.
524+
"""
525+
del comm_dict, nall, nloc
526+
if self.use_loc_mapping:
527+
return node_ebd
528+
if mapping_tiled is None:
529+
raise ValueError(
530+
"`mapping` is required when use_loc_mapping=False unless "
531+
"`_exchange_ghosts` is overridden for parallel comm handling."
532+
)
533+
return xp_take_along_axis(node_ebd, mapping_tiled, axis=1)
534+
509535
def call(
510536
self,
511537
nlist: Array,
@@ -514,6 +540,7 @@ def call(
514540
atype_embd_ext: Array | None = None,
515541
mapping: Array | None = None,
516542
type_embedding: Array | None = None,
543+
comm_dict: dict | None = None,
517544
) -> tuple[Array, Array, Array, Array, Array]:
518545
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
519546
nframes, nloc, nnei = nlist.shape
@@ -641,15 +668,24 @@ def call(
641668
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
642669
angle_ebd = self.angle_embd(angle_input)
643670

644-
# nb x nall x n_dim
645-
mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim))
671+
# nb x nall x n_dim (pre-tiled mapping reused across layers when not
672+
# using comm_dict). Skip the tile when mapping is None — pt_expt's
673+
# parallel-mode override consults comm_dict instead.
674+
mapping_tiled = (
675+
xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim))
676+
if mapping is not None
677+
else None
678+
)
646679
for idx, ll in enumerate(self.layers):
647680
# node_ebd: nb x nloc x n_dim
648-
# node_ebd_ext: nb x nall x n_dim
649-
node_ebd_ext = (
650-
node_ebd
651-
if self.use_loc_mapping
652-
else xp_take_along_axis(node_ebd, mapping, axis=1)
681+
# node_ebd_ext: nb x nall x n_dim (or nb x nloc x n_dim when
682+
# use_loc_mapping=True)
683+
node_ebd_ext = self._exchange_ghosts(
684+
node_ebd,
685+
mapping_tiled,
686+
comm_dict,
687+
nall,
688+
nloc,
653689
)
654690
node_ebd, edge_ebd, angle_ebd = ll.call(
655691
node_ebd_ext,
@@ -696,6 +732,16 @@ def has_message_passing(self) -> bool:
696732
"""Returns whether the descriptor block has message passing."""
697733
return True
698734

735+
def has_message_passing_across_ranks(self) -> bool:
736+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
737+
738+
Repflows passes ``node_ebd`` either in ``[nb, nloc, n_dim]`` layout
739+
(``use_loc_mapping=True``: messages stay within the rank's local atoms)
740+
or ``[nb, nall, n_dim]`` layout (``use_loc_mapping=False``: ghost slots
741+
must be filled by cross-rank exchange before each layer).
742+
"""
743+
return not self.use_loc_mapping
744+
699745
def need_sorted_nlist_for_lower(self) -> bool:
700746
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
701747
return True

0 commit comments

Comments
 (0)