|
| 1 | +# Separable sum, using slices of an array as variables |
| 2 | + |
| 3 | +export PrecomposedSlicedSeparableSum |
| 4 | + |
| 5 | +""" |
| 6 | + precomposedSlicedSeparableSum((f_1, ..., f_k), (J_1, ..., J_k), (L_1, ..., L_k)) |
| 7 | +
|
| 8 | +Return the function |
| 9 | +```math |
| 10 | +g(x) = \\sum_{i=1}^k f_i(L_i * x_{J_i}). |
| 11 | +``` |
| 12 | +
|
| 13 | + precomposedSlicedSeparableSum(f, (J_1, ..., J_k), (L_1, ..., L_k)) |
| 14 | +
|
| 15 | +Analogous to the previous one, but apply the same function `f` to all slices |
| 16 | +of the variable `x`: |
| 17 | +```math |
| 18 | +g(x) = \\sum_{i=1}^k f(L_i * x_{J_i}). |
| 19 | +``` |
| 20 | +""" |
| 21 | +struct PrecomposedSlicedSeparableSum{S <: Tuple, T <: AbstractArray, U <: AbstractArray, V <: AbstractArray, N} |
| 22 | + fs::S # Tuple, where each element is a Vector with elements of the same type; the functions to prox on |
| 23 | + # Example: S = Tuple{Array{ProximalOperators.NormL1{Float64},1}, Array{ProximalOperators.NormL2{Float64},1}} |
| 24 | + idxs::T # Vector, where each element is a Vector containing the indices to prox on |
| 25 | + # Example: T = Array{Array{Tuple{Colon,UnitRange{Int64}},1},1} |
| 26 | + ops::U # Vector of operations (matrices or AbstractOperators) to apply to the function |
| 27 | + # Example: U = Array{Array{Matrix{Float64},1},1} |
| 28 | + μs::V # Vector of mu values for each function |
| 29 | +end |
| 30 | + |
| 31 | +function PrecomposedSlicedSeparableSum(fs::Tuple, idxs::Tuple, ops::Tuple, μs::Tuple) |
| 32 | + @assert length(fs) == length(idxs) |
| 33 | + @assert length(fs) == length(ops) |
| 34 | + ftypes = DataType[] |
| 35 | + fsarr = Array{Any,1}[] |
| 36 | + indarr = Array{eltype(idxs),1}[] |
| 37 | + opsarr = Array{Any,1}[] |
| 38 | + μsarr = Array{Any,1}[] |
| 39 | + for (i,f) in enumerate(fs) |
| 40 | + t = typeof(f) |
| 41 | + fi = findfirst(isequal(t), ftypes) |
| 42 | + if fi === nothing |
| 43 | + push!(ftypes, t) |
| 44 | + push!(fsarr, Any[f]) |
| 45 | + push!(indarr, eltype(idxs)[idxs[i]]) |
| 46 | + push!(opsarr, Any[ops[i]]) |
| 47 | + push!(μsarr, Any[μs[i]]) |
| 48 | + else |
| 49 | + push!(fsarr[fi], f) |
| 50 | + push!(indarr[fi], idxs[i]) |
| 51 | + push!(opsarr[fi], ops[i]) |
| 52 | + push!(μsarr[fi], μs[i]) |
| 53 | + end |
| 54 | + end |
| 55 | + fsnew = ((Array{typeof(fs[1]),1}(fs) for fs in fsarr)...,) |
| 56 | + @assert typeof(fsnew) == Tuple{(Array{ft,1} for ft in ftypes)...} |
| 57 | + PrecomposedSlicedSeparableSum{typeof(fsnew),typeof(indarr),typeof(opsarr),typeof(μsarr),length(fsnew)}(fsnew, indarr, opsarr, μsarr) |
| 58 | +end |
| 59 | + |
| 60 | +# Constructor for the case where the same function is applied to all slices |
| 61 | +PrecomposedSlicedSeparableSum(f::F, idxs::T, ops::U, μs::V) where {F, T <: Tuple, U <: Tuple, V <: Tuple} = |
| 62 | + PrecomposedSlicedSeparableSum(Tuple(f for k in eachindex(idxs)), idxs, ops, μs) |
| 63 | + |
| 64 | +# Unroll the loop over the different types of functions to evaluate |
| 65 | +function (f::PrecomposedSlicedSeparableSum)(x::Tuple) |
| 66 | + v = zero(eltype(x[1])) |
| 67 | + for (fs_group, idxs_group, ops_group) = zip(f.fs, f.idxs, f.ops) # For each function type |
| 68 | + for (fun, idx_group, hcat_op) in zip(fs_group, idxs_group, ops_group) # For each function of that type |
| 69 | + for (var_index, (x_var, idx)) in enumerate(zip(x, idx_group)) |
| 70 | + if idx isa Tuple |
| 71 | + v += fun(hcat_op[var_index] * view(x_var, idx...)) |
| 72 | + elseif idx isa Colon |
| 73 | + v += fun(hcat_op[var_index] * x_var) |
| 74 | + elseif idx isa Nothing |
| 75 | + # do nothing |
| 76 | + else |
| 77 | + v += fun(hcat_op[var_index] * view(x_var, idx)) |
| 78 | + end |
| 79 | + end |
| 80 | + end |
| 81 | + end |
| 82 | + return v |
| 83 | +end |
| 84 | + |
| 85 | +function slice_var(x, idx) |
| 86 | + if idx isa Tuple |
| 87 | + return view(x, idx...) |
| 88 | + elseif idx isa Colon |
| 89 | + return x |
| 90 | + else |
| 91 | + return view(x, idx) |
| 92 | + end |
| 93 | +end |
| 94 | + |
| 95 | +# Unroll the loop over the different types of functions to prox on |
| 96 | +function prox!(y::Tuple, f::PrecomposedSlicedSeparableSum, x::Tuple, gamma) |
| 97 | + v = zero(eltype(x[1])) |
| 98 | + for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type |
| 99 | + for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type |
| 100 | + for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y) |
| 101 | + if idx isa Nothing |
| 102 | + continue |
| 103 | + end |
| 104 | + sliced_x = slice_var(x_var, idx) |
| 105 | + sliced_y = slice_var(y_var, idx) |
| 106 | + res = op * sliced_x |
| 107 | + prox_res, g = prox(fun, res, μ.*gamma) |
| 108 | + prox_res .-= res |
| 109 | + prox_res ./= μ |
| 110 | + mul!(sliced_y, adjoint(op), prox_res) |
| 111 | + sliced_y .+= sliced_x |
| 112 | + v += g |
| 113 | + end |
| 114 | + end |
| 115 | + end |
| 116 | + return v |
| 117 | +end |
| 118 | + |
| 119 | +component_types(::Type{PrecomposedSlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S)) |
| 120 | + |
| 121 | +@generated is_proximable(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false) |
| 122 | +@generated is_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) |
| 123 | +@generated is_set_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) |
| 124 | +@generated is_singleton_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false) |
| 125 | +@generated is_cone_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) |
| 126 | +@generated is_affine_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false) |
| 127 | +@generated is_smooth(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) |
| 128 | +@generated is_generalized_quadratic(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) |
| 129 | +@generated is_strongly_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) |
| 130 | + |
| 131 | +function prox_naive(f::PrecomposedSlicedSeparableSum, x, gamma) |
| 132 | + fy = 0 |
| 133 | + y = similar.(x) |
| 134 | + for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type |
| 135 | + for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type |
| 136 | + for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y) |
| 137 | + if idx isa Nothing |
| 138 | + continue |
| 139 | + end |
| 140 | + sliced_x = slice_var(x_var, idx) |
| 141 | + sliced_y = slice_var(y_var, idx) |
| 142 | + res = op * sliced_x |
| 143 | + prox_res, _fy = prox_naive(fun, res, μ.*gamma) |
| 144 | + prox_res = (prox_res .- res) ./ μ |
| 145 | + mul!(sliced_y, adjoint(op), prox_res) |
| 146 | + fy += _fy |
| 147 | + sliced_y .+= sliced_x |
| 148 | + end |
| 149 | + end |
| 150 | + end |
| 151 | + return y, fy |
| 152 | +end |
0 commit comments