Skip to content

Commit dbf0959

Browse files
sfarrenschaithyagrchaithyagrpaquiteau
authored
Version 1.6.1 patch (#222)
* Add support for tensorflow backend which allows for differentiability (#112) * Added support for tensorflow * Updates to get tests passing * Or --> And * Moving modopt to allow working with tensorflow * Fix issues with wos * Fix all flakes finally! * Update modopt/base/backend.py Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update modopt/base/backend.py Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Minute updates to codes * Add dynamic module * Fix docu * Fix PEP Co-authored-by: chaithyagr <chaithyagr@gitlab.com> Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Fix 115 (#116) * Fix issues * Add right tests * Fix PEP Co-authored-by: chaithyagr <chaithyagr@gitlab.com> * Minor bug fix, remove elif (#124) Co-authored-by: chaithyagr <chaithyagr@gitlab.com> * Add tests for modopt.base.backend and fix minute bug uncovered (#126) * Minor bug fix, remove elif * Add tests for backend * Fix tests * Add tests * Remove cupy * PEP fixes * Fix PEP * Fix PEP and update * Final PEP * Update setup.cfg Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update test_base.py Co-authored-by: chaithyagr <chaithyagr@gitlab.com> Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Release cleanup (#128) * updated GPU dependencies * added logo to manifest * updated package version and release date * Unpin package dependencies (#189) * unpinned dependencies * updated pinned documentation dependency versions * Add Gradient descent algorithms (#196) * Version 1.5.1 patch release (#114) * Add support for tensorflow backend which allows for differentiability (#112) * Added support for tensorflow * Updates to get tests passing * Or --> And * Moving modopt to allow working with tensorflow * Fix issues with wos * Fix all flakes finally! * Update modopt/base/backend.py Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update modopt/base/backend.py Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Minute updates to codes * Add dynamic module * Fix docu * Fix PEP Co-authored-by: chaithyagr <chaithyagr@gitlab.com> Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Fix 115 (#116) * Fix issues * Add right tests * Fix PEP Co-authored-by: chaithyagr <chaithyagr@gitlab.com> * Minor bug fix, remove elif (#124) Co-authored-by: chaithyagr <chaithyagr@gitlab.com> * Add tests for modopt.base.backend and fix minute bug uncovered (#126) * Minor bug fix, remove elif * Add tests for backend * Fix tests * Add tests * Remove cupy * PEP fixes * Fix PEP * Fix PEP and update * Final PEP * Update setup.cfg Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update test_base.py Co-authored-by: chaithyagr <chaithyagr@gitlab.com> Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Release cleanup (#128) * updated GPU dependencies * added logo to manifest * updated package version and release date Co-authored-by: Chaithya G R <chaithyagr@gmail.com> Co-authored-by: chaithyagr <chaithyagr@gitlab.com> * make algorithms a module. * add Gradient Descent Algorithms * enforce WPS compliance. * add test for gradient descent * Docstrings improvements * Add See Also and minor corrections * add idx initialisation for all algorithms. * fix merge error * fix typo Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> Co-authored-by: Chaithya G R <chaithyagr@gmail.com> Co-authored-by: chaithyagr <chaithyagr@gitlab.com> * Release cleanup (#198) * started clean up for next release * update progress * further clean up * additional clean up * cleaned up link to logo * fixed index.rst * fixed conflict * Fast Singular Value Thresholding (#209) * add SingularValueThreshold This Method provides 10x faster SVT estimation than the LowRankMatrix Operator. * linting * add test for fast computation. * flake8 compliance * Ignore DAR000 Error. * Update modopt/signal/svd.py tuples in docstring Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update modopt/signal/svd.py typo Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update modopt/opt/proximity.py typo Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * update docstring * fix isort * Update modopt/signal/svd.py Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * Update modopt/signal/svd.py Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * run isort Co-authored-by: Samuel Farrens <samuel.farrens@gmail.com> * added writeable input data array feature for benchopt (#213) * removed flake8 limit Co-authored-by: Chaithya G R <chaithyagr@gmail.com> Co-authored-by: chaithyagr <chaithyagr@gitlab.com> Co-authored-by: Pierre-Antoine Comby <77174042+paquiteau@users.noreply.github.com>
1 parent c3db304 commit dbf0959

6 files changed

Lines changed: 105 additions & 7 deletions

File tree

develop.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
coverage>=5.5
2-
flake8<4
2+
flake8>=4
33
nose>=1.3.7
44
pytest>=6.2.2
55
pytest-cov>=2.11.1

modopt/opt/gradient.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class GradParent(object):
3434
Method for calculating the cost (default is ``None``)
3535
data_type : type, optional
3636
Expected data type of the input data (default is ``None``)
37+
input_data_writeable: bool, optional
38+
Option to make the observed data writeable (default is ``False``)
3739
verbose : bool, optional
3840
Option for verbose output (default is ``True``)
3941
@@ -66,10 +68,12 @@ def __init__(
6668
get_grad=None,
6769
cost=None,
6870
data_type=None,
71+
input_data_writeable=False,
6972
verbose=True,
7073
):
7174

7275
self.verbose = verbose
76+
self._input_data_writeable = input_data_writeable
7377
self._grad_data_type = data_type
7478
self.obs_data = input_data
7579
self.op = op
@@ -102,7 +106,7 @@ def obs_data(self, input_data):
102106
check_npndarray(
103107
input_data,
104108
dtype=self._grad_data_type,
105-
writeable=False,
109+
writeable=self._input_data_writeable,
106110
verbose=self.verbose,
107111
)
108112

modopt/opt/proximity.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from modopt.math.matrix import nuclear_norm
2929
from modopt.signal.noise import thresh
3030
from modopt.signal.positivity import positive
31-
from modopt.signal.svd import svd_thresh, svd_thresh_coef
31+
from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast
3232

3333

3434
class ProximityParent(object):
@@ -237,6 +237,9 @@ class LowRankMatrix(ProximityParent):
237237
lowr_type : {'standard', 'ngole'}
238238
Low-rank implementation (options are 'standard' or 'ngole', default is
239239
'standard')
240+
initial_rank: int, optional
241+
Initial guess of the rank of future input_data.
242+
If provided this will save computation time.
240243
operator : class
241244
Operator class ('ngole' only)
242245
@@ -268,6 +271,7 @@ def __init__(
268271
threshold,
269272
thresh_type='soft',
270273
lowr_type='standard',
274+
initial_rank=None,
271275
operator=None,
272276
):
273277

@@ -277,8 +281,9 @@ def __init__(
277281
self.operator = operator
278282
self.op = self._op_method
279283
self.cost = self._cost_method
284+
self.rank = initial_rank
280285

281-
def _op_method(self, input_data, extra_factor=1.0):
286+
def _op_method(self, input_data, extra_factor=1.0, rank=None):
282287
"""Operator.
283288
284289
This method returns the input data after the singular values have been
@@ -290,22 +295,37 @@ def _op_method(self, input_data, extra_factor=1.0):
290295
Input data array
291296
extra_factor : float
292297
Additional multiplication factor (default is ``1.0``)
298+
rank: int, optional
299+
Estimation of the rank to save computation time in standard mode,
300+
if not set an internal estimation is used.
293301
294302
Returns
295303
-------
296304
numpy.ndarray
297305
SVD thresholded data
298306
307+
Raises
308+
------
309+
ValueError
310+
if lowr_type is not in ``{'standard', 'ngole'}``
299311
"""
300312
# Update threshold with extra factor.
301313
threshold = self.thresh * extra_factor
302-
303-
if self.lowr_type == 'standard':
314+
if self.lowr_type == 'standard' and self.rank is None and rank is None:
304315
data_matrix = svd_thresh(
305316
cube2matrix(input_data),
306317
threshold,
307318
thresh_type=self.thresh_type,
308319
)
320+
elif self.lowr_type == 'standard':
321+
data_matrix, update_rank = svd_thresh_coef_fast(
322+
cube2matrix(input_data),
323+
threshold,
324+
n_vals=rank or self.rank,
325+
extra_vals=5,
326+
thresh_type=self.thresh_type,
327+
)
328+
self.rank = update_rank # save for future use
309329

310330
elif self.lowr_type == 'ngole':
311331
data_matrix = svd_thresh_coef(
@@ -314,6 +334,8 @@ def _op_method(self, input_data, extra_factor=1.0):
314334
threshold,
315335
thresh_type=self.thresh_type,
316336
)
337+
else:
338+
raise ValueError('lowr_type should be standard or ngole')
317339

318340
# Return updated data.
319341
return matrix2cube(data_matrix, input_data.shape[1:])

modopt/signal/svd.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
from scipy.linalg import svd
13+
from scipy.sparse.linalg import svds
1314

1415
from modopt.base.transform import matrix2cube
1516
from modopt.interface.errors import warn
@@ -200,6 +201,64 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
200201
return np.dot(u_vec, np.dot(s_new, v_vec))
201202

202203

204+
def svd_thresh_coef_fast(
205+
input_data,
206+
threshold,
207+
n_vals=-1,
208+
extra_vals=5,
209+
thresh_type='hard',
210+
):
211+
"""Threshold the singular values coefficients.
212+
213+
This method thresholds the input data by using singular value
214+
decomposition, but only computing the the greastest ``n_vals``
215+
values.
216+
217+
Parameters
218+
----------
219+
input_data : numpy.ndarray
220+
Input data array, 2D matrix
221+
Operator class instance
222+
threshold : float or numpy.ndarray
223+
Threshold value(s)
224+
n_vals: int, optional
225+
Number of singular values to compute.
226+
If None, compute all singular values.
227+
extra_vals: int, optional
228+
If the number of values computed is not enough to perform thresholding,
229+
recompute by using ``n_vals + extra_vals`` (default is ``5``)
230+
thresh_type : {'hard', 'soft'}
231+
Type of noise to be added (default is ``'hard'``)
232+
233+
Returns
234+
-------
235+
tuple
236+
The thresholded data (numpy.ndarray) and the estimated rank after
237+
thresholding (int)
238+
"""
239+
if n_vals == -1:
240+
n_vals = min(input_data.shape) - 1
241+
ok = False
242+
while not ok:
243+
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
244+
ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
245+
n_vals = min(n_vals + extra_vals, *input_data.shape)
246+
247+
s_values = thresh(
248+
s_values,
249+
threshold,
250+
threshold_type=thresh_type,
251+
)
252+
rank = np.count_nonzero(s_values)
253+
return (
254+
np.dot(
255+
u_vec[:, -rank:] * s_values[-rank:],
256+
v_vec[-rank:, :],
257+
),
258+
rank,
259+
)
260+
261+
203262
def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
204263
"""Threshold the singular values coefficients.
205264

modopt/tests/test_opt.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,11 @@ def setUp(self):
675675
weights,
676676
)
677677
self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard')
678+
self.lowrank_rank = proximity.LowRankMatrix(
679+
10.0,
680+
initial_rank=1,
681+
thresh_type='hard',
682+
)
678683
self.lowrank_ngole = proximity.LowRankMatrix(
679684
10.0,
680685
lowr_type='ngole',
@@ -763,6 +768,8 @@ def tearDown(self):
763768
self.positivity = None
764769
self.sparsethresh = None
765770
self.lowrank = None
771+
self.lowrank_rank = None
772+
self.lowrank_ngole = None
766773
self.combo = None
767774
self.data1 = None
768775
self.data2 = None
@@ -841,6 +848,11 @@ def test_low_rank_matrix(self):
841848
err_msg='Incorrect low rank operation: standard',
842849
)
843850

851+
npt.assert_almost_equal(
852+
self.lowrank_rank.op(self.data3),
853+
self.data4,
854+
err_msg='Incorrect low rank operation: standard with rank',
855+
)
844856
npt.assert_almost_equal(
845857
self.lowrank_ngole.op(self.data3),
846858
self.data5,

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ per-file-ignores =
5858
#Justification: Needed to import matplotlib.pyplot
5959
modopt/plot/cost_plot.py: N802,WPS301
6060
#Todo: Investigate possible bug in find_n_pc function
61-
modopt/signal/svd.py: WPS345
61+
#Todo: Investigate darglint error
62+
modopt/signal/svd.py: WPS345, DAR000
6263
#Todo: Check security of using system executable call
6364
modopt/signal/wavelet.py: S404,S603
6465
#Todo: Clean up tests

0 commit comments

Comments
 (0)