11import math
2+ from typing import Any , Dict , Optional , Tuple , Union
23
34import numpy as np
45from mpi4py import MPI
89from pylops_mpi .Distributed import DistributedMixIn
910
1011
11- def halo_block_split (global_shape : tuple , comm , grid_shape : tuple = None ) -> tuple :
12+ def halo_block_split (
13+ global_shape : tuple ,
14+ comm : MPI .Comm ,
15+ grid_shape : Optional [tuple ] = None ,
16+ ) -> tuple :
1217 r"""Split a global array over a Cartesian process grid.
1318
1419 Compute the local slice owned by the calling rank when ``global_shape`` is
@@ -133,11 +138,11 @@ class MPIHalo(DistributedMixIn, MPILinearOperator):
133138 def __init__ (
134139 self ,
135140 dims : tuple ,
136- halo ,
137- proc_grid_shape : tuple = None ,
141+ halo : Union [ int , tuple ] ,
142+ proc_grid_shape : Optional [ tuple ] = None ,
138143 comm : MPI .Comm = MPI .COMM_WORLD ,
139- dtype = np .float64 ,
140- ):
144+ dtype : Any = np .float64 ,
145+ ) -> None :
141146 self .global_dims = tuple (dims )
142147 self .ndim = len (dims )
143148
@@ -163,7 +168,8 @@ def __init__(
163168 )
164169 super ().__init__ (shape = self .shape , dtype = np .dtype (dtype ), base_comm = comm )
165170
166- def _parse_halo (self , h ):
171+ def _parse_halo (self , h : Union [int , tuple ]) -> tuple :
172+ """Normalize halo input and trim halos at global boundaries."""
167173 if isinstance (h , (int , np .int64 , np .int32 )):
168174 halo = (h ,) * (2 * self .ndim )
169175 trimmed = list (halo )
@@ -185,7 +191,8 @@ def _parse_halo(self, h):
185191 raise ValueError (f"Invalid halo length { len (h )} for ndim={ self .ndim } " )
186192 return halo
187193
188- def _build_topo (self ):
194+ def _build_topo (self ) -> Tuple [MPI .Comm , Dict [Tuple [str , int ], int ]]:
195+ """Create the Cartesian communicator and map neighboring ranks on the distribution axis."""
189196 cart_comm = self .comm .Create_cart (
190197 self .proc_grid_shape ,
191198 periods = [False ] * self .ndim ,
@@ -198,7 +205,8 @@ def _build_topo(self):
198205 neigh [("+" , ax )] = after
199206 return cart_comm , neigh
200207
201- def _calc_local_dims (self ):
208+ def _calc_local_dims (self ) -> tuple :
209+ """Compute this rank's local block shape before halo padding."""
202210 rank = self .cart_comm .Get_rank ()
203211 coords = self .cart_comm .Get_coords (rank )
204212 local = []
@@ -211,14 +219,16 @@ def _calc_local_dims(self):
211219 local .append (end - start )
212220 return tuple (local )
213221
214- def _calc_local_extent (self ):
222+ def _calc_local_extent (self ) -> tuple :
223+ """Compute this rank's local block shape after halo padding."""
215224 ext = []
216225 for ax in range (self .ndim ):
217226 minus_halo , plus_halo = self .halo [2 * ax ], self .halo [2 * ax + 1 ]
218227 ext .append (self .local_dims [ax ] + minus_halo + plus_halo )
219228 return tuple (ext )
220229
221- def _exchange_along_axis (self , ncp , arr , axis , before , after , engine ):
230+ def _exchange_along_axis (self , ncp : Any , arr : Any , axis : int , before : int , after : int , engine : str ) -> None :
231+ """Exchange boundary/halo slices with neighboring ranks along one axis."""
222232 minus_nbr , plus_nbr = self .neigh [("-" , axis )], self .neigh [("+" , axis )]
223233 # slice definitions
224234 slicer = [slice (None )] * self .ndim
@@ -259,7 +269,7 @@ def _exchange_along_axis(self, ncp, arr, axis, before, after, engine):
259269 )
260270 arr [tuple (rcv_s )] = rcv
261271
262- def _matvec (self , x ) :
272+ def _matvec (self , x : DistributedArray ) -> DistributedArray :
263273 ncp = get_module (x .engine )
264274 if x .partition != Partition .SCATTER :
265275 raise ValueError (
@@ -295,7 +305,7 @@ def _matvec(self, x):
295305 y [:] = halo_arr .ravel ()
296306 return y
297307
298- def _rmatvec (self , x ) :
308+ def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
299309 if x .partition != Partition .SCATTER :
300310 raise ValueError (
301311 f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..."
0 commit comments