44
55from pylops .utils .backend import get_module
66from pylops_mpi import MPILinearOperator , DistributedArray , Partition
7+ from pylops_mpi .Distributed import DistributedMixIn
78
89
910def halo_block_split (global_shape : tuple , comm , grid_shape : tuple = None ) -> tuple :
@@ -31,20 +32,18 @@ def halo_block_split(global_shape: tuple, comm, grid_shape: tuple = None) -> tup
3132 return tuple (slices )
3233
3334
34- class MPIHalo (MPILinearOperator ):
35+ class MPIHalo (DistributedMixIn , MPILinearOperator ):
3536 def __init__ (
3637 self ,
3738 dims : tuple ,
3839 halo ,
3940 proc_grid_shape : tuple = None ,
4041 comm : MPI .Comm = MPI .COMM_WORLD ,
4142 dtype = np .float64 ,
42- normalize : bool = False ,
4343 ):
4444 self .global_dims = tuple (dims )
4545 self .ndim = len (dims )
4646
47- self .normalize = normalize
4847 self .comm = comm
4948 self .dtype = dtype
5049
@@ -55,13 +54,16 @@ def __init__(
5554
5655 self .local_dims = self ._calc_local_dims ()
5756 self .local_extent = self ._calc_local_extent ()
58- self ._local_extent_sizes = self .comm .allgather (int (np .prod (self .local_extent )))
57+ self ._local_extent_sizes = self ._allgather (
58+ self .comm ,
59+ None ,
60+ int (np .prod (self .local_extent )),
61+ )
5962
6063 self .shape = (
6164 int (np .sum (self ._local_extent_sizes )),
6265 int (np .prod (self .global_dims )),
6366 )
64- self ._norm_factors = self ._compute_norm_factors () if self .normalize else None
6567 super ().__init__ (shape = self .shape , dtype = np .dtype (dtype ), base_comm = comm )
6668
6769 def _parse_halo (self , h ):
@@ -117,26 +119,6 @@ def _calc_local_extent(self):
117119 ext .append (self .local_dims [ax ] + minus_halo + plus_halo )
118120 return tuple (ext )
119121
120- def _compute_norm_factors (self ):
121- weights = np .ones (self .local_dims , dtype = np .float64 )
122- for ax in range (self .ndim ):
123- before , after = self .halo [2 * ax ], self .halo [2 * ax + 1 ]
124- if before == 0 and after == 0 :
125- continue
126- n_local = self .local_dims [ax ]
127- factors = np .ones (n_local , dtype = np .float64 )
128- minus_nbr , plus_nbr = self .neigh [("-" , ax )], self .neigh [("+" , ax )]
129- minus_copy = minus_nbr != MPI .PROC_NULL
130- plus_copy = plus_nbr != MPI .PROC_NULL
131- if before and minus_copy :
132- factors [:before ] += 1.0
133- if after and plus_copy :
134- factors [n_local - after :] += 1.0
135- shape = [1 ] * self .ndim
136- shape [ax ] = n_local
137- weights *= factors .reshape (shape )
138- return (1.0 / np .sqrt (weights )).astype (self .dtype , copy = False )
139-
140122 def _exchange_along_axis (self , ncp , arr , axis , before , after ):
141123 minus_nbr , plus_nbr = self .neigh [("-" , axis )], self .neigh [("+" , axis )]
142124 # slice definitions
@@ -162,46 +144,6 @@ def _exchange_along_axis(self, ncp, arr, axis, before, after):
162144 self .cart_comm .Sendrecv (snd , dest = plus_nbr , recvbuf = rcv , source = plus_nbr )
163145 arr [tuple (rcv_s )] = rcv
164146
165- def _exchange_adjoint_along_axis (self , ncp , arr , axis , before , after ):
166- minus_nbr , plus_nbr = self .neigh [("-" , axis )], self .neigh [("+" , axis )]
167- out = arr .copy ()
168- if before == 0 and after == 0 :
169- return out
170-
171- def axis_slice (start , end ):
172- s = [slice (None )] * self .ndim
173- s [axis ] = slice (start , end )
174- return tuple (s )
175-
176- empty = axis_slice (0 , 0 )
177- halo_minus = axis_slice (0 , before ) if before else None
178- halo_plus = axis_slice (- after , None ) if after else None
179- core_minus = axis_slice (before , 2 * before ) if before else None
180- core_plus = axis_slice (- 2 * after , - after ) if after else None
181-
182- # remove contributions from halo slabs along this axis (forward overwrites them)
183- if halo_minus :
184- out [halo_minus ] = 0
185- if halo_plus :
186- out [halo_plus ] = 0
187-
188- # send right halo to plus neighbor, receive right halo from minus neighbor
189- snd_slice = halo_plus or empty
190- rcv_slice = halo_minus or empty
191- rcv = ncp .zeros_like (arr [rcv_slice ])
192- self .cart_comm .Sendrecv (arr [snd_slice ].copy (), dest = plus_nbr , recvbuf = rcv , source = minus_nbr )
193- if core_minus :
194- out [core_minus ] += rcv
195-
196- # send left halo to minus neighbor, receive left halo from plus neighbor
197- snd_slice = halo_minus or empty
198- rcv_slice = halo_plus or empty
199- rcv = ncp .zeros_like (arr [rcv_slice ])
200- self .cart_comm .Sendrecv (arr [snd_slice ].copy (), dest = minus_nbr , recvbuf = rcv , source = plus_nbr )
201- if core_plus :
202- out [core_plus ] += rcv
203- return out
204-
205147 def _matvec (self , x ):
206148 ncp = get_module (x .engine )
207149 if x .partition != Partition .SCATTER :
@@ -211,12 +153,13 @@ def _matvec(self, x):
211153 global_shape = self .shape [0 ],
212154 partition = Partition .SCATTER ,
213155 local_shapes = self ._local_extent_sizes ,
156+ base_comm = x .base_comm ,
157+ base_comm_nccl = x .base_comm_nccl ,
158+ engine = x .engine ,
159+ dtype = self .dtype ,
214160 )
215161
216162 core = x .local_array .reshape (self .local_dims )
217- if self .normalize :
218- norm = self ._norm_factors if ncp is np else ncp .asarray (self ._norm_factors )
219- core = core * norm
220163 halo_arr = ncp .zeros (self .local_extent , dtype = self .dtype )
221164 # insert core
222165 core_slices = [
@@ -234,25 +177,16 @@ def _matvec(self, x):
234177 return y
235178
236179 def _rmatvec (self , x ):
237- ncp = get_module (x .engine )
238180 if x .partition != Partition .SCATTER :
239181 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
240-
241- y = DistributedArray (global_shape = self .shape [1 ], partition = Partition .SCATTER )
242-
243- arr = x .local_array .reshape (self .local_extent ).copy ()
244- # adjoint of halo exchange (reverse axis order because of the Transpose)
245- for ax in reversed (range (self .ndim )):
246- before , after = self .halo [2 * ax ], self .halo [2 * ax + 1 ]
247- arr = self ._exchange_adjoint_along_axis (ncp , arr , axis = ax , before = before , after = after )
248- core_slices = [
249- slice (left , left + ldim )
250- for left , ldim in zip (self .halo [::2 ], self .local_dims )
251- ]
182+ res = DistributedArray (global_shape = self .shape [1 ],
183+ partition = Partition .SCATTER ,
184+ base_comm = x .base_comm ,
185+ base_comm_nccl = x .base_comm_nccl ,
186+ engine = x .engine ,
187+ dtype = self .dtype )
188+ arr = x .local_array .reshape (self .local_extent )
189+ core_slices = [slice (left , left + ldim ) for left , ldim in zip (self .halo [::2 ], self .local_dims )]
252190 core = arr [tuple (core_slices )]
253- if self .normalize :
254- norm = self ._norm_factors if ncp is np else ncp .asarray (self ._norm_factors )
255- core = core * norm
256-
257- y [:] = core .ravel ()
258- return y
191+ res [:] = core .ravel ()
192+ return res
0 commit comments