Skip to content

Commit 86c455c

Browse files
committed
Add PrecomposedSlicedSeparableSum implementation and corresponding tests
1 parent 287e9b4 commit 86c455c

4 files changed

Lines changed: 202 additions & 0 deletions

File tree

src/ProximalOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ include("calculus/regularize.jl")
9292
include("calculus/separableSum.jl")
9393
include("calculus/slicedSeparableSum.jl")
9494
include("calculus/reshapeInput.jl")
95+
include("calculus/precomposedSlicedSeparableSum.jl")
9596
include("calculus/sqrDistL2.jl")
9697
include("calculus/tilt.jl")
9798
include("calculus/translate.jl")
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Separable sum, using slices of an array as variables
2+
3+
export PrecomposedSlicedSeparableSum
4+
5+
"""
6+
precomposedSlicedSeparableSum((f_1, ..., f_k), (J_1, ..., J_k), (L_1, ..., L_k))
7+
8+
Return the function
9+
```math
10+
g(x) = \\sum_{i=1}^k f_i(L_i * x_{J_i}).
11+
```
12+
13+
precomposedSlicedSeparableSum(f, (J_1, ..., J_k), (L_1, ..., L_k))
14+
15+
Analogous to the previous one, but apply the same function `f` to all slices
16+
of the variable `x`:
17+
```math
18+
g(x) = \\sum_{i=1}^k f(L_i * x_{J_i}).
19+
```
20+
"""
21+
struct PrecomposedSlicedSeparableSum{S <: Tuple, T <: AbstractArray, U <: AbstractArray, V <: AbstractArray, N}
22+
fs::S # Tuple, where each element is a Vector with elements of the same type; the functions to prox on
23+
# Example: S = Tuple{Array{ProximalOperators.NormL1{Float64},1}, Array{ProximalOperators.NormL2{Float64},1}}
24+
idxs::T # Vector, where each element is a Vector containing the indices to prox on
25+
# Example: T = Array{Array{Tuple{Colon,UnitRange{Int64}},1},1}
26+
ops::U # Vector of operations (matrices or AbstractOperators) to apply to the function
27+
# Example: U = Array{Array{Matrix{Float64},1},1}
28+
μs::V # Vector of mu values for each function
29+
end
30+
31+
function PrecomposedSlicedSeparableSum(fs::Tuple, idxs::Tuple, ops::Tuple, μs::Tuple)
32+
@assert length(fs) == length(idxs)
33+
@assert length(fs) == length(ops)
34+
ftypes = DataType[]
35+
fsarr = Array{Any,1}[]
36+
indarr = Array{eltype(idxs),1}[]
37+
opsarr = Array{Any,1}[]
38+
μsarr = Array{Any,1}[]
39+
for (i,f) in enumerate(fs)
40+
t = typeof(f)
41+
fi = findfirst(isequal(t), ftypes)
42+
if fi === nothing
43+
push!(ftypes, t)
44+
push!(fsarr, Any[f])
45+
push!(indarr, eltype(idxs)[idxs[i]])
46+
push!(opsarr, Any[ops[i]])
47+
push!(μsarr, Any[μs[i]])
48+
else
49+
push!(fsarr[fi], f)
50+
push!(indarr[fi], idxs[i])
51+
push!(opsarr[fi], ops[i])
52+
push!(μsarr[fi], μs[i])
53+
end
54+
end
55+
fsnew = ((Array{typeof(fs[1]),1}(fs) for fs in fsarr)...,)
56+
@assert typeof(fsnew) == Tuple{(Array{ft,1} for ft in ftypes)...}
57+
PrecomposedSlicedSeparableSum{typeof(fsnew),typeof(indarr),typeof(opsarr),typeof(μsarr),length(fsnew)}(fsnew, indarr, opsarr, μsarr)
58+
end
59+
60+
# Constructor for the case where the same function is applied to all slices
61+
PrecomposedSlicedSeparableSum(f::F, idxs::T, ops::U, μs::V) where {F, T <: Tuple, U <: Tuple, V <: Tuple} =
62+
PrecomposedSlicedSeparableSum(Tuple(f for k in eachindex(idxs)), idxs, ops, μs)
63+
64+
# Unroll the loop over the different types of functions to evaluate
65+
function (f::PrecomposedSlicedSeparableSum)(x::Tuple)
66+
v = zero(eltype(x[1]))
67+
for (fs_group, idxs_group, ops_group) = zip(f.fs, f.idxs, f.ops) # For each function type
68+
for (fun, idx_group, hcat_op) in zip(fs_group, idxs_group, ops_group) # For each function of that type
69+
for (var_index, (x_var, idx)) in enumerate(zip(x, idx_group))
70+
if idx isa Tuple
71+
v += fun(hcat_op[var_index] * view(x_var, idx...))
72+
elseif idx isa Colon
73+
v += fun(hcat_op[var_index] * x_var)
74+
elseif idx isa Nothing
75+
# do nothing
76+
else
77+
v += fun(hcat_op[var_index] * view(x_var, idx))
78+
end
79+
end
80+
end
81+
end
82+
return v
83+
end
84+
85+
function slice_var(x, idx)
86+
if idx isa Tuple
87+
return view(x, idx...)
88+
elseif idx isa Colon
89+
return x
90+
else
91+
return view(x, idx)
92+
end
93+
end
94+
95+
# Unroll the loop over the different types of functions to prox on
96+
function prox!(y::Tuple, f::PrecomposedSlicedSeparableSum, x::Tuple, gamma)
97+
v = zero(eltype(x[1]))
98+
for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type
99+
for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type
100+
for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y)
101+
if idx isa Nothing
102+
continue
103+
end
104+
sliced_x = slice_var(x_var, idx)
105+
sliced_y = slice_var(y_var, idx)
106+
res = op * sliced_x
107+
prox_res, g = prox(fun, res, μ.*gamma)
108+
prox_res .-= res
109+
prox_res ./= μ
110+
mul!(sliced_y, adjoint(op), prox_res)
111+
sliced_y .+= sliced_x
112+
v += g
113+
end
114+
end
115+
end
116+
return v
117+
end
118+
119+
component_types(::Type{PrecomposedSlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S))
120+
121+
@generated is_proximable(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false)
122+
@generated is_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false)
123+
@generated is_set_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false)
124+
@generated is_singleton_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false)
125+
@generated is_cone_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false)
126+
@generated is_affine_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false)
127+
@generated is_smooth(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false)
128+
@generated is_generalized_quadratic(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false)
129+
@generated is_strongly_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false)
130+
131+
function prox_naive(f::PrecomposedSlicedSeparableSum, x, gamma)
132+
fy = 0
133+
y = similar.(x)
134+
for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type
135+
for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type
136+
for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y)
137+
if idx isa Nothing
138+
continue
139+
end
140+
sliced_x = slice_var(x_var, idx)
141+
sliced_y = slice_var(y_var, idx)
142+
res = op * sliced_x
143+
prox_res, _fy = prox_naive(fun, res, μ.*gamma)
144+
prox_res = (prox_res .- res) ./ μ
145+
mul!(sliced_y, adjoint(op), prox_res)
146+
fy += _fy
147+
sliced_y .+= sliced_x
148+
end
149+
end
150+
end
151+
return y, fy
152+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ end
155155
include("test_regularize.jl")
156156
include("test_separableSum.jl")
157157
include("test_slicedSeparableSum.jl")
158+
include("test_precomposedSlicedSeparableSum.jl")
158159
include("test_sum.jl")
159160
include("test_reshapeInput.jl")
160161
end
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using Test
2+
using Random
3+
using ProximalOperators
4+
using LinearAlgebra
5+
6+
Random.seed!(1234)
7+
8+
# x = (randn(10), randn(10))
9+
# norm(x[1], 1) + norm(A2[1:5, 1:5] * x[2][1:5], 2) + norm(A2[6:10, 6:10] * x[2][6:10], 2)^2
10+
11+
@testset "PrecomposedSlicedSeparableSum" begin
12+
13+
fs = (NormL1(), NormL2(), SqrNormL2())
14+
15+
A1 = (Diagonal(ones(10)), nothing)
16+
F = qr(randn(5, 5))
17+
A2 = (nothing, Matrix(F.Q))
18+
F = qr(randn(5, 5))
19+
A3 = (nothing, Matrix(F.Q))
20+
mu = rand(5)
21+
A3[2] .*= reshape(mu, 5, 1)
22+
ops = (A1, A2, A3)
23+
24+
idxs = ((Colon(), nothing), (nothing, 1:5), (nothing, 6:10))
25+
μs = (1.0, 1.0, mu)
26+
27+
AAc2 = A2[2] * A2[2]'
28+
@test AAc2 I
29+
AAc3 = A3[2] * A3[2]'
30+
@test AAc3 Diagonal(mu) .^ 2
31+
32+
f = PrecomposedSlicedSeparableSum(fs, idxs, ops, μs)
33+
x = (randn(10), rand(10))
34+
y = (zeros(10), zeros(10))
35+
fy = prox!(y, f, x, 1.0)
36+
yn, fyn = ProximalOperators.prox_naive(f, x, 1.0)
37+
y1, fy1 = prox(NormL1(), x[1], 1.0)
38+
y2, fy2 = prox(Precompose(NormL2(), A2[2], 1), x[2][1:5], 1.0)
39+
y3, fy3 = prox(Precompose(SqrNormL2(), A3[2], mu), x[2][6:10], 1.0)
40+
41+
@test abs(fyn-fy)<1e-11
42+
@test norm(yn[1]-y[1])+norm(yn[2]-y[2])<1e-11
43+
@test abs((fy1+fy2+fy3)-fy)<1e-11
44+
@test norm(y[1] - y1) < 1e-11
45+
@test norm(y[2][1:5] - y2) < 1e-11
46+
@test norm(y[2][6:10] - y3) < 1e-11
47+
48+
end

0 commit comments

Comments
 (0)