-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathprecomposeDiagonal.jl
More file actions
75 lines (64 loc) · 2.68 KB
/
precomposeDiagonal.jl
File metadata and controls
75 lines (64 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Precompose with diagonal scaling and translation
export PrecomposeDiagonal
"""
PrecomposeDiagonal(f, a, b)
Return the function
```math
g(x) = f(\\mathrm{diag}(a)x + b)
```
Function ``f`` must be convex and separable, or `a` must be a scalar, for the
`prox` of ``g`` to be computable. Parametes `a` and `b` can be arrays of
multiple dimensions, according to the shape/size of the input `x` that will be
provided to the function: the way the above expression for ``g`` should be
thought of, is `g(x) = f(a.*x + b)`.
"""
struct PrecomposeDiagonal{T, R, S}
f::T
a::R
b::S
function PrecomposeDiagonal{T,R,S}(f::T, a::R, b::S) where {T, R, S}
if R <: AbstractArray && !(is_convex(f) && is_separable(f))
error("`f` must be convex and separable since `a` is of type $(R)")
end
if any(a == 0)
error("elements of `a` must be nonzero")
else
new(f, a, b)
end
end
end
is_separable(::Type{<:PrecomposeDiagonal{T}}) where T = is_separable(T)
is_proximable(::Type{<:PrecomposeDiagonal{T}}) where T = is_proximable(T)
is_convex(::Type{<:PrecomposeDiagonal{T}}) where T = is_convex(T)
is_set_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_set_indicator(T)
is_singleton_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_singleton_indicator(T)
is_cone_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_cone_indicator(T)
is_affine_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_affine_indicator(T)
is_smooth(::Type{<:PrecomposeDiagonal{T}}) where T = is_smooth(T)
is_locally_smooth(::Type{<:PrecomposeDiagonal{T}}) where T = is_locally_smooth(T)
is_generalized_quadratic(::Type{<:PrecomposeDiagonal{T}}) where T = is_generalized_quadratic(T)
is_strongly_convex(::Type{<:PrecomposeDiagonal{T}}) where T = is_strongly_convex(T)
PrecomposeDiagonal(f::T, a::S=1, b::S=0) where {T, S <: Real} = PrecomposeDiagonal{T, S, S}(f, a, b)
PrecomposeDiagonal(f::T, a::R, b::S=0) where {T, R <: AbstractArray, S <: Real} = PrecomposeDiagonal{T, R, S}(f, a, b)
PrecomposeDiagonal(f::T, a::R, b::S) where {T, R <: Union{AbstractArray, Real}, S <: AbstractArray} = PrecomposeDiagonal{T, R, S}(f, a, b)
function (g::PrecomposeDiagonal)(x)
return g.f(g.a .* x .+ g.b)
end
function gradient!(y, g::PrecomposeDiagonal, x)
z = g.a .* x .+ g.b
v = gradient!(y, g.f, z)
y .*= g.a
return v
end
function prox!(y, g::PrecomposeDiagonal, x, gamma)
z = g.a .* x .+ g.b
v = prox!(y, g.f, z, (g.a .* g.a) .* gamma)
y .-= g.b
y ./= g.a
return v
end
function prox_naive(g::PrecomposeDiagonal, x, gamma)
z = g.a .* x .+ g.b
y, fy = prox_naive(g.f, z, (g.a .* g.a) .* gamma)
return (y .- g.b)./g.a, fy
end