Skip to content

Commit ae03972

Browse files
MaxenceGollierdpo
andauthored
Shifted Composite L2 Norm (#124)
Co-authored-by: Dominique <dominique.orban@gmail.com>
1 parent ab012a2 commit ae03972

6 files changed

Lines changed: 397 additions & 1 deletion

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,21 @@ version = "0.2.1"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
OpenBLAS32_jll = "656ef2d0-ae68-5445-9ca0-591084a874a2"
99
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
10+
QRMumps = "422b30a1-cc69-4d85-abe7-cc07b540c444"
1011
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
12+
SparseMatricesCOO = "fa32481b-f100-4b48-8dc8-c62f61b13870"
1113
libblastrampoline_jll = "8e850b90-86db-534c-a0d3-1478176c7d93"
1214

1315
[compat]
1416
OpenBLAS32_jll = "0.3.9"
1517
ProximalOperators = "0.15"
18+
QRMumps = "^0.3.0"
1619
Roots = "^1.0.0"
1720
julia = "^1.3.0"
1821

1922
[extras]
2023
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
24+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2125

2226
[targets]
23-
test = ["Test"]
27+
test = ["Test", "SparseArrays"]

src/ShiftedProximalOperators.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module ShiftedProximalOperators
22

33
using LinearAlgebra
4+
using QRMumps
5+
using SparseMatricesCOO
46

57
using libblastrampoline_jll
68
using OpenBLAS32_jll
@@ -25,16 +27,22 @@ import ProximalOperators.prox, ProximalOperators.prox!
2527

2628
"Abstract type for shifted proximable functions."
2729
abstract type ShiftedProximableFunction end
30+
abstract type CompositeProximableFunction end
31+
32+
abstract type AbstractCompositeNorm <: CompositeProximableFunction end
33+
abstract type ShiftedCompositeProximableFunction <: ShiftedProximableFunction end
2834

2935
include("utils.jl")
3036
include("psvd.jl")
3137

38+
include("compositeNormL2.jl")
3239
include("rootNormLhalf.jl")
3340
include("groupNormL2.jl")
3441
include("Rank.jl")
3542
include("cappedl1.jl")
3643
include("Nuclearnorm.jl")
3744

45+
include("shiftedCompositeNormL2.jl")
3846
include("shiftedNormL0.jl")
3947
include("shiftedNormL0Box.jl")
4048
include("shiftedRootNormLhalf.jl")
@@ -56,6 +64,17 @@ function (ψ::ShiftedProximableFunction)(y)
5664
return ψ.h.xsy)
5765
end
5866

67+
function::ShiftedCompositeProximableFunction)(y)
68+
mul!.g, ψ.A, y)
69+
ψ.g .+= ψ.b
70+
return ψ.h.g)
71+
end
72+
73+
function::CompositeProximableFunction)(y)
74+
ψ.c!.b, y)
75+
ψ.h.b)
76+
end
77+
5978
"""
6079
shift!(ψ, x)
6180
@@ -70,6 +89,12 @@ function shift!(ψ::ShiftedProximableFunction, shift::AbstractVector{R}) where {
7089
return ψ
7190
end
7291

92+
function shift!::ShiftedCompositeProximableFunction, shift::AbstractVector{R}) where {R <: Real}
93+
ψ.c!.b, shift)
94+
ψ.J!.A, shift)
95+
return ψ
96+
end
97+
7398
"""
7499
set_radius!(ψ, Δ)
75100

