@@ -687,6 +687,16 @@ def has_message_passing(self) -> bool:
687687 [self .repinit .has_message_passing (), self .repformers .has_message_passing ()]
688688 )
689689
690+ def has_message_passing_across_ranks (self ) -> bool :
691+ """Returns whether per-layer node embeddings need MPI ghost exchange.
692+
693+ DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]``
694+ layout (no ``use_loc_mapping`` opt-out exists at the block level),
695+ so multi-rank deployment always needs cross-rank exchange of
696+ per-atom features between layers.
697+ """
698+ return self .repformers .has_message_passing_across_ranks ()
699+
690700 def need_sorted_nlist_for_lower (self ) -> bool :
691701 """Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
692702 return True
@@ -831,6 +841,7 @@ def call(
831841 nlist : Array ,
832842 mapping : Array | None = None ,
833843 fparam : Array | None = None ,
844+ comm_dict : dict | None = None ,
834845 charge_spin : Array | None = None ,
835846 ) -> tuple [Array , Array , Array , Array , Array ]:
836847 """Compute the descriptor.
@@ -845,6 +856,11 @@ def call(
845856 The neighbor list. shape: nf x nloc x nnei
846857 mapping
847858 The index mapping, maps extended region index to local region.
859+ comm_dict
860+ MPI communication metadata for parallel inference. Forwarded to
861+ the repformer block (the message-passing part). The repinit
862+ sub-block does no message passing and does not receive it.
863+ ``None`` for non-parallel inference (default).
848864
849865 Returns
850866 -------
@@ -913,9 +929,18 @@ def call(
913929 assert self .tebd_transform is not None
914930 g1 = g1 + self .tebd_transform (g1_inp )
915931 # mapping g1
916- assert mapping is not None
917- mapping_ext = xp .tile (xp .expand_dims (mapping , axis = - 1 ), (1 , 1 , g1 .shape [- 1 ]))
918- g1_ext = xp_take_along_axis (g1 , mapping_ext , axis = 1 )
932+ if comm_dict is None :
933+ # non-parallel: gather g1 -> g1_ext via mapping, hand the
934+ # nall-sized embedding to the repformer block.
935+ assert mapping is not None
936+ mapping_ext = xp .tile (
937+ xp .expand_dims (mapping , axis = - 1 ), (1 , 1 , g1 .shape [- 1 ])
938+ )
939+ g1_ext = xp_take_along_axis (g1 , mapping_ext , axis = 1 )
940+ else :
941+ # parallel mode: hand the local-only g1 to the repformer block;
942+ # its per-layer override fills ghosts via the MPI exchange.
943+ g1_ext = g1
919944 # repformer
920945 g1 , g2 , h2 , rot_mat , sw = self .repformers (
921946 nlist_dict [
@@ -927,6 +952,7 @@ def call(
927952 atype_ext ,
928953 g1_ext ,
929954 mapping ,
955+ comm_dict = comm_dict ,
930956 )
931957 if self .concat_output_tebd :
932958 g1 = xp .concat ([g1 , g1_inp ], axis = - 1 )
0 commit comments