Skip to content
4 changes: 4 additions & 0 deletions src/ShiftedProximalOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ include("Rank.jl")
include("cappedl1.jl")
include("Nuclearnorm.jl")

include("null.jl")
include("shiftedNullBox.jl")

include("shiftedCompositeNormL2.jl")
include("shiftedNormL0.jl")
include("shiftedNormL0Box.jl")
Expand Down Expand Up @@ -98,6 +101,7 @@ end
set_radius!(ψ::ShiftedNormL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedNormL1Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedRootNormLhalfBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedNullRegularizerBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)

"""
set_bounds!(ψ, l, u)
Expand Down
66 changes: 66 additions & 0 deletions src/null.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Null regularizer
export NullRegularizer

@doc raw"""
NullRegularizer(::Type{T}) where {T <: Real}
NullRegularizer(lambda::T) where {T <: Real}


Returns the null regularizer, i.e., the function that is identically zero.
```math
h(x) = 0
```

### Arguments
- `T`: The type of zero that is expected to be returned by the regularizer.

In the second constructor, the type of lambda is used to infer the type of zero that is expected to be returned by the regularizer. The value of lambda is ignored.
"""
struct NullRegularizer{T <: Real} <: ShiftedProximableFunction end

NullRegularizer(lambda::T) where {T <: Real} = NullRegularizer{T}()
NullRegularizer(::Type{T}) where {T <: Real} = NullRegularizer{T}()

Comment thread
MaxenceGollier marked this conversation as resolved.
shifted(h::NullRegularizer{T}, xk::AbstractVector{T}) where {T <: Real} =
NullRegularizer(T)

function shift!(h::NullRegularizer{T}, xk::AbstractVector{T}) where {T <: Real}
return h
end

function (h::NullRegularizer{T})(y) where {T <: Real}
return zero(T)
end

fun_name(h::NullRegularizer{T}) where {T <: Real} = "null regularizer"
fun_expr(h::NullRegularizer{T}) where {T <: Real} = "t ↦ 0"
fun_params(h::NullRegularizer{T}) where {T <: Real} = ""

function Base.show(io::IO, h::NullRegularizer{T}) where {T <: Real}
println(io, "description : ", fun_name(h))
println(io, "expression : ", fun_expr(h))
println(io, "parameters : ", fun_params(h))
end

function prox!(
y::AbstractVector{T},
h::NullRegularizer{T},
q::AbstractVector{T},
ν::T
) where {T <: Real}
@assert ν > zero(T)
y .= q
return y
end

function iprox!(
y::AbstractVector{T},
h::NullRegularizer{T},
g::AbstractVector{T},
d::AbstractVector{T},
) where {T <: Real}
@inbounds for i ∈ eachindex(y)
@assert d[i] > 0
y[i] = - g[i] / d[i]
end
end
111 changes: 111 additions & 0 deletions src/shiftedNullBox.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Null box regularizer
export ShiftedNullRegularizerBox

@doc raw"""
ShiftedNullRegularizerBox(h, sj, shifted_twice, l, u)

Returns the shifted box null regularizer, i.e., the function that is identically zero on a box and +∞ outside of it.
```math
ψ(x) = h(xk + sj + x) + \chi(sj + x | [l, u])
```
where `h` is identically zero, `xk`represents a shift, `sj` is an additional shift that is applied to the indicator
function as well and `[l, u]` is the box that defines the domain of the function.

