Skip to content

Commit ab0566b

Browse files
authored
Merge branch 'master' into 0330-default-pf
2 parents a4be608 + 9847afb commit ab0566b

71 files changed

Lines changed: 5755 additions & 137 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ repos:
6262
- mdformat-gfm-alerts==2.0.0
6363
# C++
6464
- repo: https://github.com/pre-commit/mirrors-clang-format
65-
rev: v22.1.4
65+
rev: v22.1.5
6666
hooks:
6767
- id: clang-format
6868
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|source/tests/infer/.+\.json$)

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def forward_common_atomic(
231231
mapping: Array | None = None,
232232
fparam: Array | None = None,
233233
aparam: Array | None = None,
234+
comm_dict: dict | None = None,
234235
) -> dict[str, Array]:
235236
"""Common interface for atomic inference.
236237
@@ -252,6 +253,9 @@ def forward_common_atomic(
252253
frame parameters, shape: nf x dim_fparam
253254
aparam
254255
atomic parameter, shape: nf x nloc x dim_aparam
256+
comm_dict
257+
MPI communication metadata for parallel inference. ``None`` for
258+
non-parallel inference (default).
255259
256260
Returns
257261
-------
@@ -279,6 +283,7 @@ def forward_common_atomic(
279283
mapping=mapping,
280284
fparam=fparam,
281285
aparam=aparam,
286+
comm_dict=comm_dict,
282287
)
283288
ret_dict = self.apply_out_stat(ret_dict, atype)
284289

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def forward_atomic(
157157
mapping: Array | None = None,
158158
fparam: Array | None = None,
159159
aparam: Array | None = None,
160+
comm_dict: dict | None = None,
160161
) -> dict[str, Array]:
161162
"""Models' atomic predictions.
162163
@@ -174,6 +175,9 @@ def forward_atomic(
174175
frame parameter. nf x ndf
175176
aparam
176177
atomic parameter. nf x nloc x nda
178+
comm_dict
179+
MPI communication metadata for parallel inference. ``None`` for
180+
non-parallel inference (default). Forwarded to the descriptor.
177181
178182
Returns
179183
-------
@@ -215,6 +219,7 @@ def forward_atomic(
215219
nlist,
216220
mapping=mapping,
217221
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
222+
comm_dict=comm_dict,
218223
)
219224
ret = self.fitting_net(
220225
descriptor,

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def forward_atomic(
224224
mapping: Array | None = None,
225225
fparam: Array | None = None,
226226
aparam: Array | None = None,
227+
comm_dict: dict | None = None,
227228
) -> dict[str, Array]:
228229
"""Return atomic prediction.
229230
@@ -241,6 +242,10 @@ def forward_atomic(
241242
frame parameter. (nframes, ndf)
242243
aparam
243244
atomic parameter. (nframes, nloc, nda)
245+
comm_dict
246+
MPI communication metadata. Forwarded to each sub-model so GNN
247+
sub-descriptors can perform parallel ghost exchange. ``None`` for
248+
non-parallel inference (default).
244249
245250
Returns
246251
-------
@@ -280,6 +285,7 @@ def forward_atomic(
280285
mapping,
281286
fparam,
282287
aparam,
288+
comm_dict,
283289
)["energy"]
284290
)
285291
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def forward_atomic(
253253
mapping: Array | None = None,
254254
fparam: Array | None = None,
255255
aparam: Array | None = None,
256+
comm_dict: dict | None = None,
256257
) -> dict[str, Array]:
258+
del comm_dict # pairtab is local; no MPI ghost exchange needed.
257259
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
258260
nframes, nloc, nnei = nlist.shape
259261
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,14 @@ def has_message_passing(self) -> bool:
397397
"""Returns whether the descriptor has message passing."""
398398
return self.se_atten.has_message_passing()
399399

400+
def has_message_passing_across_ranks(self) -> bool:
401+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
402+
403+
DPA1 (se_atten) is single-layer and does not exchange features
404+
across ranks; same as the base se_e2_a path.
405+
"""
406+
return False
407+
400408
def need_sorted_nlist_for_lower(self) -> bool:
401409
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
402410
return self.se_atten.need_sorted_nlist_for_lower()
@@ -500,6 +508,7 @@ def call(
500508
nlist: Array,
501509
mapping: Array | None = None,
502510
fparam: Array | None = None,
511+
comm_dict: dict | None = None,
503512
) -> Array:
504513
"""Compute the descriptor.
505514

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,16 @@ def has_message_passing(self) -> bool:
687687
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
688688
)
689689

690+
def has_message_passing_across_ranks(self) -> bool:
691+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
692+
693+
DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]``
694+
layout (no ``use_loc_mapping`` opt-out exists at the block level),
695+
so multi-rank deployment always needs cross-rank exchange of
696+
per-atom features between layers.
697+
"""
698+
return self.repformers.has_message_passing_across_ranks()
699+
690700
def need_sorted_nlist_for_lower(self) -> bool:
691701
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
692702
return True
@@ -831,6 +841,7 @@ def call(
831841
nlist: Array,
832842
mapping: Array | None = None,
833843
fparam: Array | None = None,
844+
comm_dict: dict | None = None,
834845
) -> tuple[Array, Array, Array, Array, Array]:
835846
"""Compute the descriptor.
836847
@@ -844,6 +855,11 @@ def call(
844855
The neighbor list. shape: nf x nloc x nnei
845856
mapping
846857
The index mapping, maps extended region index to local region.
858+
comm_dict
859+
MPI communication metadata for parallel inference. Forwarded to
860+
the repformer block (the message-passing part). The repinit
861+
sub-block does no message passing and does not receive it.
862+
``None`` for non-parallel inference (default).
847863
848864
Returns
849865
-------
@@ -912,9 +928,18 @@ def call(
912928
assert self.tebd_transform is not None
913929
g1 = g1 + self.tebd_transform(g1_inp)
914930
# mapping g1
915-
assert mapping is not None
916-
mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]))
917-
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
931+
if comm_dict is None:
932+
# non-parallel: gather g1 -> g1_ext via mapping, hand the
933+
# nall-sized embedding to the repformer block.
934+
assert mapping is not None
935+
mapping_ext = xp.tile(
936+
xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1])
937+
)
938+
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
939+
else:
940+
# parallel mode: hand the local-only g1 to the repformer block;
941+
# its per-layer override fills ghosts via the MPI exchange.
942+
g1_ext = g1
918943
# repformer
919944
g1, g2, h2, rot_mat, sw = self.repformers(
920945
nlist_dict[
@@ -926,6 +951,7 @@ def call(
926951
atype_ext,
927952
g1_ext,
928953
mapping,
954+
comm_dict=comm_dict,
929955
)
930956
if self.concat_output_tebd:
931957
g1 = xp.concat([g1, g1_inp], axis=-1)

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,17 @@ def has_message_passing(self) -> bool:
527527
"""Returns whether the descriptor has message passing."""
528528
return self.repflows.has_message_passing()
529529

530+
def has_message_passing_across_ranks(self) -> bool:
531+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
532+
533+
Delegates to repflows: ``False`` when ``use_loc_mapping=True``
534+
(per-layer messages stay within each rank's local atoms),
535+
``True`` when ``use_loc_mapping=False`` (ghost slots in
536+
``[nb, nall, n_dim]`` layout must be filled by cross-rank
537+
exchange before each layer).
538+
"""
539+
return self.repflows.has_message_passing_across_ranks()
540+
530541
def need_sorted_nlist_for_lower(self) -> bool:
531542
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
532543
return True
@@ -616,6 +627,7 @@ def call(
616627
nlist: Array,
617628
mapping: Array | None = None,
618629
fparam: Array | None = None,
630+
comm_dict: dict | None = None,
619631
) -> tuple[Array, Array, Array, Array, Array]:
620632
"""Compute the descriptor.
621633
@@ -629,6 +641,9 @@ def call(
629641
The neighbor list. shape: nf x nloc x nnei
630642
mapping
631643
The index mapping, mapps extended region index to local region.
644+
comm_dict
645+
MPI communication metadata for parallel inference. Forwarded to
646+
the repflows block. ``None`` for non-parallel inference (default).
632647
633648
Returns
634649
-------
@@ -695,6 +710,7 @@ def call(
695710
atype_ext,
696711
node_ebd_ext,
697712
mapping,
713+
comm_dict=comm_dict,
698714
)
699715
if self.concat_output_tebd:
700716
node_ebd = xp.concat([node_ebd, node_ebd_inp], axis=-1)

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ def has_message_passing(self) -> bool:
168168
"""Returns whether the descriptor has message passing."""
169169
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)
170170

171+
def has_message_passing_across_ranks(self) -> bool:
172+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
173+
174+
``True`` if any child descriptor needs cross-rank message passing
175+
(e.g. a hybrid wrapping a DPA3 with ``use_loc_mapping=False``).
176+
"""
177+
return any(
178+
descrpt.has_message_passing_across_ranks() for descrpt in self.descrpt_list
179+
)
180+
171181
def need_sorted_nlist_for_lower(self) -> bool:
172182
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
173183
return True
@@ -276,6 +286,7 @@ def call(
276286
nlist: Array,
277287
mapping: Array | None = None,
278288
fparam: Array | None = None,
289+
comm_dict: dict | None = None,
279290
) -> tuple[
280291
Array,
281292
Array | None,
@@ -332,7 +343,9 @@ def call(
332343
# mixed_types is True, but descrpt.mixed_types is False
333344
assert nl_distinguish_types is not None
334345
nl = nl_distinguish_types[:, :, nci]
335-
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
346+
odescriptor, gr, _g2, _h2, _sw = descrpt(
347+
coord_ext, atype_ext, nl, mapping, comm_dict=comm_dict
348+
)
336349
out_descriptor.append(odescriptor)
337350
if gr is not None:
338351
out_gr.append(gr)

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ def mixed_types(self) -> bool:
107107
def has_message_passing(self) -> bool:
108108
"""Returns whether the descriptor has message passing."""
109109

110+
def has_message_passing_across_ranks(self) -> bool:
111+
"""Returns whether the descriptor's message passing extends across rank
112+
boundaries — i.e. whether it requires cross-rank exchange of intermediate
113+
atomic features (per-layer node embeddings) during the forward pass.
114+
115+
Distinct from generic ghost-coord/force exchange that every LAMMPS
116+
pair_style does. This question gates whether the pt_expt backend
117+
compiles a second "with-comm" AOTI artifact for multi-rank deployment.
118+
119+
Concrete default ``False`` (non-GNN behavior) so pt and pd backend
120+
descriptors that subclass ``BaseDescriptor`` directly do not have
121+
to implement this method until they grow a multi-rank GNN path of
122+
their own. GNN descriptors that need MPI ghost-feature exchange
123+
(DPA2, DPA3 with ``use_loc_mapping=False``, hybrids wrapping such
124+
children) override to return ``True``.
125+
"""
126+
return False
127+
110128
@abstractmethod
111129
def need_sorted_nlist_for_lower(self) -> bool:
112130
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

0 commit comments

Comments
 (0)