diff --git a/src/ShiftedProximalOperators.jl b/src/ShiftedProximalOperators.jl index b21055a..d047b86 100644 --- a/src/ShiftedProximalOperators.jl +++ b/src/ShiftedProximalOperators.jl @@ -31,6 +31,9 @@ include("Rank.jl") include("cappedl1.jl") include("Nuclearnorm.jl") +include("null.jl") +include("shiftedNullBox.jl") + include("shiftedCompositeNormL2.jl") include("shiftedNormL0.jl") include("shiftedNormL0Box.jl") @@ -98,6 +101,7 @@ end set_radius!(ψ::ShiftedNormL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) set_radius!(ψ::ShiftedNormL1Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) set_radius!(ψ::ShiftedRootNormLhalfBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) +set_radius!(ψ::ShiftedNullRegularizerBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) """ set_bounds!(ψ, l, u) diff --git a/src/null.jl b/src/null.jl new file mode 100644 index 0000000..3f95dfa --- /dev/null +++ b/src/null.jl @@ -0,0 +1,66 @@ +# Null regularizer +export NullRegularizer + +@doc raw""" + NullRegularizer(::Type{T}) where {T <: Real} + NullRegularizer(lambda::T) where {T <: Real} + + +Returns the null regularizer, i.e., the function that is identically zero. +```math +h(x) = 0 +``` + +### Arguments +- `T`: The type of zero that is expected to be returned by the regularizer. + +In the second constructor, the type of lambda is used to infer the type of zero that is expected to be returned by the regularizer. The value of lambda is ignored. +""" +struct NullRegularizer{T <: Real} <: ShiftedProximableFunction end + +NullRegularizer(lambda::T) where {T <: Real} = NullRegularizer{T}() +NullRegularizer(::Type{T}) where {T <: Real} = NullRegularizer{T}() + +shifted(h::NullRegularizer{T}, xk::AbstractVector{T}) where {T <: Real} = + NullRegularizer(T) + +function shift!(h::NullRegularizer{T}, xk::AbstractVector{T}) where {T <: Real} + return h +end + +function (h::NullRegularizer{T})(y) where {T <: Real} + return zero(T) +end + +fun_name(h::NullRegularizer{T}) where {T <: Real} = "null regularizer" +fun_expr(h::NullRegularizer{T}) where {T <: Real} = "t ↦ 0" +fun_params(h::NullRegularizer{T}) where {T <: Real} = "" + +function Base.show(io::IO, h::NullRegularizer{T}) where {T <: Real} + println(io, "description : ", fun_name(h)) + println(io, "expression : ", fun_expr(h)) + println(io, "parameters : ", fun_params(h)) +end + +function prox!( + y::AbstractVector{T}, + h::NullRegularizer{T}, + q::AbstractVector{T}, + ν::T +) where {T <: Real} + @assert ν > zero(T) + y .= q + return y +end + +function iprox!( + y::AbstractVector{T}, + h::NullRegularizer{T}, + g::AbstractVector{T}, + d::AbstractVector{T}, +) where {T <: Real} + @inbounds for i ∈ eachindex(y) + @assert d[i] > 0 + y[i] = - g[i] / d[i] + end +end \ No newline at end of file diff --git a/src/shiftedNullBox.jl b/src/shiftedNullBox.jl new file mode 100644 index 0000000..83d757f --- /dev/null +++ b/src/shiftedNullBox.jl @@ -0,0 +1,111 @@ +# Null box regularizer +export ShiftedNullRegularizerBox + +@doc raw""" + ShiftedNullRegularizerBox(h, sj, shifted_twice, l, u) + +Returns the shifted box null regularizer, i.e., the function that is identically zero on a box and +∞ outside of it. +```math +ψ(x) = h(xk + sj + x) + \chi(sj + x | [l, u]) +``` +where `h` is identically zero, `xk`represents a shift, `sj` is an additional shift that is applied to the indicator +function as well and `[l, u]` is the box that defines the domain of the function. + +### Arguments +- `h`: The unshifted null regularizer (see `NullRegularizer`). +- `sj`: The shift of the indicator function. +- `shifted_twice`: A boolean indicating whether `sj` is updated or not on shifts, true means that `sj` is updated, false means that `xk` is updated. +- `l`: The lower bound of the box. +- `u`: The upper bound of the box. +""" +mutable struct ShiftedNullRegularizerBox{T <: Real, V <: AbstractVector{T}, VT <: Union{AbstractVector{T}, T}} <: ShiftedProximableFunction + h::NullRegularizer{T} + sj::V + shifted_twice::Bool + l::VT + u::VT + + function ShiftedNullRegularizerBox( + h::NullRegularizer{T}, + sj::V, + shifted_twice::Bool, + l::VT, + u::VT, + ) where {T <: Real, V <: AbstractVector{T}, VT <: Union{AbstractVector{T}, T}} + new{T, V, VT}(h, sj, shifted_twice, l, u) + end +end + +shifted( + h::NullRegularizer{T}, + xk::AbstractVector{T}, + l, + u, + selected::AbstractArray{I} = 1:length(xk), +) where {T <: Real, I <: Integer} = shifted(h, xk, l, u) +shifted( + h::NullRegularizer{T}, + xk::AbstractVector{T}, + l::VT, + u::VT, +) where {T <: Real, VT} = ShiftedNullRegularizerBox(h, zero(xk), false, l, u) +shifted( + h::NullRegularizer{T}, + xk::AbstractVector{T}, + Δ::T, + χ::Conjugate{IndBallL1{T}}, +) where {T <: Real} = ShiftedNullRegularizerBox(h, zero(xk), false, -Δ, Δ) +shifted( + ψ::ShiftedNullRegularizerBox{T, V, VT}, + sj::AbstractVector{T}, +) where {T <: Real, V <: AbstractVector{T}, VT} = ShiftedNullRegularizerBox(ψ.h, sj, true, ψ.l, ψ.u) + +function shift!(ψ::ShiftedNullRegularizerBox{T, V, VT}, shift::AbstractVector{T}) where {T <: Real, V <: AbstractVector{T}, VT} + ψ.shifted_twice && (ψ.sj .= shift) +end + +function (ψ::ShiftedNullRegularizerBox{T, V})(y) where {T <: Real, V <: AbstractVector{T}} + ϵ = √eps(eltype(y)) + @inbounds for i in eachindex(y) + l = ψ.l isa AbstractVector ? ψ.l[i] : ψ.l + u = ψ.u isa AbstractVector ? ψ.u[i] : ψ.u + if !(l - ϵ ≤ ψ.sj[i] + y[i] ≤ u + ϵ) + return T(Inf) + end + end + return zero(T) +end + +fun_name(ψ::ShiftedNullRegularizerBox{T, V}) where {T <: Real, V <: AbstractVector{T}} = "shifted null regularizer with box indicator" +fun_expr(ψ::ShiftedNullRegularizerBox{T, V}) where {T <: Real, V <: AbstractVector{T}} = "t ↦ χ({sj + t .∈ [l,u]})" +fun_params(ψ::ShiftedNullRegularizerBox{T, V}) where {T <: Real, V <: AbstractVector{T}} = + "sj = $(ψ.sj)\n" * " "^14 * "l = $(ψ.l)\n" * " "^14 * "u = $(ψ.u)" + +function prox!( + y::AbstractVector{T}, + ψ::ShiftedNullRegularizerBox{T, V}, + q::AbstractVector{T}, + σ::T, +) where {T <: Real, V <: AbstractVector{T}} + @assert σ > zero(T) + @inbounds for i ∈ eachindex(y) + l = ψ.l isa AbstractVector ? ψ.l[i] : ψ.l + u = ψ.u isa AbstractVector ? ψ.u[i] : ψ.u + y[i] = prox_zero(q[i], l - ψ.sj[i], u - ψ.sj[i]) + end + return y +end + +function iprox!( + y::AbstractVector{T}, + ψ::ShiftedNullRegularizerBox{T, V}, + g::AbstractVector{T}, + d::AbstractVector{T}, +) where {T <: Real, V <: AbstractVector{T}} + @inbounds for i ∈ eachindex(y) + l = ψ.l isa AbstractVector ? ψ.l[i] : ψ.l + u = ψ.u isa AbstractVector ? ψ.u[i] : ψ.u + y[i] = iprox_zero(d[i], g[i], l - ψ.sj[i], u - ψ.sj[i]) + end + return y +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 4855827..86fc5cd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,65 @@ using Test include("test_psvd.jl") +@testset "NullRegularizer" begin + for T in (Float64, Float32) + h = NullRegularizer(T) + @test h(T.([1.0, 1.0])) == T(0.0) + y = similar(T.([1.0, 2.0])) + prox!(y, h, T.([3.0, 4.0]), T(1.0)) + @test all(y .== T.([3.0, 4.0])) + iprox!(y, h, T.([3.0, 4.0]), T.([2.0, 4.0])) + @test all(y .== -T.([1.5, 1.0])) + + h_shifted = shifted(h, T.([1.0, 2.0])) + @test h_shifted(T.([1.0, 1.0])) == T(0.0) + prox!(y, h_shifted, T.([3.0, 4.0]), T(1.0)) + @test all(y .== T.([3.0, 4.0])) + iprox!(y, h_shifted, T.([3.0, 4.0]), T.([2.0, 4.0])) + @test all(y .== -T.([1.5, 1.0])) + + shift!(h_shifted, T.([5.0, 6.0])) + @test h_shifted(T.([1.0, 1.0])) == T(0.0) + prox!(y, h_shifted, T.([3.0, 4.0]), T(1.0)) + @test all(y .== T.([3.0, 4.0])) + iprox!(y, h_shifted, T.([3.0, 4.0]), T.([2.0, 4.0])) + @test all(y .== -T.([1.5, 1.0])) + + h_shifted_box = shifted(h, T.([1.0, 2.0]), T.([0.0, 0.0]), T.([2.0, 3.0])) + @test h_shifted_box.shifted_twice == false + @test h_shifted_box(T.([0.0, 0.0])) == T(0.0) + @test h_shifted_box(T.([-1.0, 0.0])) == T(Inf) + @test h_shifted_box(T.([2.0, 3.0])) == T(0.0) + @test h_shifted_box(T.([1.0, 1.5])) == T(0.0) + prox!(y, h_shifted_box, T.([3.0, 4.0]), T(1.0)) + @test all(y .== T.([2.0, 3.0])) + iprox!(y, h_shifted_box, T.([3.0, 4.0]), T.([2.0, 4.0])) + @test all(y .== T.([0.0, 0.0])) + + shift!(h_shifted_box, T.([5.0, 6.0])) + + @test h_shifted_box(T.([0.0, 0.0])) == T(0.0) + @test h_shifted_box(T.([-1.0, 0.0])) == T(Inf) + @test h_shifted_box(T.([2.0, 3.0])) == T(0.0) + @test h_shifted_box(T.([1.0, 1.5])) == T(0.0) + prox!(y, h_shifted_box, T.([3.0, 4.0]), T(1.0)) + @test all(y .== T.([2.0, 3.0])) + iprox!(y, h_shifted_box, T.([3.0, 4.0]), T.([2.0, 4.0])) + @test all(y .== T.([0.0, 0.0])) + + h_shifted_twice_box = shifted(h_shifted_box, T.([1.0, 2.0])) + @test h_shifted_twice_box.shifted_twice == true + @test h_shifted_twice_box(T.([0.0, 0.0])) == h_shifted_box(T.([1.0, 2.0])) + @test h_shifted_twice_box(T.([-1.0, 0.0])) == h_shifted_box(T.([0.0, 2.0])) + @test h_shifted_twice_box(T.([2.0, 3.0])) == h_shifted_box(T.([3.0, 5.0])) + @test h_shifted_twice_box(T.([1.0, 1.5])) == h_shifted_box(T.([2.0, 3.5])) + prox!(y, h_shifted_twice_box, T.([3.0, 4.0]), T(1.0)) + @test all(y .== T.([1.0, 1.0])) + iprox!(y, h_shifted_twice_box, T.([3.0, 4.0]), T.([2.0, 4.0])) + @test all(y .== T.([-1.0, -1.0])) + end +end + for (op, composite_op, shifted_op) ∈ zip((:NormL2,), (:CompositeNormL2,), (:ShiftedCompositeNormL2,)) @testset "$shifted_op" begin diff --git a/test/test_allocs.jl b/test/test_allocs.jl index d2824b3..9b90d39 100644 --- a/test/test_allocs.jl +++ b/test/test_allocs.jl @@ -37,6 +37,17 @@ macro wrappedallocs(expr) end @testset "allocs" begin + + for op ∈ (:NullRegularizer,) + h = eval(op)(Float64) + @test @wrappedallocs(h([0.0, 0.0])) == 0 + h_shifted_box = shifted(h, [1.0, 2.0], [0.0, 0.0], [2.0, 3.0]) + @test @wrappedallocs(h_shifted_box([0.0, 0.0])) == 0 + y = zeros(Float64, 2) + @test @wrappedallocs(prox!(y, h_shifted_box, [3.0, 4.0], 1.0)) == 0 + @test @wrappedallocs(iprox!(y, h_shifted_box, [3.0, 4.0], [2.0, 4.0])) == 0 + end + for (op, composite_op) ∈ ((:NormL2, :CompositeNormL2),) CompositeOp = eval(composite_op) @@ -66,7 +77,7 @@ end ν = 0.1056 @test @wrappedallocs(prox!(y, ϕ, x, ν)) == 0 end - for op ∈ (:NormL0, :NormL1, :RootNormLhalf) + for op ∈ (:NormL0, :NormL1, :RootNormLhalf,) h = eval(op)(1.0) n = 1000 xk = rand(n) @@ -99,7 +110,7 @@ end @test allocs == 0 end - for op ∈ (:NormL0, :NormL1) + for op ∈ (:NormL0, :NormL1,) h = eval(op)(1.0) n = 1000 xk = rand(n)