Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,34 @@ def ravel(self, order: Optional[str] = "C"):
x = local_array.copy()
arr[:] = x
return arr

def reshape(self, local_shape, axis=0):
"""Return a reshaped DistributedArray

Parameters
----------

Returns
-------
arr : :obj:`pylops_mpi.DistributedArray`
Reshaped N-D DistributedArray
"""
local_shapes = self.base_comm.allgather(local_shape)
global_shape = list(local_shapes[0])
global_shape[axis] = np.sum([ls[axis] for ls in local_shapes])
arr = DistributedArray(global_shape=tuple(global_shape),
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
local_shapes=local_shapes,
mask=self.mask,
partition=self.partition,
engine=self.engine,
dtype=self.dtype)
local_array = self.local_array.reshape(local_shapes[self.rank])
x = local_array.copy()
arr[:] = x
return arr

def empty_like(self):
"""Creates an empty like DistributedArray with uninitialized values
"""
Expand Down
56 changes: 48 additions & 8 deletions pylops_mpi/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy.sparse._sputils import isintlike
from scipy.sparse.linalg._interface import _get_dtype

from pylops import get_ndarray_multiplication
from pylops import LinearOperator
from pylops.utils import DTypeLike, ShapeLike

Expand Down Expand Up @@ -150,11 +151,34 @@ def dot(self, x):
elif np.isscalar(x):
return _ScaledLinearOperator(self, x)
else:
if x is None or x.ndim == 1:
return self.matvec(x)
if not get_ndarray_multiplication() and x.ndim >= 2:
msg = (
"Operator can only be applied to 1D vectors. "
"Enable ndarray multiplication with pylops.set_ndarray_multiplication(True)."
)
raise ValueError(msg)
# current workaround as dims is not guaranteed to be set in a MPILinearOperator
# TODO: make dims/dimsd first-class attributes of MPILinearOperator as for
# PyLops LinearOperator
is_dims_shaped = x.global_shape == getattr(self, "dims", (1, ))
if is_dims_shaped:
# (dims1, ..., dimsK) => (dims1 * ... * dimsK,) == self.shape
x = x.ravel()
if x.ndim == 1:
y = self.matvec(x)
if (
is_dims_shaped
and get_ndarray_multiplication()
):
y = y.reshape(self._dims_dimsd_local)
return y
else:
raise ValueError('expected 1-d DistributedArray, got %r'
% x.global_shape)
msg = (
"Wrong shape.\nExpects either a 1d array or, an ndarray of "
f"size `dims` when `dims` and `dimsd` both are available.\n"
f"Instead, received an array of shape {x.shape}."
)
raise ValueError(msg)

def adjoint(self):
"""Adjoint MPI LinearOperator
Expand Down Expand Up @@ -216,10 +240,20 @@ def __sub__(self, x):
return self.__add__(-x)

def _adjoint(self):
return _AdjointLinearOperator(self)
Op = _AdjointLinearOperator(self)
if hasattr(self, "dimsd"):
Op.dims = self.dimsd
if hasattr(self, "dims"):
Op.dimsd = self.dims
return Op

def _transpose(self):
return _TransposedLinearOperator(self)
Op = _TransposedLinearOperator(self)
if hasattr(self, "dimsd"):
Op.dims = self.dimsd
if hasattr(self, "dims"):
Op.dimsd = self.dims
return Op

def conj(self):
"""Complex conjugate operator
Expand Down Expand Up @@ -251,10 +285,16 @@ def __init__(self, A: MPILinearOperator):
base_comm=MPI.COMM_WORLD)

def _matvec(self, x: DistributedArray) -> DistributedArray:
return self.A.rmatvec(x)
y = self.A.rmatvec(x)
# Inherit current dims/dimsd from A
self._dims_dimsd_local = self.A._dims_dimsd_local
return y

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
return self.A.matvec(x)
y = self.A.matvec(x)
# Inherit current dims/dimsd from A
self._dims_dimsd_local = self.A._dims_dimsd_local
return y


class _TransposedLinearOperator(MPILinearOperator):
Expand Down
5 changes: 4 additions & 1 deletion pylops_mpi/basicoperators/FirstDerivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def __init__(self,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = np.float64):
self.dims = _value_or_sized_to_tuple(dims)
shape = (int(np.prod(dims)),) * 2
self.dimsd = self.dims
shape = (int(np.prod(self.dims)),) * 2
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
self.sampling = sampling
self.kind = kind
Expand Down Expand Up @@ -130,12 +131,14 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
# If Partition.BROADCAST, then convert to Partition.SCATTER
if x.partition is Partition.BROADCAST:
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
self._dims_dimsd_local = (self.dimsd[0] // x.size, *self.dimsd[1:])
return self._hmatvec(x)

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
# If Partition.BROADCAST, then convert to Partition.SCATTER
if x.partition is Partition.BROADCAST:
x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl)
self._dims_dimsd_local = (self.dims[0] // x.size, *self.dims[1:])
return self._hrmatvec(x)

@reshaped
Expand Down
Loading