Skip to content

Commit 8a2c5b9

Browse files
committed
removed normalization restored normal behaviour
1 parent 93f7a75 commit 8a2c5b9

4 files changed

Lines changed: 64 additions & 109 deletions

File tree

examples/plot_halo.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ def local_extent_from_slice(local_shape, local_slice, halo):
209209
dims=dims,
210210
halo=halo,
211211
proc_grid_shape=proc_grid_shape,
212-
comm=comm,
213-
normalize=True, # why?
212+
comm=comm
214213
)
215214

216215
# Global array
@@ -261,8 +260,7 @@ def local_extent_from_slice(local_shape, local_slice, halo):
261260
dims=dims,
262261
halo=halo,
263262
proc_grid_shape=proc_grid_shape,
264-
comm=comm,
265-
normalize=False, # why?
263+
comm=comm
266264
)
267265

268266
for axis in [0, 1]:
@@ -307,8 +305,7 @@ def local_extent_from_slice(local_shape, local_slice, halo):
307305
dims=dims,
308306
halo=halo,
309307
proc_grid_shape=proc_grid_shape,
310-
comm=comm,
311-
normalize=False, # why?
308+
comm=comm
312309
)
313310

314311
# Global array

examples/plot_halo_nsconv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ def local_extent_from_slice(local_shape, local_slice, halo):
8686
dims=(n, ),
8787
halo=halo,
8888
proc_grid_shape=proc_grid_shape,
89-
comm=comm,
90-
normalize=True,
89+
comm=comm
9190
)
9291

9392
# Distributed array

pylops_mpi/basicoperators/Halo.py

Lines changed: 21 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pylops.utils.backend import get_module
66
from pylops_mpi import MPILinearOperator, DistributedArray, Partition
7+
from pylops_mpi.Distributed import DistributedMixIn
78

89

910
def halo_block_split(global_shape: tuple, comm, grid_shape: tuple = None) -> tuple:
@@ -31,20 +32,18 @@ def halo_block_split(global_shape: tuple, comm, grid_shape: tuple = None) -> tup
3132
return tuple(slices)
3233

3334

34-
class MPIHalo(MPILinearOperator):
35+
class MPIHalo(DistributedMixIn, MPILinearOperator):
3536
def __init__(
3637
self,
3738
dims: tuple,
3839
halo,
3940
proc_grid_shape: tuple = None,
4041
comm: MPI.Comm = MPI.COMM_WORLD,
4142
dtype=np.float64,
42-
normalize: bool = False,
4343
):
4444
self.global_dims = tuple(dims)
4545
self.ndim = len(dims)
4646

47-
self.normalize = normalize
4847
self.comm = comm
4948
self.dtype = dtype
5049

