@@ -687,6 +687,16 @@ def has_message_passing(self) -> bool:
687687 [self .repinit .has_message_passing (), self .repformers .has_message_passing ()]
688688 )
689689
690+ def has_message_passing_across_ranks (self ) -> bool :
691+ """Returns whether per-layer node embeddings need MPI ghost exchange.
692+
693+ DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]``
694+ layout (no ``use_loc_mapping`` opt-out exists at the block level),
695+ so multi-rank deployment always needs cross-rank exchange of
696+ per-atom features between layers.
697+ """
698+ return self .repformers .has_message_passing_across_ranks ()
699+
690700 def need_sorted_nlist_for_lower (self ) -> bool :
691701 """Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
692702 return True
@@ -831,6 +841,7 @@ def call(
831841 nlist : Array ,
832842 mapping : Array | None = None ,
833843 fparam : Array | None = None ,
844+ comm_dict : dict | None = None ,
834845 ) -> tuple [Array , Array , Array , Array , Array ]:
835846 """Compute the descriptor.
836847
@@ -844,6 +855,11 @@ def call(
844855 The neighbor list. shape: nf x nloc x nnei
845856 mapping
846857 The index mapping, maps extended region index to local region.
858+ comm_dict
859+ MPI communication metadata for parallel inference. Forwarded to
860+ the repformer block (the message-passing part). The repinit
861+ sub-block does no message passing and does not receive it.
862+ ``None`` for non-parallel inference (default).
847863
848864 Returns
849865 -------
@@ -912,9 +928,18 @@ def call(
912928 assert self .tebd_transform is not None
913929 g1 = g1 + self .tebd_transform (g1_inp )
914930 # mapping g1
915- assert mapping is not None
916- mapping_ext = xp .tile (xp .expand_dims (mapping , axis = - 1 ), (1 , 1 , g1 .shape [- 1 ]))
917- g1_ext = xp_take_along_axis (g1 , mapping_ext , axis = 1 )
931+ if comm_dict is None :
932+ # non-parallel: gather g1 -> g1_ext via mapping, hand the
933+ # nall-sized embedding to the repformer block.
934+ assert mapping is not None
935+ mapping_ext = xp .tile (
936+ xp .expand_dims (mapping , axis = - 1 ), (1 , 1 , g1 .shape [- 1 ])
937+ )
938+ g1_ext = xp_take_along_axis (g1 , mapping_ext , axis = 1 )
939+ else :
940+ # parallel mode: hand the local-only g1 to the repformer block;
941+ # its per-layer override fills ghosts via the MPI exchange.
942+ g1_ext = g1
918943 # repformer
919944 g1 , g2 , h2 , rot_mat , sw = self .repformers (
920945 nlist_dict [
@@ -926,6 +951,7 @@ def call(
926951 atype_ext ,
927952 g1_ext ,
928953 mapping ,
954+ comm_dict = comm_dict ,
929955 )
930956 if self .concat_output_tebd :
931957 g1 = xp .concat ([g1 , g1_inp ], axis = - 1 )
0 commit comments