-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathindPolyhedralQPDAS.jl
More file actions
146 lines (121 loc) · 3.84 KB
/
indPolyhedralQPDAS.jl
File metadata and controls
146 lines (121 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# IndPolyhedral: QPDAS implementation
import QPDAS
"""
**Indicator of a polyhedral set**
IndPolyhedralQPDAS(A, C, b, d)
```math
S = \\{x : \\langle A, x \\rangle = b\\ ∧ \\langle C, x \\rangle \\le b\\}.
```
"""
struct IndPolyhedralQPDAS{R<:Real, MT<:AbstractMatrix{R}, VT<:AbstractVector{R}, QP<:QPDAS.QuadraticProgram} <: IndPolyhedral
A::MT
b::VT
C::MT
d::VT
z::VT
qp::QP
first_prox::Ref{Bool}
function IndPolyhedralQPDAS{R}(A::MT, b::VT, C::MT, d::VT) where {R<:Real, MT<:AbstractMatrix{R}, VT<:AbstractVector{R}, QP<:QPDAS.QuadraticProgram}
@assert size(A,1) == size(b,1)
qp = QPDAS.QuadraticProgram(A, b, C, d, smartstart=false)
new{R, MT, VT, typeof(qp)}(A, b, C, d, zeros(R,size(A,2)), qp, Ref(true))
end
end
# properties
is_prox_accurate(::IndPolyhedralQPDAS) = true
# constructors
function IndPolyhedralQPDAS(
l::AbstractVector{R}, A::AbstractMatrix{R}, u::AbstractVector{R}
) where R
if !all(l .<= u)
error("function is improper (are some bounds inverted?)")
end
eqinds = (l .== u) .& .!isnothing.(l)
Aeq = A[eqinds,:]
beq = l[eqinds]
_islower(l::T) where T =
l != typemin(T) && !isnan(l) && !isnothing(l)
_isupper(u::T) where T =
u != typemax(T) && !isnan(u) && !isnothing(u)
lower = _islower.(l) .& (.!eqinds)
upper = _isupper.(u) .& (.!eqinds)
lower_only = lower .& (.! upper)
upper_only = upper .& (.! lower)
upper_and_lower = upper .& lower
Cieq = [-A[lower_only, :];
A[upper_only, :];
-A[upper_and_lower, :];
A[upper_and_lower, :] ]
dieq = [-l[lower_only];
u[upper_only];
-l[upper_and_lower];
u[upper_and_lower] ]
IndPolyhedralQPDAS{R}(Aeq, beq, Cieq, dieq)
end
IndPolyhedralQPDAS(
l::AbstractVector{R}, A::AbstractMatrix{R}, u::AbstractVector{R},
xmin::AbstractVector{R}, xmax::AbstractVector{R}
) where R =
IndPolyhedralQPDAS([l; xmin], [A; I], [u; xmax])
IndPolyhedralQPDAS(
l::AbstractVector{R}, A::AbstractMatrix{R}, args...
) where R =
IndPolyhedralQPDAS(
l, A, R(Inf).*ones(R, size(A, 1)), args...
)
IndPolyhedralQPDAS(
A::AbstractMatrix{R}, u::AbstractVector{R}, args...
) where R =
IndPolyhedralQPDAS(
R(-Inf).*ones(R, size(A, 1)), A, u, args...
)
# function evaluation
function (f::IndPolyhedralQPDAS{R})(x::AbstractVector{R}) where R
Ax = f.A * x
Cx = f.C * x
return all(Ax .<= f.b .& Cx .<= f.d) ? R(0) : Inf
end
# prox
function prox!(y::AbstractVector{R}, f::IndPolyhedralQPDAS{R}, x::AbstractVector{R}, gamma::R=R(1)) where R
# Linear term in qp is -x
f.z .= .- x
# Update the problem
QPDAS.update!(f.qp, z=f.z)
if f.first_prox[]
# This sets the initial active set based on z, should only be run once
QPDAS.run_smartstart(f.qp.boxQP)
f.first_prox[] = false
end
sol, val = QPDAS.solve!(f.qp)
y .= sol
return R(0)
end
# naive prox
# we want to compute the projection p of a point x
#
# primal problem is: minimize_p (1/2)||p-x||^2 + g(Ap)
# where g is the indicator of the box [l, u]
#
# dual problem is: minimize_y (1/2)||-A'y||^2 - x'A'y + g*(y)
# can solve with (fast) dual proximal gradient method
function prox_naive(f::IndPolyhedralQPDAS{R}, x::AbstractVector{R}, gamma::R=R(1)) where R
# Rewrite to l ≤ Ax ≤ u
A = [f.A; f.C]
l = [f.b; fill(R(-Inf), length(f.d))]
u = [f.b; f.d]
y = zeros(R, size(A, 1)) # dual vector
y1 = y
g = IndBox(l, u)
gstar = Conjugate(g)
gstar_y = R(0)
stepsize = R(1)/opnorm(Matrix(A*A'))
for it = 1:1e6
w = y + (it-1)/(it+2)*(y - y1)
y1 = y
z = w - stepsize * (A * (A'*w - x))
y, = prox(gstar, z, stepsize)
if norm(y-w)/(1+norm(w)) <= 1e-12 break end
end
p = -A'*y + x
return p, R(0)
end