@@ -55,13 +54,16 @@ def __init__(
5554

5655
self.local_dims = self._calc_local_dims()
5756
self.local_extent = self._calc_local_extent()
58-
self._local_extent_sizes = self.comm.allgather(int(np.prod(self.local_extent)))
57+
self._local_extent_sizes = self._allgather(
58+
self.comm,
59+
None,
60+
int(np.prod(self.local_extent)),
61+
)
5962

6063
self.shape = (
6164
int(np.sum(self._local_extent_sizes)),
6265
int(np.prod(self.global_dims)),
6366
)
64-
self._norm_factors = self._compute_norm_factors() if self.normalize else None
6567
super().__init__(shape=self.shape, dtype=np.dtype(dtype), base_comm=comm)
6668

6769
def _parse_halo(self, h):
@@ -117,26 +119,6 @@ def _calc_local_extent(self):
117119
ext.append(self.local_dims[ax] + minus_halo + plus_halo)
118120
return tuple(ext)
119121

120-
def _compute_norm_factors(self):
121-
weights = np.ones(self.local_dims, dtype=np.float64)
122-
for ax in range(self.ndim):
123-
before, after = self.halo[2 * ax], self.halo[2 * ax + 1]
124-
if before == 0 and after == 0:
125-
continue
126-
n_local = self.local_dims[ax]
127-
factors = np.ones(n_local, dtype=np.float64)
128-
minus_nbr, plus_nbr = self.neigh[("-", ax)], self.neigh[("+", ax)]
129-
minus_copy = minus_nbr != MPI.PROC_NULL
130-
plus_copy = plus_nbr != MPI.PROC_NULL
131-
if before and minus_copy:
132-
factors[:before] += 1.0
133-
if after and plus_copy:
134-
factors[n_local - after:] += 1.0
135-
shape = [1] * self.ndim
136-
shape[ax] = n_local
137-
weights *= factors.reshape(shape)
138-
return (1.0 / np.sqrt(weights)).astype(self.dtype, copy=False)
139-
140122
def _exchange_along_axis(self, ncp, arr, axis, before, after):
141123
minus_nbr, plus_nbr = self.neigh[("-", axis)], self.neigh[("+", axis)]
142124
# slice definitions
@@ -162,46 +144,6 @@ def _exchange_along_axis(self, ncp, arr, axis, before, after):
162144
self.cart_comm.Sendrecv(snd, dest=plus_nbr, recvbuf=rcv, source=plus_nbr)
163145
arr[tuple(rcv_s)] = rcv
164146

165-
def _exchange_adjoint_along_axis(self, ncp, arr, axis, before, after):
166-
minus_nbr, plus_nbr = self.neigh[("-", axis)], self.neigh[("+", axis)]
167-
out = arr.copy()
168-
if before == 0 and after == 0:
169-
return out
170-
171-
def axis_slice(start, end):
172-
s = [slice(None)] * self.ndim
173-
s[axis] = slice(start, end)
174-
return tuple(s)
175-
176-
empty = axis_slice(0, 0)
177-
halo_minus = axis_slice(0, before) if before else None
178-
halo_plus = axis_slice(-after, None) if after else None
179-
core_minus = axis_slice(before, 2 * before) if before else None
180-
core_plus = axis_slice(-2 * after, -after) if after else None
181-
182-
# remove contributions from halo slabs along this axis (forward overwrites them)
183-
if halo_minus:
184-
out[halo_minus] = 0
185-
if halo_plus:
186-
out[halo_plus] = 0
187-
188-
# send right halo to plus neighbor, receive right halo from minus neighbor
189-
snd_slice = halo_plus or empty
190-
rcv_slice = halo_minus or empty
191-
rcv = ncp.zeros_like(arr[rcv_slice])
192-
self.cart_comm.Sendrecv(arr[snd_slice].copy(), dest=plus_nbr, recvbuf=rcv, source=minus_nbr)
193-
if core_minus:
194-
out[core_minus] += rcv
195-
196-
# send left halo to minus neighbor, receive left halo from plus neighbor
197-
snd_slice = halo_minus or empty
198-
rcv_slice = halo_plus or empty
199-
rcv = ncp.zeros_like(arr[rcv_slice])
200-
self.cart_comm.Sendrecv(arr[snd_slice].copy(), dest=minus_nbr, recvbuf=rcv, source=plus_nbr)
201-
if core_plus:
202-
out[core_plus] += rcv
203-
return out
204-
205147
def _matvec(self, x):
206148
ncp = get_module(x.engine)
207149
if x.partition != Partition.SCATTER:
@@ -211,12 +153,13 @@ def _matvec(self, x):
211153
global_shape=self.shape[0],
212154
partition=Partition.SCATTER,
213155
local_shapes=self._local_extent_sizes,
156+
base_comm=x.base_comm,
157+
base_comm_nccl=x.base_comm_nccl,
158+
engine=x.engine,
159+
dtype=self.dtype,
214160
)
215161

