Skip to content

Commit 4aa0f11

Browse files
committed
merge master
2 parents 03d2141 + 9847afb commit 4aa0f11

71 files changed

Lines changed: 5771 additions & 149 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ repos:
6262
- mdformat-gfm-alerts==2.0.0
6363
# C++
6464
- repo: https://github.com/pre-commit/mirrors-clang-format
65-
rev: v22.1.4
65+
rev: v22.1.5
6666
hooks:
6767
- id: clang-format
6868
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|source/tests/infer/.+\.json$)

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def forward_common_atomic(
243243
mapping: Array | None = None,
244244
fparam: Array | None = None,
245245
aparam: Array | None = None,
246+
comm_dict: dict | None = None,
246247
charge_spin: Array | None = None,
247248
) -> dict[str, Array]:
248249
"""Common interface for atomic inference.
@@ -265,6 +266,9 @@ def forward_common_atomic(
265266
frame parameters, shape: nf x dim_fparam
266267
aparam
267268
atomic parameter, shape: nf x nloc x dim_aparam
269+
comm_dict
270+
MPI communication metadata for parallel inference. ``None`` for
271+
non-parallel inference (default).
268272
269273
Returns
270274
-------
@@ -292,6 +296,7 @@ def forward_common_atomic(
292296
mapping=mapping,
293297
fparam=fparam,
294298
aparam=aparam,
299+
comm_dict=comm_dict,
295300
charge_spin=charge_spin,
296301
)
297302
ret_dict = self.apply_out_stat(ret_dict, atype)

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def forward_atomic(
179179
mapping: Array | None = None,
180180
fparam: Array | None = None,
181181
aparam: Array | None = None,
182+
comm_dict: dict | None = None,
182183
charge_spin: Array | None = None,
183184
) -> dict[str, Array]:
184185
"""Models' atomic predictions.
@@ -197,6 +198,9 @@ def forward_atomic(
197198
frame parameter. nf x ndf
198199
aparam
199200
atomic parameter. nf x nloc x nda
201+
comm_dict
202+
MPI communication metadata for parallel inference. ``None`` for
203+
non-parallel inference (default). Forwarded to the descriptor.
200204
charge_spin
201205
charge and spin parameter for descriptor. nf x 2
202206
@@ -230,6 +234,7 @@ def forward_atomic(
230234
extended_atype,
231235
nlist,
232236
mapping=mapping,
237+
comm_dict=comm_dict,
233238
charge_spin=charge_spin if self.add_chg_spin_ebd else None,
234239
)
235240
ret = self.fitting_net(

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def forward_atomic(
224224
mapping: Array | None = None,
225225
fparam: Array | None = None,
226226
aparam: Array | None = None,
227+
comm_dict: dict | None = None,
227228
charge_spin: Array | None = None,
228229
) -> dict[str, Array]:
229230
"""Return atomic prediction.
@@ -242,6 +243,10 @@ def forward_atomic(
242243
frame parameter. (nframes, ndf)
243244
aparam
244245
atomic parameter. (nframes, nloc, nda)
246+
comm_dict
247+
MPI communication metadata. Forwarded to each sub-model so GNN
248+
sub-descriptors can perform parallel ghost exchange. ``None`` for
249+
non-parallel inference (default).
245250
246251
Returns
247252
-------
@@ -281,6 +286,7 @@ def forward_atomic(
281286
mapping,
282287
fparam,
283288
aparam,
289+
comm_dict,
284290
charge_spin=charge_spin,
285291
)["energy"]
286292
)

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,10 @@ def forward_atomic(
253253
mapping: Array | None = None,
254254
fparam: Array | None = None,
255255
aparam: Array | None = None,
256+
comm_dict: dict | None = None,
256257
charge_spin: Array | None = None,
257258
) -> dict[str, Array]:
259+
del comm_dict # pairtab is local; no MPI ghost exchange needed.
258260
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
259261
nframes, nloc, nnei = nlist.shape
260262
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,14 @@ def has_message_passing(self) -> bool:
397397
"""Returns whether the descriptor has message passing."""
398398
return self.se_atten.has_message_passing()
399399

400+
def has_message_passing_across_ranks(self) -> bool:
401+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
402+
403+
DPA1 (se_atten) is single-layer and does not exchange features
404+
across ranks; same as the base se_e2_a path.
405+
"""
406+
return False
407+
400408
def need_sorted_nlist_for_lower(self) -> bool:
401409
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
402410
return self.se_atten.need_sorted_nlist_for_lower()
@@ -500,6 +508,7 @@ def call(
500508
nlist: Array,
501509
mapping: Array | None = None,
502510
fparam: Array | None = None,
511+
comm_dict: dict | None = None,
503512
charge_spin: Array | None = None,
504513
) -> Array:
505514
"""Compute the descriptor.

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,16 @@ def has_message_passing(self) -> bool:
687687
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
688688
)
689689

690+
def has_message_passing_across_ranks(self) -> bool:
691+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
692+
693+
DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]``
694+
layout (no ``use_loc_mapping`` opt-out exists at the block level),
695+
so multi-rank deployment always needs cross-rank exchange of
696+
per-atom features between layers.
697+
"""
698+
return self.repformers.has_message_passing_across_ranks()
699+
690700
def need_sorted_nlist_for_lower(self) -> bool:
691701
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
692702
return True
@@ -831,6 +841,7 @@ def call(
831841
nlist: Array,
832842
mapping: Array | None = None,
833843
fparam: Array | None = None,
844+
comm_dict: dict | None = None,
834845
charge_spin: Array | None = None,
835846
) -> tuple[Array, Array, Array, Array, Array]:
836847
"""Compute the descriptor.
@@ -845,6 +856,11 @@ def call(
845856
The neighbor list. shape: nf x nloc x nnei
846857
mapping
847858
The index mapping, maps extended region index to local region.
859+
comm_dict
860+
MPI communication metadata for parallel inference. Forwarded to
861+
the repformer block (the message-passing part). The repinit
862+
sub-block does no message passing and does not receive it.
863+
``None`` for non-parallel inference (default).
848864
849865
Returns
850866
-------
@@ -913,9 +929,18 @@ def call(
913929
assert self.tebd_transform is not None
914930
g1 = g1 + self.tebd_transform(g1_inp)
915931
# mapping g1
916-
assert mapping is not None
917-
mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]))
918-
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
932+
if comm_dict is None:
933+
# non-parallel: gather g1 -> g1_ext via mapping, hand the
934+
# nall-sized embedding to the repformer block.
935+
assert mapping is not None
936+
mapping_ext = xp.tile(
937+
xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1])
938+
)
939+
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
940+
else:
941+
# parallel mode: hand the local-only g1 to the repformer block;
942+
# its per-layer override fills ghosts via the MPI exchange.
943+
g1_ext = g1
919944
# repformer
920945
g1, g2, h2, rot_mat, sw = self.repformers(
921946
nlist_dict[
@@ -927,6 +952,7 @@ def call(
927952
atype_ext,
928953
g1_ext,
929954
mapping,
955+
comm_dict=comm_dict,
930956
)
931957
if self.concat_output_tebd:
932958
g1 = xp.concat([g1, g1_inp], axis=-1)

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,17 @@ def has_message_passing(self) -> bool:
545545
"""Returns whether the descriptor has message passing."""
546546
return self.repflows.has_message_passing()
547547

548+
def has_message_passing_across_ranks(self) -> bool:
549+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
550+
551+
Delegates to repflows: ``False`` when ``use_loc_mapping=True``
552+
(per-layer messages stay within each rank's local atoms),
553+
``True`` when ``use_loc_mapping=False`` (ghost slots in
554+
``[nb, nall, n_dim]`` layout must be filled by cross-rank
555+
exchange before each layer).
556+
"""
557+
return self.repflows.has_message_passing_across_ranks()
558+
548559
def need_sorted_nlist_for_lower(self) -> bool:
549560
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
550561
return True
@@ -634,6 +645,7 @@ def call(
634645
nlist: Array,
635646
mapping: Array | None = None,
636647
fparam: Array | None = None,
648+
comm_dict: dict | None = None,
637649
charge_spin: Array | None = None,
638650
) -> tuple[Array, Array, Array, Array, Array]:
639651
"""Compute the descriptor.
@@ -648,6 +660,9 @@ def call(
648660
The neighbor list. shape: nf x nloc x nnei
649661
mapping
650662
The index mapping, mapps extended region index to local region.
663+
comm_dict
664+
MPI communication metadata for parallel inference. Forwarded to
665+
the repflows block. ``None`` for non-parallel inference (default).
651666
652667
Returns
653668
-------
@@ -714,6 +729,7 @@ def call(
714729
atype_ext,
715730
node_ebd_ext,
716731
mapping,
732+
comm_dict=comm_dict,
717733
)
718734
if self.concat_output_tebd:
719735
node_ebd = xp.concat([node_ebd, node_ebd_inp], axis=-1)

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ def has_message_passing(self) -> bool:
168168
"""Returns whether the descriptor has message passing."""
169169
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)
170170

171+
def has_message_passing_across_ranks(self) -> bool:
172+
"""Returns whether per-layer node embeddings need MPI ghost exchange.
173+
174+
``True`` if any child descriptor needs cross-rank message passing
175+
(e.g. a hybrid wrapping a DPA3 with ``use_loc_mapping=False``).
176+
"""
177+
return any(
178+
descrpt.has_message_passing_across_ranks() for descrpt in self.descrpt_list
179+
)
180+
171181
def need_sorted_nlist_for_lower(self) -> bool:
172182
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
173183
return True
@@ -276,6 +286,7 @@ def call(
276286
nlist: Array,
277287
mapping: Array | None = None,
278288
fparam: Array | None = None,
289+
comm_dict: dict | None = None,
279290
charge_spin: Array | None = None,
280291
) -> tuple[
281292
Array,
@@ -333,8 +344,13 @@ def call(
333344
# mixed_types is True, but descrpt.mixed_types is False
334345
assert nl_distinguish_types is not None
335346
nl = nl_distinguish_types[:, :, nci]
336-
odescriptor, gr, g2, h2, sw = descrpt(
337-
coord_ext, atype_ext, nl, mapping, charge_spin=charge_spin
347+
odescriptor, gr, _g2, _h2, _sw = descrpt(
348+
coord_ext,
349+
atype_ext,
350+
nl,
351+
mapping,
352+
comm_dict=comm_dict,
353+
charge_spin=charge_spin,
338354
)
339355
out_descriptor.append(odescriptor)
340356
if gr is not None:

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def mixed_types(self) -> bool:
119119
def has_message_passing(self) -> bool:
120120
"""Returns whether the descriptor has message passing."""
121121

122+
def has_message_passing_across_ranks(self) -> bool:
123+
"""Returns whether the descriptor's message passing extends across rank
124+
boundaries — i.e. whether it requires cross-rank exchange of intermediate
125+
atomic features (per-layer node embeddings) during the forward pass.
126+
127+
Distinct from generic ghost-coord/force exchange that every LAMMPS
128+
pair_style does. This question gates whether the pt_expt backend
129+
compiles a second "with-comm" AOTI artifact for multi-rank deployment.
130+
131+
Concrete default ``False`` (non-GNN behavior) so pt and pd backend
132+
descriptors that subclass ``BaseDescriptor`` directly do not have
133+
to implement this method until they grow a multi-rank GNN path of
134+
their own. GNN descriptors that need MPI ghost-feature exchange
135+
(DPA2, DPA3 with ``use_loc_mapping=False``, hybrids wrapping such
136+
children) override to return ``True``.
137+
"""
138+
return False
139+
122140
@abstractmethod
123141
def need_sorted_nlist_for_lower(self) -> bool:
124142
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

0 commit comments

Comments
 (0)