src/compositeNormL2.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Composition of the L2 norm with a function
2+
export CompositeNormL2
3+
4+
@doc raw"""
5+
CompositeNormL2(λ, c!, J!, A, b)
6+
7+
Returns function `c` composed with the `ℓ₂` norm:
8+
```math
9+
f(x) = λ ‖c(x)‖₂
10+
```
11+
where `λ > 0`. `c!` and `J!` should implement functions
12+
```math
13+
c : ℝⁿ ↦ ℝᵐ,
14+
```
15+
```math
16+
J : ℝⁿ ↦ ℝᵐˣⁿ,
17+
```
18+
such that `J` is the Jacobian of `c`. It is expected that `m ≤ n`.
19+
`A` and `b` should respectively be a matrix and a vector which can respectively store the values of `J` and `c`.
20+
`A` is expected to be sparse, `c!` and `J!` should have signatures
21+
```
22+
c!(b <: AbstractVector{Real}, xk <: AbstractVector{Real})
23+
J!(A <: AbstractSparseMatrixCOO{Real, Integer}, xk <: AbstractVector{Real})
24+
```
25+
"""
26+
mutable struct CompositeNormL2{
27+
T <: Real,
28+
F0 <: Function,
29+
F1 <: Function,
30+
M <: AbstractMatrix{T},
31+
V <: AbstractVector{T},
32+
} <: AbstractCompositeNorm
33+
h::NormL2{T}
34+
c!::F0
35+
J!::F1
36+
A::M
37+
b::V
38+
39+
function CompositeNormL2(
40+
λ::T,
41+
c!::Function,
42+
J!::Function,
43+
A::AbstractMatrix{T},
44+
b::AbstractVector{T},
45+
) where {T <: Real}
46+
λ > 0 || error("CompositeNormL2: λ should be positive")
47+
length(b) == size(A, 1) || error("Composite Norm L2: Wrong input dimensions, the length of c(x) should be the same as the number of rows of J(x)")
48+
new{T, typeof(c!), typeof(J!), typeof(A), typeof(b)}(NormL2(λ), c!, J!, A, b)
49+
end
50+
end
51+
52+
fun_name(f::CompositeNormL2) = "ℓ₂ norm of the function c"
53+
fun_dom(f::CompositeNormL2) = "AbstractVector{Real}"
54+
fun_expr(f::CompositeNormL2{T, F0, F1, M, V}) where {T <: Real, F0 <: Function, F1 <: Function, M <: AbstractMatrix{T}, V <: AbstractVector{T}} = "x ↦ λ ‖c(x)‖₂"

