Skip to content

Commit dae002d

Browse files
authored
Merge pull request #192 from rohanbabbar04/ista
ISTA/FISTA implementation in cls_sparsity
2 parents 901b706 + 4594c93 commit dae002d

18 files changed

Lines changed: 2372 additions & 21 deletions

docs/source/api/index.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,24 @@ Basic
106106
cg
107107
cgls
108108

109+
Sparsity
110+
~~~~~
111+
112+
.. currentmodule:: pylops_mpi.optimization.cls_sparsity
113+
114+
.. autosummary::
115+
:toctree: generated/
116+
117+
ISTA
118+
FISTA
119+
120+
.. currentmodule:: pylops_mpi.optimization.sparsity
121+
122+
.. autosummary::
123+
:toctree: generated/
124+
125+
ista
126+
fista
109127

110128
Utils
111129
-----

environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- numpy>=1.15.0
99
- scipy>=1.8.0
1010
- mpi4py
11-
- pylops
11+
- pylops>=2.0.0
1212
- matplotlib
1313
- ipython
1414
- pytest

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ dependencies:
66
- python>=3.8.0
77
- numpy>=1.15.0
88
- scipy>=1.8.0
9-
- pylops>=2.0
9+
- pylops>=2.0.0
1010
- matplotlib
1111
- mpi4py

pylops_mpi/DistributedArray.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,19 @@ def __setitem__(self, index, value):
205205
the specified index positions.
206206
"""
207207
if self.partition is Partition.BROADCAST:
208+
ncp = get_module(self.engine)
208209
view = self.local_array[index]
210+
buf = ncp.empty(view.shape, dtype=self.dtype)
209211
if self.rank == 0:
210-
view[...] = value
211-
view = self._bcast(self.base_comm, self.base_comm_nccl,
212-
view, root=0, engine=self.engine)
213-
self.local_array[index] = view
212+
buf[...] = value
213+
buf = self._bcast(
214+
self.base_comm,
215+
self.base_comm_nccl,
216+
buf,
217+
root=0,
218+
engine=self.engine
219+
)
220+
self.local_array[index] = buf
214221
else:
215222
self.local_array[index] = value
216223

@@ -644,8 +651,25 @@ def multiply(self, dist_array):
644651
ProductArray[:] = self.local_array * dist_array
645652
return ProductArray
646653

647-
def dot(self, dist_array):
648-
"""Distributed Dot Product
654+
def dot(self, dist_array, vdot: bool = False):
655+
"""
656+
Compute the distributed dot product between this array and another
657+
distributed array.
658+
659+
Parameters
660+
----------
661+
dist_array : :obj:`pylops_mpi.DistributedArray`
662+
The distributed array with which to compute the dot product.
663+
It must have a compatible shape and partitioning.
664+
vdot : bool, optional, Defaults to `False`
665+
If True, compute the complex conjugate dot product (like numpy.vdot),
666+
where the first argument is conjugated. If False, compute the standard
667+
dot product.
668+
669+
Returns
670+
-------
671+
result : float
672+
The result of the dot product across all ranks. This is reduced across all processes.
649673
"""
650674
self._check_partition_shape(dist_array)
651675
self._check_mask(dist_array)
@@ -656,8 +680,9 @@ def dot(self, dist_array):
656680
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
657681
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
658682
# Flatten the local arrays and calculate dot product
683+
dot_func = ncp.vdot if vdot else ncp.dot
659684
return self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
660-
ncp.dot(x.local_array.flatten(), y.local_array.flatten()),
685+
dot_func(x.local_array.flatten(), y.local_array.flatten()),
661686
engine=self.engine)
662687

663688
def _compute_vector_norm(self, local_array: NDArray,
@@ -839,6 +864,15 @@ def ravel(self, order: Optional[str] = "C"):
839864
arr[:] = x
840865
return arr
841866

867+
def empty_like(self):
868+
"""Creates an empty like DistributedArray with uninitialized values
869+
"""
870+
dist = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm,
871+
base_comm_nccl=self.base_comm_nccl, partition=self.partition,
872+
axis=self.axis, local_shapes=self.local_shapes, mask=self.mask,
873+
engine=self.engine, dtype=self.dtype)
874+
return dist
875+
842876
def add_ghost_cells(self, cells_front: Optional[int] = None,
843877
cells_back: Optional[int] = None):
844878
"""Add ghost cells to the DistributedArray along the axis
@@ -1038,13 +1072,30 @@ def multiply(self, stacked_array):
10381072
ProductArray[iarr][:] = (self[iarr] * stacked_array)[:]
10391073
return ProductArray
10401074

1041-
def dot(self, stacked_array):
1042-
"""Dot Product of Stacked Distributed Arrays
1075+
def dot(self, stacked_array, vdot: bool = False):
1076+
"""
1077+
Compute the distributed dot product between this array and another
1078+
distributed array.
1079+
1080+
Parameters
1081+
----------
1082+
stacked_array : :obj:`pylops_mpi.StackedDistributedArray`
1083+
The distributed array with which to compute the dot product.
1084+
It must have a compatible shape and partitioning.
1085+
vdot : bool, optional, Defaults to `False`
1086+
If True, compute the complex conjugate dot product (like numpy.vdot),
1087+
where the first argument is conjugated. If False, compute the standard
1088+
dot product.
1089+
1090+
Returns
1091+
-------
1092+
result : float
1093+
The result of the dot product across all ranks. This is reduced across all processes.
10431094
"""
10441095
self._check_stacked_size(stacked_array)
10451096
dotprod = 0.
10461097
for iarr in range(self.narrays):
1047-
dotprod += self[iarr].dot(stacked_array[iarr])
1098+
dotprod += self[iarr].dot(stacked_array[iarr], vdot=vdot)
10481099
return dotprod
10491100

10501101
def norm(self, ord: Optional[int] = None):
@@ -1085,6 +1136,19 @@ def copy(self):
10851136
arr = StackedDistributedArray([distarray.copy() for distarray in self.distarrays])
10861137
return arr
10871138

1139+
def empty_like(self):
1140+
"""Creates an empty like StackedDistributedArray with uninitialized values
1141+
"""
1142+
dists = []
1143+
for iarr in range(self.narrays):
1144+
distarray = self.distarrays[iarr]
1145+
dist = DistributedArray(global_shape=distarray.global_shape, base_comm=distarray.base_comm,
1146+
base_comm_nccl=distarray.base_comm_nccl, partition=distarray.partition,
1147+
axis=distarray.axis, local_shapes=distarray.local_shapes, mask=distarray.mask,
1148+
engine=distarray.engine, dtype=distarray.dtype)
1149+
dists.append(dist)
1150+
return StackedDistributedArray(distarrays=dists)
1151+
10881152
def __repr__(self):
10891153
repr_dist = "\n".join([distarray.__repr__() for distarray in self.distarrays])
10901154
return f"<StackedDistributedArray with {self.narrays} distributed arrays: \n" + repr_dist

pylops_mpi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from .plotting.plotting import *
1313
from .optimization.basic import *
14+
from .optimization.sparsity import *
1415

1516
try:
1617
from .version import version as __version__

pylops_mpi/optimization/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
99
cg Conjugate gradient
1010
cgls Conjugate gradient least-squares.
11-
11+
ista Iterative Shrinkage-Thresholding Algorithm
12+
fista Fast Iterative Shrinkage-Thresholding Algorithm
1213
"""

0 commit comments

Comments
 (0)