### Arguments
- `h`: The unshifted null regularizer (see `NullRegularizer`).
- `sj`: The shift of the indicator function.
- `shifted_twice`: A boolean indicating whether `sj` is updated or not on shifts, true means that `sj` is updated, false means that `xk` is updated.
- `l`: The lower bound of the box.
- `u`: The upper bound of the box.
"""
mutable struct ShiftedNullRegularizerBox{T <: Real, V <: AbstractVector{T}, VT <: Union{AbstractVector{T}, T}} <: ShiftedProximableFunction
h::NullRegularizer{T}
sj::V
shifted_twice::Bool
l::VT
u::VT

function ShiftedNullRegularizerBox(
h::NullRegularizer{T},
sj::V,
shifted_twice::Bool,
l::VT,
u::VT,
) where {T <: Real, V <: AbstractVector{T}, VT <: Union{AbstractVector{T}, T}}
new{T, V, VT}(h, sj, shifted_twice, l, u)
end
end

shifted(
h::NullRegularizer{T},
xk::AbstractVector{T},
l,
u,
selected::AbstractArray{I} = 1:length(xk),
) where {T <: Real, I <: Integer} = shifted(h, xk, l, u)
shifted(
h::NullRegularizer{T},
xk::AbstractVector{T},
l::VT,
u::VT,
) where {T <: Real, VT} = ShiftedNullRegularizerBox(h, zero(xk), false, l, u)
shifted(
h::NullRegularizer{T},
xk::AbstractVector{T},
Δ::T,
χ::Conjugate{IndBallL1{T}},
) where {T <: Real} = ShiftedNullRegularizerBox(h, zero(xk), false, -Δ, Δ)
shifted(
ψ::ShiftedNullRegularizerBox{T, V, VT},
sj::AbstractVector{T},
) where {T <: Real, V <: AbstractVector{T}, VT} = ShiftedNullRegularizerBox(ψ.h, sj, true, ψ.l, ψ.u)

function shift!(ψ::ShiftedNullRegularizerBox{T, V, VT}, shift::AbstractVector{T}) where {T <: Real, V <: AbstractVector{T}, VT}
ψ.shifted_twice && (ψ.sj .= shift)
end

function (ψ::ShiftedNullRegularizerBox{T, V})(y) where {T <: Real, V <: AbstractVector{T}}
ϵ = √eps(eltype(y))
@inbounds for i in eachindex(y)
l = ψ.l isa AbstractVector ? ψ.l[i] : ψ.l
u = ψ.u isa AbstractVector ? ψ.u[i] : ψ.u
if !(l - ϵ ≤ ψ.sj[i] + y[i] ≤ u + ϵ)
return T(Inf)
end
end
return zero(T)
end

fun_name(ψ::ShiftedNullRegularizerBox{T, V}) where {T <: Real, V <: AbstractVector{T}} = "shifted null regularizer with box indicator"
fun_expr(ψ::ShiftedNullRegularizerBox{T, V}) where {T <: Real, V <: AbstractVector{T}} = "t ↦ χ({sj + t .∈ [l,u]})"
fun_params(ψ::ShiftedNullRegularizerBox{T, V}) where {T <: Real, V <: AbstractVector{T}} =
"sj = $(ψ.sj)\n" * " "^14 * "l = $(ψ.l)\n" * " "^14 * "u = $(ψ.u)"

function prox!(
y::AbstractVector{T},
ψ::ShiftedNullRegularizerBox{T, V},
q::AbstractVector{T},
σ::T,
) where {T <: Real, V <: AbstractVector{T}}
@assert σ > zero(T)
@inbounds for i ∈ eachindex(y)
l = ψ.l isa AbstractVector ? ψ.l[i] : ψ.l
u = ψ.u isa AbstractVector ? ψ.u[i] : ψ.u
y[i] = prox_zero(q[i], l - ψ.sj[i], u - ψ.sj[i])
end
return y
end

function iprox!(
y::AbstractVector{T},
ψ::ShiftedNullRegularizerBox{T, V},
g::AbstractVector{T},
d::AbstractVector{T},
) where {T <: Real, V <: AbstractVector{T}}
@inbounds for i ∈ eachindex(y)
l = ψ.l isa AbstractVector ? ψ.l[i] : ψ.l
u = ψ.u isa AbstractVector ? ψ.u[i] : ψ.u
y[i] = iprox_zero(d[i], g[i], l - ψ.sj[i], u - ψ.sj[i])
end
return y
end
59 changes: 59 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,65 @@ using Test

include("test_psvd.jl")

@testset "NullRegularizer" begin
for T in (Float64, Float32)
h = NullRegularizer(T)
@test h(T.([1.0, 1.0])) == T(0.0)
y = similar(T.([1.0, 2.0]))
prox!(y, h, T.([3.0, 4.0]), T(1.0))
@test all(y .== T.([3.0, 4.0]))
iprox!(y, h, T.([3.0, 4.0]), T.([2.0, 4.0]))
@test all(y .== -T.([1.5, 1.0]))

h_shifted = shifted(h, T.([1.0, 2.0]))
@test h_shifted(T.([1.0, 1.0])) == T(0.0)
prox!(y, h_shifted, T.([3.0, 4.0]), T(1.0))
@test all(y .== T.([3.0, 4.0]))
iprox!(y, h_shifted, T.([3.0, 4.0]), T.([2.0, 4.0]))
@test all(y .== -T.([1.5, 1.0]))

shift!(h_shifted, T.([5.0, 6.0]))
@test h_shifted(T.([1.0, 1.0])) == T(0.0)
prox!(y, h_shifted, T.([3.0, 4.0]), T(1.0))
@test all(y .== T.([3.0, 4.0]))
iprox!(y, h_shifted, T.([3.0, 4.0]), T.([2.0, 4.0]))
@test all(y .== -T.([1.5, 1.0]))

h_shifted_box = shifted(h, T.([1.0, 2.0]), T.([0.0, 0.0]), T.([2.0, 3.0]))
@test h_shifted_box.shifted_twice == false
@test h_shifted_box(T.([0.0, 0.0])) == T(0.0)
@test h_shifted_box(T.([-1.0, 0.0])) == T(Inf)
@test h_shifted_box(T.([2.0, 3.0])) == T(0.0)
@test h_shifted_box(T.([1.0, 1.5])) == T(0.0)
prox!(y, h_shifted_box, T.([3.0, 4.0]), T(1.0))
@test all(y .== T.([2.0, 3.0]))
iprox!(y, h_shifted_box, T.([3.0, 4.0]), T.([2.0, 4.0]))
@test all(y .== T.([0.0, 0.0]))

shift!(h_shifted_box, T.([5.0, 6.0]))

@test h_shifted_box(T.([0.0, 0.0])) == T(0.0)
@test h_shifted_box(T.([-1.0, 0.0])) == T(Inf)
@test h_shifted_box(T.([2.0, 3.0])) == T(0.0)
@test h_shifted_box(T.([1.0, 1.5])) == T(0.0)
prox!(y, h_shifted_box, T.([3.0, 4.0]), T(1.0))
@test all(y .== T.([2.0, 3.0]))
iprox!(y, h_shifted_box, T.([3.0, 4.0]), T.([2.0, 4.0]))
@test all(y .== T.([0.0, 0.0]))

h_shifted_twice_box = shifted(h_shifted_box, T.([1.0, 2.0]))
@test h_shifted_twice_box.shifted_twice == true
@test h_shifted_twice_box(T.([0.0, 0.0])) == h_shifted_box(T.([1.0, 2.0]))
@test h_shifted_twice_box(T.([-1.0, 0.0])) == h_shifted_box(T.([0.0, 2.0]))
@test h_shifted_twice_box(T.([2.0, 3.0])) == h_shifted_box(T.([3.0, 5.0]))
@test h_shifted_twice_box(T.([1.0, 1.5])) == h_shifted_box(T.([2.0, 3.5]))
prox!(y, h_shifted_twice_box, T.([3.0, 4.0]), T(1.0))
@test all(y .== T.([1.0, 1.0]))
iprox!(y, h_shifted_twice_box, T.([3.0, 4.0]), T.([2.0, 4.0]))
@test all(y .== T.([-1.0, -1.0]))
end
end

for (op, composite_op, shifted_op) ∈
zip((:NormL2,), (:CompositeNormL2,), (:ShiftedCompositeNormL2,))
@testset "$shifted_op" begin
Expand Down
15 changes: 13 additions & 2 deletions test/test_allocs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ macro wrappedallocs(expr)
end

@testset "allocs" begin

for op ∈ (:NullRegularizer,)
h = eval(op)(Float64)
@test @wrappedallocs(h([0.0, 0.0])) == 0
h_shifted_box = shifted(h, [1.0, 2.0], [0.0, 0.0], [2.0, 3.0])
@test @wrappedallocs(h_shifted_box([0.0, 0.0])) == 0
y = zeros(Float64, 2)
@test @wrappedallocs(prox!(y, h_shifted_box, [3.0, 4.0], 1.0)) == 0
@test @wrappedallocs(iprox!(y, h_shifted_box, [3.0, 4.0], [2.0, 4.0])) == 0
end

for (op, composite_op) ∈ ((:NormL2, :CompositeNormL2),)
CompositeOp = eval(composite_op)

Expand Down Expand Up @@ -66,7 +77,7 @@ end
ν = 0.1056
@test @wrappedallocs(prox!(y, ϕ, x, ν)) == 0
end
for op ∈ (:NormL0, :NormL1, :RootNormLhalf)
for op ∈ (:NormL0, :NormL1, :RootNormLhalf,)
h = eval(op)(1.0)
n = 1000
xk = rand(n)
Expand Down Expand Up @@ -99,7 +110,7 @@ end
@test allocs == 0
end

for op ∈ (:NormL0, :NormL1)
for op ∈ (:NormL0, :NormL1,)
h = eval(op)(1.0)
n = 1000
xk = rand(n)
Expand Down
Loading