216162
core = x.local_array.reshape(self.local_dims)
217-
if self.normalize:
218-
norm = self._norm_factors if ncp is np else ncp.asarray(self._norm_factors)
219-
core = core * norm
220163
halo_arr = ncp.zeros(self.local_extent, dtype=self.dtype)
221164
# insert core
222165
core_slices = [
@@ -234,25 +177,16 @@ def _matvec(self, x):
234177
return y
235178

236179
def _rmatvec(self, x):
237-
ncp = get_module(x.engine)
238180
if x.partition != Partition.SCATTER:
239181
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
240-
241-
y = DistributedArray(global_shape=self.shape[1], partition=Partition.SCATTER)
242-
243-
arr = x.local_array.reshape(self.local_extent).copy()
244-
# adjoint of halo exchange (reverse axis order because of the Transpose)
245-
for ax in reversed(range(self.ndim)):
246-
before, after = self.halo[2 * ax], self.halo[2 * ax + 1]
247-
arr = self._exchange_adjoint_along_axis(ncp, arr, axis=ax, before=before, after=after)
248-
core_slices = [
249-
slice(left, left + ldim)
250-
for left, ldim in zip(self.halo[::2], self.local_dims)
251-
]
182+
res = DistributedArray(global_shape=self.shape[1],
183+
partition=Partition.SCATTER,
184+
base_comm=x.base_comm,
185+
base_comm_nccl=x.base_comm_nccl,
186+
engine=x.engine,
187+
dtype=self.dtype)
188+
arr = x.local_array.reshape(self.local_extent)
189+
core_slices = [slice(left, left + ldim) for left, ldim in zip(self.halo[::2], self.local_dims)]
252190
core = arr[tuple(core_slices)]
253-
if self.normalize:
254-
norm = self._norm_factors if ncp is np else ncp.asarray(self._norm_factors)
255-
core = core * norm
256-
257-
y[:] = core.ravel()
258-
return y
191+
res[:] = core.ravel()
192+
return res

tests/test_halo.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
else:
77
import numpy as np
88
backend = "numpy"
9-
import numpy as npp
109
from mpi4py import MPI
1110
import pytest
11+
import pylops
12+
from numpy.testing import assert_allclose
1213

1314
import pylops_mpi
1415
from pylops_mpi.basicoperators.Halo import MPIHalo
@@ -21,28 +22,24 @@
2122
np.cuda.Device(device_id).use()
2223

2324

24-
@pytest.mark.mpi(min_size=8)
25-
def test_halo_dottest_plot_config():
25+
@pytest.mark.mpi(min_size=2)
26+
def test_halo_composed_dottest_matches_serial():
2627
comm = MPI.COMM_WORLD
2728
size = comm.Get_size()
28-
p_prime = int(round(size ** (1 / 3)))
29-
if p_prime ** 3 != size:
30-
pytest.skip("MPI size must be a perfect cube for 3D halo grid")
31-
32-
gdim = (4 * p_prime, 4 * p_prime, 4 * p_prime)
33-
g_shape = (p_prime, p_prime, p_prime)
29+
nlocal = 16
30+
n = nlocal * size
3431
halo = 1
3532

3633
halo_op = MPIHalo(
37-
dims=gdim,
34+
dims=(n,),
3835
halo=halo,
39-
proc_grid_shape=g_shape,
36+
proc_grid_shape=(size,),
4037
comm=comm,
4138
dtype=np.float64,
4239
)
4340

4441
x_dist = pylops_mpi.DistributedArray(
45-
global_shape=npp.prod(gdim),
42+
global_shape=n,
4643
base_comm=comm,
4744
partition=pylops_mpi.Partition.SCATTER,
4845
engine=backend,
@@ -51,12 +48,40 @@ def test_halo_dottest_plot_config():
5148
x_dist[:] = np.random.normal(0.0, 1.0, x_dist.local_array.shape)
5249

5350
y_dist = pylops_mpi.DistributedArray(
54-
global_shape=halo_op.shape[0],
51+
global_shape=n,
5552
base_comm=comm,
5653
partition=pylops_mpi.Partition.SCATTER,
5754
engine=backend,
5855
dtype=np.float64,
5956
)
6057
y_dist[:] = np.random.normal(0.0, 1.0, y_dist.local_array.shape)
6158

62-
dottest(halo_op, x_dist, y_dist)
59+
local_extent = halo_op.local_extent[0]
60+
DOp = pylops.FirstDerivative(
61+
dims=local_extent,
62+
axis=0,
63+
kind="forward",
64+
dtype=np.float64,
65+
)
66+
DOp_dist = pylops_mpi.MPIBlockDiag([DOp], base_comm=comm, dtype=np.float64)
67+
Op_dist = halo_op.H @ DOp_dist @ halo_op
68+
69+
dottest(Op_dist, x_dist, y_dist, n, n)
70+
71+
y_dist = Op_dist @ x_dist
72+
y_adj_dist = Op_dist.H @ x_dist
73+
y = y_dist.asarray()
74+
y_adj = y_adj_dist.asarray()
75+
76+
x_global = x_dist.asarray()
77+
if rank == 0:
78+
DOp_serial = pylops.FirstDerivative(
79+
dims=n,
80+
axis=0,
81+
kind="forward",
82+
dtype=np.float64,
83+
)
84+
y_serial = DOp_serial @ x_global
85+
y_adj_serial = DOp_serial.H @ x_global
86+
assert_allclose(y, y_serial, rtol=1e-14)
87+
assert_allclose(y_adj, y_adj_serial, rtol=1e-14)

0 commit comments

Comments
 (0)