@@ -41,15 +41,23 @@ component_types(::Type{SeparableSum{T}}) where T = fieldtypes(T)
4141
4242(g:: SeparableSum )(xs:: Tuple ) = sum (f (x) for (f, x) in zip (g. fs, xs))
4343
44- prox! (ys:: Tuple , fs :: Tuple , xs:: Tuple , gamma:: Number ) = sum (prox! (y, f, x, gamma) for (y, f, x) in zip (ys, fs, xs))
44+ prox! (ys:: Tuple , g :: SeparableSum , xs:: Tuple , gamma:: Number ) = sum (prox! (y, f, x, gamma) for (y, f, x) in zip (ys, g . fs, xs))
4545
46- prox! (ys:: Tuple , fs :: Tuple , xs:: Tuple , gammas:: Tuple ) = sum (prox! (y, f, x, gamma) for (y, f, x, gamma) in zip (ys, fs, xs, gammas))
46+ prox! (ys:: Tuple , g :: SeparableSum , xs:: Tuple , gammas:: Tuple ) = sum (prox! (y, f, x, gamma) for (y, f, x, gamma) in zip (ys, g . fs, xs, gammas))
4747
48- prox! (ys:: Tuple , g:: SeparableSum , xs:: Tuple , gamma) = prox! (ys, g. fs, xs, gamma)
48+ function prox (g:: SeparableSum , xs:: Tuple , gamma= 1 )
49+ ys = similar .(xs)
50+ fys = prox! (ys, g, xs, gamma)
51+ return ys, fys
52+ end
4953
50- gradient! (grads:: Tuple , fs :: Tuple , xs:: Tuple ) = sum (gradient! (grad, f, x) for (grad, f, x) in zip (grads, fs, xs))
54+ gradient! (grads:: Tuple , g :: SeparableSum , xs:: Tuple ) = sum (gradient! (grad, f, x) for (grad, f, x) in zip (grads, g . fs, xs))
5155
52- gradient! (grads:: Tuple , g:: SeparableSum , xs:: Tuple ) = gradient! (grads, g. fs, xs)
56+ function gradient (g:: SeparableSum , xs:: Tuple )
57+ ys = similar .(xs)
58+ fxs = gradient! (ys, g, xs)
59+ return ys, fxs
60+ end
5361
5462function prox_naive (f:: SeparableSum , xs:: Tuple , gamma)
5563 fys = real (eltype (xs[1 ]))(0 )
0 commit comments