src/shiftedCompositeNormL2.jl

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
export ShiftedCompositeNormL2
2+
@doc raw"""
3+
ShiftedCompositeNormL2(h, c!, J!, A, b)
4+
5+
Returns the shift of a function `c` composed with the `ℓ₂` norm (see CompositeNormL2.jl).
6+
Here, `c` is linearized i.e, `c(x + s) ≈ c(x) + J(x)s`.
7+
```math
8+
f(s) = λ ‖c(x) + J(x)s‖₂,
9+
```
10+
where `λ > 0`. `c!` and `J!` should implement functions
11+
```math
12+
c : ℝⁿ ↦ ℝᵐ,
13+
```
14+
```math
15+
J : ℝⁿ ↦ ℝᵐˣⁿ,
16+
```
17+
such that `J` is the Jacobian of `c`. It is expected that `m ≤ n`.
18+
`A` and `b` should respectively be a matrix and a vector which can respectively store the values of `J` and `c`.
19+
`A` is expected to be sparse, `c!` and `J!` should have signatures
20+
```
21+
c!(b <: AbstractVector{Real}, xk <: AbstractVector{Real})
22+
J!(A <: AbstractSparseMatrixCOO{Real, Integer}, xk <: AbstractVector{Real})
23+
```
24+
"""
25+
mutable struct ShiftedCompositeNormL2{
26+
T <: Real,
27+
F0 <: Function,
28+
F1 <: Function,
29+
M <: AbstractMatrix{T},
30+
V <: AbstractVector{T},
31+
} <: ShiftedCompositeProximableFunction
32+
h::NormL2{T}
33+
c!::F0
34+
J!::F1
35+
A::M
36+
shifted_spmat::qrm_shifted_spmat{T}
37+
spfct::qrm_spfct{T}
38+
b::V
39+
g::V # Preallocated vector used either to compute A*y + b when we call ψ(y) or the RHS of the dual of the proximal problem.
40+
q::V # Preallocated solution vector of the dual of the proximal problem.
41+
dq::V # Preallocated vector to refine the q solution.
42+
p::V # Preallocated vector used to compute s(α)ᵀ∇s(α) for the secular equation.
43+
dp::V # Preallocated vector used to refine the p vector.
44+
function ShiftedCompositeNormL2(
45+
λ::T,
46+
c!::Function,
47+
J!::Function,
48+
A::AbstractMatrix{T},
49+
b::AbstractVector{T},
50+
) where {T <: Real}
51+
p = similar(b, A.n + A.m)
52+
dp = similar(b, A.n + A.m)
53+
g = similar(b)
54+
q = similar(b)
55+
dq = similar(b)
56+
if length(b) != size(A, 1)
57+
error("ShiftedCompositeNormL2: Wrong input dimensions, there should be as many constraints as rows in the Jacobian")
58+
end
59+
60+
spmat = qrm_spmat_init(A; sym=false)
61+
shifted_spmat = qrm_shift_spmat(spmat)
62+
spfct = qrm_spfct_init(spmat)
63+
64+
new{T, typeof(c!), typeof(J!), typeof(A), typeof(b)}(NormL2(λ), c!, J!, A, shifted_spmat, spfct, b, g, q, dq, p, dp)
65+
end
66+
end
67+
68+
shifted(
69+
ψ::CompositeNormL2{T, F0, F1, M, V},
70+
xk::AbstractVector{T}
71+
) where {T <: Real, F0 <: Function, F1 <: Function, M <: AbstractMatrix{T}, V <: AbstractVector{T}} = begin
72+
b = similar.b)
73+
ψ.c!(b, xk)
74+
A = similar.A)
75+
ψ.J!(A, xk)
76+
ShiftedCompositeNormL2.h.lambda, ψ.c!, ψ.J!, A, b)
77+
end
78+
79+
fun_name::ShiftedCompositeNormL2) = "shifted `ℓ₂` norm"
80+
fun_expr::ShiftedCompositeNormL2) = "t ↦ ‖c(xk) + J(xk)t‖₂"
81+
fun_params::ShiftedCompositeNormL2) = "c(xk) = $(ψ.b)\n" * " "^14 * "J(xk) = $(ψ.A)\n"
82+
83+
function prox!(
84+
y::AbstractVector{T},
85+
ψ::ShiftedCompositeNormL2{T, F0, F1, M, V},
86+
q::AbstractVector{T},
87+
ν::T;
88+
max_iter = 10,
89+
atol = eps(T)^0.3,
90+
max_time = 180.0
91+
) where {T <: Real, F0 <: Function, F1 <:Function, M <: AbstractMatrix{T}, V <: AbstractVector{T}}
92+
93+
start_time = time()
94+
θ = T(0.8)
95+
α = zero(T)
96+
αmin = eps(T)^(0.9)
97+
98+
# Compute RHS g = -(A * q + b)
99+
mul!.g, ψ.A, q, -one(T), zero(T))
100+
ψ.g .-= ψ.b
101+
102+
# Retrieve qrm workspace
103+
shifted_spmat = ψ.shifted_spmat
104+
spmat = shifted_spmat.spmat
105+
spfct = ψ.spfct
106+
qrm_update_shift_spmat!(shifted_spmat, α)
107+
spmat.val[1:spmat.mat.nz - spmat.mat.m] .= ψ.A.vals
108+
qrm_spfct_init!(spfct, spmat)
109+
qrm_set(spfct, "qrm_keeph", 0) # Discard de Q matrix in all subsequent QR factorizations
110+
qrm_set(spfct, "qrm_rd_eps", eps(T)^(0.4)) # If a diagonal element of the R-factor is less than eps(R)^(0.4), we consider that A is rank defficient.
111+
112+
# Check interior convergence
113+
qrm_analyse!(spmat, spfct; transp='t')
114+
_obj_dot_grad!(spmat, spfct, ψ.p, ψ.q, ψ.g, ψ.dq)
115+
116+
# Check full-rankness
117+
full_row_rank = (qrm_get(spfct,"qrm_rd_num") == 0)
118+
if !full_row_rank
119+
# QRMumps cannot factorize rank-deficient matrices; use the Golub-Riley iteration instead
120+
α = αmin
121+
qrm_golub_riley!.shifted_spmat, spfct, ψ.p, ψ.g, ψ.dp, ψ.q, ψ.dq, transp = 't', α = α, tol = eps(T)^(0.75)) # Now, ψ.p = Aᵀψ.q and ψ.q = (AAᵀ)†b
122+
123+
# Compute residual
124+
qrm_spmat_mv!(spmat, T(1), ψ.q, T(0), ψ.dp, transp = 't')
125+
qrm_spmat_mv!(spmat, T(1), ψ.dp, T(0), ψ.dq, transp = 'n')
126+
@. ψ.dq = ψ.dq - ψ.g # In this case, ψ.dq = AAᵀψ.q - b. If ‖ψ.dq‖₂ is small, then g ∈ Range(AAᵀ)
127+
128+
if abs(norm.q) - ν*ψ.h.lambda) < atol && norm.dq) eps(T)^(0.5) # Check interior optimality and range of AAᵀ
129+
y .= ψ.p[1:length(y)]
130+
y .+= q
131+
return y
132+
end
133+
134+
# The solution is not α = 0, prepare root finding
135+
qrm_update_shift_spmat!(shifted_spmat, α)
136+
_obj_dot_grad!(spmat, spfct, ψ.p, ψ.q, ψ.g, ψ.dq)
137+
end
138+
139+
# Scalar Root finding
140+
k = 0
141+
elapsed_time = time() - start_time
142+
α₊ = α
143+
144+
norm_q = norm.q)
145+
νλ = ν * ψ.h.lambda
146+
if abs(norm_q - νλ) > atol
147+
while abs(norm_q - νλ) > atol && k < max_iter && elapsed_time < max_time
148+
149+
α₊ += (norm_q / νλ - 1) * (norm_q / norm.p))^2
150+
α = α₊ > 0 ? α₊ : θ*α
151+
α = α αmin ? αmin : α
152+
153+
qrm_update_shift_spmat!(shifted_spmat, α)
154+
155+
_obj_dot_grad!(spmat, spfct, ψ.p, ψ.q, ψ.g, ψ.dq)
156+
norm_q = norm.q)
157+
158+
α == αmin && break
159+
160+
k += 1
161+
elapsed_time = time() - start_time
162+
end
163+
end
164+
165+
k > max_iter && @warn "ShiftedCompositeNormL2: Newton method did not converge during prox computation returning with residual $(abs(norm.q) - ν*ψ.h.lambda)) instead"
166+
mul!(y, ψ.A', ψ.q)
167+
y .+= q
168+
return y
169+
end
170+
171+
# Utility function that computes in place both q = s(α) and p such that ‖p‖² = s(α)ᵀ∇s(α) for the secular equation.
172+
function _obj_dot_grad!(spmat :: qrm_spmat{T}, spfct :: qrm_spfct{T}, p :: AbstractVector{T}, q :: AbstractVector{T}, g :: AbstractVector{T}, dq :: AbstractVector{T}) where T
173+
qrm_factorize!(spmat, spfct, transp='t')
174+
qrm_solve!(spfct, g, p, transp='t')
175+
qrm_solve!(spfct, p, q, transp='n')
176+
qrm_refine!(spmat, spfct, q, g, dq, p)
177+
qrm_solve!(spfct, q, p, transp='t')
178+
end

0 commit comments

Comments
 (0)