diff --git a/src/Nuclearnorm.jl b/src/Nuclearnorm.jl index 7fa4d2d..e3a2d8e 100644 --- a/src/Nuclearnorm.jl +++ b/src/Nuclearnorm.jl @@ -53,16 +53,23 @@ function prox!( x::AbstractVector{R}, gamma::R, ) where {R <: Real, S <: AbstractArray, T, Tr, M <: AbstractArray{T}} - f.A .= reshape_array(x, size(f.A)) + # copy reshaped x into internal matrix A + copyto!(f.A, reshape_array(x, size(f.A))) psvd_dd!(f.F, f.A, full = false) c = sqrt(2 * f.lambda * gamma) - f.F.S .= max.(0, f.F.S .- f.lambda * gamma) + # in-place shrink singular values + @inbounds for i in eachindex(f.F.S) + v = f.F.S[i] - f.lambda * gamma + f.F.S[i] = v > 0 ? v : zero(v) + end + # scale U by singular values in-place for i ∈ eachindex(f.F.S) + s = f.F.S[i] for j = 1:size(f.A, 1) - f.F.U[j, i] = f.F.U[j, i] * f.F.S[i] + f.F.U[j, i] = f.F.U[j, i] * s end end mul!(f.A, f.F.U, f.F.Vt) - y .= reshape_array(f.A, (size(y, 1), 1)) + copyto!(y, reshape_array(f.A, (size(y, 1), 1))) return y end diff --git a/src/ShiftedProximalOperators.jl b/src/ShiftedProximalOperators.jl index b21055a..e9b14e0 100644 --- a/src/ShiftedProximalOperators.jl +++ b/src/ShiftedProximalOperators.jl @@ -41,16 +41,27 @@ include("shiftedGroupNormL2.jl") include("shiftedNormL1B2.jl") include("shiftedNormL1Box.jl") include("shiftedIndBallL0.jl") -include("shiftedIndBallL0BInf.jl") +include("shiftedIndBallL0Box.jl") include("shiftedRootNormLhalfBox.jl") -include("shiftedGroupNormL2Binf.jl") +include("shiftedGroupNormL2Box.jl") include("shiftedRank.jl") include("shiftedCappedl1.jl") include("shiftedNuclearnorm.jl") function (ψ::ShiftedProximableFunction)(y) @. ψ.xsy = ψ.xk + ψ.sj + y - return ψ.h(ψ.xsy) + h = ψ.h + if isa(h, NormL1) + return h.lambda * norm(ψ.xsy, 1) + elseif isa(h, NormL0) + return h.lambda * count(!iszero, ψ.xsy) + elseif isa(h, RootNormLhalf) + return h.lambda * sum(sqrt ∘ abs, ψ.xsy) + elseif isa(h, NormL2) + return h.lambda * norm(ψ.xsy) + else + return h(ψ.xsy) + end end function (ψ::ShiftedCompositeProximableFunction)(y) @@ -98,6 +109,8 @@ end set_radius!(ψ::ShiftedNormL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) set_radius!(ψ::ShiftedNormL1Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) set_radius!(ψ::ShiftedRootNormLhalfBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) +set_radius!(ψ::ShiftedIndBallL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) +set_radius!(ψ::ShiftedGroupNormL2Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ) """ set_bounds!(ψ, l, u) @@ -116,6 +129,28 @@ end return ψ.h.lambda elseif prop === :r return ψ.h.r + elseif prop === :Δ + # For Box variants, convert symmetric box constraints back to radius + if hasfield(typeof(ψ), :l) && hasfield(typeof(ψ), :u) + l = getfield(ψ, :l) + u = getfield(ψ, :u) + if isa(l, Real) && isa(u, Real) && l == -u + return u # Return radius when box is symmetric [-Δ, Δ] + elseif isa(l, AbstractVector) && isa(u, AbstractVector) && all(l .== -u) + return u[1] # Return radius when all elements are symmetric + else + error("Cannot convert asymmetric box constraints to radius Δ") + end + else + return getfield(ψ, prop) # Fall back to field access for Binf types + end + elseif prop === :χ + # For backward compatibility, provide a dummy χ for Box variants + if hasfield(typeof(ψ), :l) && hasfield(typeof(ψ), :u) + return Conjugate(IndBallL1(1.0)) # Dummy conjugate + else + return getfield(ψ, prop) + end else return getfield(ψ, prop) end diff --git a/src/psvd.jl b/src/psvd.jl index 523cb3f..25b5193 100644 --- a/src/psvd.jl +++ b/src/psvd.jl @@ -71,14 +71,38 @@ PSVD{T}(F::PSVD) where {T} = PSVD( convert(AbstractMatrix{T}, F.Vt), convert(AbstractVector{T}, F.work), convert(AbstractVector{BlasInt}, F.iwork), - convert(AbstractVector{Tr}, F.rwork), + convert(AbstractVector{real(T)}, F.rwork), ) Factorization{T}(F::PSVD) where {T} = PSVD{T}(F) +function psvd( + A::StridedMatrix{T}; + full::Bool = false, + alg::Algorithm = default_svd_alg(A), + destructive::Bool = false, +) where {T <: BlasFloat} + m, n = size(A) + if m == 0 || n == 0 + u = Matrix{T}(I, m, full ? m : n) + s = real(zeros(T, 0)) + vt = Matrix{T}(I, n, n) + Tr = real(T) + return PSVD(u, s, vt, T[], BlasInt[], Tr[]) + end + + if typeof(alg) <: LinearAlgebra.QRIteration + F = psvd_workspace_qr(A, full = full) + return psvd_qr!(F, destructive ? A : copy(A); full = full) + else + F = psvd_workspace_dd(A, full = full) + return psvd_dd!(F, destructive ? A : copy(A); full = full) + end +end + # iteration for destructuring into components Base.iterate(S::PSVD) = (S.U, Val(:S)) Base.iterate(S::PSVD, ::Val{:S}) = (S.S, Val(:V)) -Base.iterate(S::PSVD, ::Val{:V}) = (S.V, Val(:done)) +Base.iterate(S::PSVD, ::Val{:V}) = (S.Vt', Val(:done)) Base.iterate(S::PSVD, ::Val{:done}) = nothing # Functions for alg = QRIteration() @@ -149,7 +173,6 @@ for (gesvd, elty, relty) in ((:dgesvd_, :Float64, :Float64), (:sgesvd_, :Float32 ) where {M} jobuvt = full ? 'A' : 'S' m, n = size(A) - m, n = size(A) minmn = min(m, n) @assert length(F.S) == minmn @assert size(F.U) == (jobuvt == 'A' ? (m, m) : (m, minmn)) @@ -204,6 +227,7 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp @eval begin function psvd_workspace_qr(A::StridedMatrix{$elty}; full::Bool = false) jobuvt = full ? 'A' : 'S' + m, n = size(A) minmn = min(m, n) S = similar(A, $relty, minmn) U = similar(A, $elty, jobuvt == 'A' ? (m, m) : (m, minmn)) @@ -211,7 +235,7 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp work = Vector{$elty}(undef, 1) lwork = BlasInt(-1) info = Ref{BlasInt}() - rwork = Vector{R}(undef, 5minmn) + rwork = Vector{$relty}(undef, max(1, 5 * minmn)) ccall( (@blasfunc($gesvd), libblastrampoline), Cvoid, @@ -234,8 +258,8 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp Clong, Clong, ), - jobu, - jobvt, + jobuvt, + jobuvt, m, n, A, @@ -295,8 +319,8 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp Clong, Clong, ), - jobu, - jobvt, + jobuvt, + jobuvt, m, n, A, @@ -439,148 +463,3 @@ for (gesdd, elty, relty) in ((:dgesdd_, :Float64, :Float64), (:sgesdd_, :Float32 end end end - -for (gesdd, elty, relty) in ((:zgesdd_, :ComplexF64, :Float64), (:cgesdd_, :ComplexF32, :Float32)) - @eval begin - function psvd_workspace_dd(A::StridedMatrix{$elty}; full::Bool = false) - require_one_based_indexing(A) - chkstride1(A) - job = full ? 'A' : 'S' - m, n = size(A) - minmn = min(m, n) - U = similar(A, $elty, job == 'A' ? (m, m) : (m, minmn)) - Vt = similar(A, $elty, job == 'A' ? (n, n) : (minmn, n)) - work = Vector{$elty}(undef, 1) - lwork = BlasInt(-1) - S = similar(A, $relty, minmn) - rwork = Vector{$relty}(undef, minmn * max(5 * minmn + 7, 2 * max(m, n) + 2 * minmn + 1)) - iwork = Vector{BlasInt}(undef, 8 * minmn) - info = Ref{BlasInt}() - ccall( - (@blasfunc($gesdd), libblastrampoline), - Cvoid, - ( - Ref{UInt8}, - Ref{BlasInt}, - Ref{BlasInt}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$relty}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$relty}, - Ptr{BlasInt}, - Ptr{BlasInt}, - Clong, - ), - job, - m, - n, - A, - max(1, stride(A, 2)), - S, - U, - max(1, stride(U, 2)), - Vt, - max(1, stride(Vt, 2)), - work, - lwork, - rwork, - iwork, - info, - 1, - ) - chklapackerror(info[]) - # Work around issue with truncated Float32 representation of lwork in - # sgesdd by using nextfloat. See - # http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036 - # and - # https://github.com/scipy/scipy/issues/5401 - lwork = round(BlasInt, nextfloat(real(work[1]))) - resize!(work, lwork) - rwork = Vector{$relty}(undef, 0) - return PSVD(U, S, Vt, work, iwork, rwork) - end - - # !!! this call destroys the contents of A - function psvd_dd!( - F::PSVD{$elty, $relty, M}, - A::StridedMatrix{$elty}; - full::Bool = false, - ) where {M} - job = full ? 'A' : 'S' - m, n = size(A) - minmn = min(m, n) - @assert length(F.S) == minmn - @assert size(F.U) == job == 'A' ? (m, m) : (m, minmn) - @assert size(F.Vt) == job == 'A' ? (n, n) : (minmn, n) - info = Ref{BlasInt}() - lwork = length(F.work) - ccall( - (@blasfunc($gesdd), libblastrampoline), - Cvoid, - ( - Ref{UInt8}, - Ref{BlasInt}, - Ref{BlasInt}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$relty}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$elty}, - Ref{BlasInt}, - Ptr{$relty}, - Ptr{BlasInt}, - Ptr{BlasInt}, - Clong, - ), - job, - m, - n, - A, - max(1, stride(A, 2)), - S, - U, - max(1, stride(U, 2)), - VT, - max(1, stride(VT, 2)), - F.work, - lwork, - F.rwork, - F.iwork, - info, - 1, - ) - chklapackerror(info[]) - return F - end - end -end - -function psvd( - A::StridedMatrix{T}; - full::Bool = false, - alg::Algorithm = default_svd_alg(A), -) where {T <: BlasFloat} - m, n = size(A) - if m == 0 || n == 0 - u, s, vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T, 0)), Matrix{T}(I, n, n)) - Tr = real(T) - return PSVD(u, s, vt, T[], BlasInt[], Tr[]) - else - if typeof(alg) <: LinearAlgebra.QRIteration - F = psvd_workspace_qr(A, full = full) - return psvd_qr!(F, copy(A), full = full) - else - F = psvd_workspace_dd(A, full = full) - return psvd_dd!(F, copy(A), full = full) - end - end -end diff --git a/src/shiftedGroupNormL2Binf.jl b/src/shiftedGroupNormL2Binf.jl index 162e4e8..6bcf05d 100644 --- a/src/shiftedGroupNormL2Binf.jl +++ b/src/shiftedGroupNormL2Binf.jl @@ -38,15 +38,6 @@ function (ψ::ShiftedGroupNormL2Binf)(y) return ψ.h(ψ.xsy) + indball_val end -shifted( - h::GroupNormL2{R, RR, I}, - xk::AbstractVector{R}, - Δ::R, - χ::Conjugate{IndBallL1{R}}, -) where {R <: Real, RR <: AbstractVector{R}, I} = - ShiftedGroupNormL2Binf(h, xk, zero(xk), Δ, χ, false) -shifted(h::NormL2{R}, xk::AbstractVector{R}, Δ::R, χ::Conjugate{IndBallL1{R}}) where {R <: Real} = - ShiftedGroupNormL2Binf(GroupNormL2([h.lambda]), xk, zero(xk), Δ, χ, false) shifted( ψ::ShiftedGroupNormL2Binf{R, RR, I, V0, V1, V2}, sj::AbstractVector{R}, @@ -79,41 +70,98 @@ function prox!( } ψ.sol .= q .+ ψ.xk .+ ψ.sj ϵ = 1 ## sasha's initial guess - softthres(x, a) = sign.(x) .* max.(0, abs.(x) .- a) - l2prox(x, a) = max(0, 1 - a / norm(x)) .* x + for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) σλ = λ * σ - ## find root for each block - froot(n) = - n - norm( - σ .* softthres( - (ψ.sol[idx] ./ σ .- (n / (σ * (n - σλ))) .* ψ.xk[idx]), - ψ.Δ * (n / (σ * (n - σλ))), - ) .- ψ.sol[idx], - ) + + # Views for block data + @views begin + solb = ψ.sol[idx] + xkb = ψ.xk[idx] + sjb = ψ.sj[idx] + tmpb = ψ.xsy[1:length(solb)] + end + + # in-place soft threshold into tmpb: tmpb .= sign.(expr) .* max.(0, abs.(expr) .- a) + function softthres_block!(dest, a, nfactor) + @inbounds for i in eachindex(dest) + val = solb[i] / σ - nfactor * xkb[i] + dv = abs(val) - a + dest[i] = dv > 0 ? sign(val) * dv : zero(eltype(dest)) + end + end + + # compute froot using in-place operations + function froot(n) + nfac = n / (σ * (n - σλ)) + ath = ψ.Δ * nfac + softthres_block!(tmpb, ath, nfac) + # tmpb currently holds softthres(expr, ath) + @inbounds begin + # compute tmpb .-= solb (in-place) + s = zero(eltype(tmpb)) + for i in eachindex(tmpb) + tmpb[i] -= solb[i] + s += tmpb[i]^2 + end + return n - sqrt(s) + end + end + lmin = σλ * (1 + eps(R)) # lower bound fl = froot(lmin) ansatz = lmin + ϵ #ansatz for upper bound step = ansatz / (σ * (ansatz - σλ)) - zlmax = norm(softthres((ψ.sol[idx] ./ σ .- step .* ψ.xk[idx]), ψ.Δ * step)) - lmax = norm(ψ.sol[idx]) + σ * (zlmax + abs((ϵ - 1) / ϵ + 1) * λ * norm(ψ.xk[idx])) + # compute zlmax using in-place softthres + softthres_block!(ψ.xsy[1:length(solb)], ψ.Δ * step, step) + zlmax = 0.0 + @inbounds for i in 1:length(solb) + zlmax += ψ.xsy[i]^2 + end + zlmax = sqrt(zlmax) + + lmax = norm(solb) + σ * (zlmax + abs((ϵ - 1) / ϵ + 1) * λ * norm(xkb)) fm = froot(lmax) if fl * fm > 0 - y[idx] .= 0 + @inbounds for i in eachindex(idx) + y[idx[i]] = zero(eltype(y)) + end else n = fzero(froot, lmin, lmax) step = n / (σ * (n - σλ)) if abs(n - σλ) ≈ 0 - y[idx] .= 0 + @inbounds for i in eachindex(idx) + y[idx[i]] = zero(eltype(y)) + end else - y[idx] .= l2prox( - ψ.sol[idx] .- σ .* softthres((ψ.sol[idx] ./ σ .- step .* ψ.xk[idx]), ψ.Δ * step), - σλ, - ) + # compute solb .- σ .* softthres(... ) into tmpb + nfac = step + ath = ψ.Δ * nfac + @inbounds for i in eachindex(solb) + val = solb[i] / σ - nfac * xkb[i] + dv = abs(val) - ath + tmpb[i] = dv > 0 ? sign(val) * dv : zero(eltype(tmpb)) + end + @inbounds for i in eachindex(tmpb) + tmpb[i] = solb[i] - σ * tmpb[i] + end + # apply l2prox in-place into y[idx] + s = zero(eltype(tmpb)) + @inbounds for i in eachindex(tmpb) + s += tmpb[i]^2 + end + s = sqrt(s) + factor = s == 0 ? zero(eltype(s)) : max(0, 1 - σλ / s) + @inbounds for i in eachindex(tmpb) + y[idx[i]] = factor * tmpb[i] + end end end - y[idx] .-= (ψ.xk[idx] + ψ.sj[idx]) + # subtract shifts in-place + @inbounds for (k, gi) in enumerate(idx) + y[gi] -= (ψ.xk[gi] + ψ.sj[gi]) + end end return y end diff --git a/src/shiftedGroupNormL2Box.jl b/src/shiftedGroupNormL2Box.jl new file mode 100644 index 0000000..3c30bbb --- /dev/null +++ b/src/shiftedGroupNormL2Box.jl @@ -0,0 +1,164 @@ +export ShiftedGroupNormL2Box + +mutable struct ShiftedGroupNormL2Box{ + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, + V3, + V4, + VI <: AbstractArray{<:Integer}, +} <: ShiftedProximableFunction + h::GroupNormL2{R, RR, I} + xk::V0 + sj::V1 + sol::V2 + l::V3 + u::V4 + shifted_twice::Bool + selected::VI + xsy::V2 + + function ShiftedGroupNormL2Box( + h::GroupNormL2{R, RR, I}, + xk::AbstractVector{R}, + sj::AbstractVector{R}, + l, + u, + shifted_twice::Bool, + selected::AbstractArray{T}, + ) where {R <: Real, RR <: AbstractVector{R}, I, T <: Integer} + sol = similar(sj) + xsy = similar(xk, length(selected)) + if any(l .> u) + error("At least one lower bound is greater than the upper bound.") + end + new{R, RR, I, typeof(xk), typeof(sj), typeof(sol), typeof(l), typeof(u), typeof(selected)}( + h, + xk, + sj, + sol, + l, + u, + shifted_twice, + selected, + xsy, + ) + end +end + +shifted( + h::GroupNormL2{R, RR, I}, + xk::AbstractVector{R}, + l, + u, + selected::AbstractArray{T} = 1:length(xk), +) where {R <: Real, RR <: AbstractVector{R}, I, T <: Integer} = ShiftedGroupNormL2Box(h, xk, zero(xk), l, u, false, selected) + +shifted( + h::NormL2{R}, + xk::AbstractVector{R}, + l, + u, + selected::AbstractArray{T} = 1:length(xk), +) where {R <: Real, T <: Integer} = ShiftedGroupNormL2Box(GroupNormL2([h.lambda], [1:length(xk)]), xk, zero(xk), l, u, false, selected) + +# Backward compatibility: Convert Binf constraints (Δ, χ) to Box constraints [-Δ, Δ] +shifted( + h::GroupNormL2{R, RR, I}, + xk::AbstractVector{R}, + Δ::R, + χ::Conjugate{IndBallL1{R}}, + selected::AbstractArray{T} = 1:length(xk), +) where {R <: Real, RR <: AbstractVector{R}, I, T <: Integer} = ShiftedGroupNormL2Box(h, xk, zero(xk), -Δ, Δ, false, selected) + +shifted( + h::NormL2{R}, + xk::AbstractVector{R}, + Δ::R, + χ::Conjugate{IndBallL1{R}}, + selected::AbstractArray{T} = 1:length(xk), +) where {R <: Real, T <: Integer} = ShiftedGroupNormL2Box(GroupNormL2([h.lambda], [1:length(xk)]), xk, zero(xk), -Δ, Δ, false, selected) + +shifted( + ψ::ShiftedGroupNormL2Box{R, RR, I, V0, V1, V2, V3, V4, VI}, + sj::AbstractVector{R}, +) where {R <: Real, RR <: AbstractVector{R}, I, V0 <: AbstractVector{R}, V1 <: AbstractVector{R}, V2 <: AbstractVector{R}, V3, V4, VI <: AbstractArray{<:Integer}} = + ShiftedGroupNormL2Box(ψ.h, ψ.xk, sj, ψ.l, ψ.u, true, ψ.selected) + +function (ψ::ShiftedGroupNormL2Box)(y) + @. ψ.xsy = @views ψ.xk[ψ.selected] + ψ.sj[ψ.selected] + y[ψ.selected] + val = ψ.h(ψ.xsy) + ϵ = √eps(eltype(y)) + for i ∈ eachindex(y) + lower = isa(ψ.l, Real) ? ψ.l : ψ.l[i] + upper = isa(ψ.u, Real) ? ψ.u : ψ.u[i] + if !(lower - ϵ ≤ ψ.sj[i] + y[i] ≤ upper + ϵ) + return Inf + end + end + return val +end + +fun_name(ψ::ShiftedGroupNormL2Box) = "shifted ∑ᵢ‖⋅‖₂ norm with box indicator" +fun_expr(ψ::ShiftedGroupNormL2Box) = "t ↦ ∑ᵢ ‖xk + sj + t‖₂ + χ({sj + t .∈ [l,u]})" +fun_params(ψ::ShiftedGroupNormL2Box) = + "xk = $(ψ.xk)\n" * " "^14 * "sj = $(ψ.sj)\n" * " "^14 * "lb = $(ψ.l)\n" * " "^14 * "ub = $(ψ.u)" + +function prox!( + y::AbstractVector{R}, + ψ::ShiftedGroupNormL2Box{R, RR, I, V0, V1, V2, V3, V4, VI}, + q::AbstractVector{R}, + σ::R, +) where { + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, + V3, + V4, + VI <: AbstractArray{<:Integer}, +} + ψ.sol .= q .+ ψ.xk .+ ψ.sj + + # buffer to reuse for block computations + tmp = similar(ψ.sol) + + for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) + σλ = λ * σ + @views begin + solb = ψ.sol[idx] + xkb = ψ.xk[idx] + sjb = ψ.sj[idx] + end + + # compute tmpb = solb .- xkb .- sjb + tmpb = tmp[1:length(solb)] + @inbounds for i in eachindex(solb) + tmpb[i] = solb[i] - xkb[i] - sjb[i] + end + + # l2prox in-place into tmpb + s = zero(eltype(tmpb)) + @inbounds for i in eachindex(tmpb) + s += tmpb[i]^2 + end + s = sqrt(s) + factor = s == 0 ? zero(eltype(s)) : max(0, 1 - σλ / s) + @inbounds for i in eachindex(tmpb) + tmpb[i] = factor * tmpb[i] + end + + # Apply box constraints elementwise and write to y + @inbounds for (i, global_i) in enumerate(idx) + li = isa(ψ.l, Real) ? ψ.l : ψ.l[global_i] + ui = isa(ψ.u, Real) ? ψ.u : ψ.u[global_i] + y[global_i] = min(max(tmpb[i], li), ui) + end + end + return y +end \ No newline at end of file diff --git a/src/shiftedIndBallL0BInf.jl b/src/shiftedIndBallL0BInf.jl index bb84517..1238ae3 100644 --- a/src/shiftedIndBallL0BInf.jl +++ b/src/shiftedIndBallL0BInf.jl @@ -48,12 +48,6 @@ function (ψ::ShiftedIndBallL0BInf)(y) return ψ.h(ψ.xsy) + indball_val end -shifted( - h::IndBallL0{I}, - xk::AbstractVector{R}, - Δ::R, - χ::Conjugate{IndBallL1{R}}, -) where {I <: Integer, R <: Real} = ShiftedIndBallL0BInf(h, xk, zero(xk), Δ, χ, false) shifted( ψ::ShiftedIndBallL0BInf{I, R, V0, V1, V2}, sj::AbstractVector{R}, diff --git a/src/shiftedIndBallL0Box.jl b/src/shiftedIndBallL0Box.jl new file mode 100644 index 0000000..aaa366f --- /dev/null +++ b/src/shiftedIndBallL0Box.jl @@ -0,0 +1,128 @@ +export ShiftedIndBallL0Box + +mutable struct ShiftedIndBallL0Box{ + I <: Integer, + R <: Real, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, + V3, + V4, + VI <: AbstractArray{<:Integer}, +} <: ShiftedProximableFunction + h::IndBallL0{I} + xk::V0 + sj::V1 + sol::V2 + p::Vector{Int} + l::V3 + u::V4 + shifted_twice::Bool + selected::VI + xsy::V2 + + function ShiftedIndBallL0Box( + h::IndBallL0{I}, + xk::AbstractVector{R}, + sj::AbstractVector{R}, + l, + u, + shifted_twice::Bool, + selected::AbstractArray{T}, + ) where {I <: Integer, R <: Real, T <: Integer} + sol = similar(sj) + xsy = similar(xk, length(selected)) + if any(l .> u) + error("At least one lower bound is greater than the upper bound.") + end + new{I, R, typeof(xk), typeof(sj), typeof(sol), typeof(l), typeof(u), typeof(selected)}( + h, + xk, + sj, + sol, + Vector{Int}(undef, length(sj)), + l, + u, + shifted_twice, + selected, + xsy, + ) + end +end + +shifted( + h::IndBallL0{I}, + xk::AbstractVector{R}, + l, + u, + selected::AbstractArray{T} = 1:length(xk), +) where {I <: Integer, R <: Real, T <: Integer} = ShiftedIndBallL0Box(h, xk, zero(xk), l, u, false, selected) + +# Backward compatibility: Convert Binf constraints (Δ, χ) to Box constraints [-Δ, Δ] +shifted( + h::IndBallL0{I}, + xk::AbstractVector{R}, + Δ::R, + χ::Conjugate{IndBallL1{R}}, + selected::AbstractArray{T} = 1:length(xk), +) where {I <: Integer, R <: Real, T <: Integer} = ShiftedIndBallL0Box(h, xk, zero(xk), -Δ, Δ, false, selected) + +shifted( + ψ::ShiftedIndBallL0Box{I, R, V0, V1, V2, V3, V4, VI}, + sj::AbstractVector{R}, +) where {I <: Integer, R <: Real, V0 <: AbstractVector{R}, V1 <: AbstractVector{R}, V2 <: AbstractVector{R}, V3, V4, VI <: AbstractArray{<:Integer}} = + ShiftedIndBallL0Box(ψ.h, ψ.xk, sj, ψ.l, ψ.u, true, ψ.selected) + +function (ψ::ShiftedIndBallL0Box)(y) + @. ψ.xsy = @views ψ.xk[ψ.selected] + ψ.sj[ψ.selected] + y[ψ.selected] + val = ψ.h(ψ.xsy) + ϵ = √eps(eltype(y)) + for i ∈ eachindex(y) + lower = isa(ψ.l, Real) ? ψ.l : ψ.l[i] + upper = isa(ψ.u, Real) ? ψ.u : ψ.u[i] + if !(lower - ϵ ≤ ψ.sj[i] + y[i] ≤ upper + ϵ) + return Inf + end + end + return val +end + +fun_name(ψ::ShiftedIndBallL0Box) = "shifted L0 norm ball with box indicator" +fun_expr(ψ::ShiftedIndBallL0Box) = "t ↦ χ({‖xk + sj + t‖₀ ≤ r}) + χ({sj + t .∈ [l,u]})" +fun_params(ψ::ShiftedIndBallL0Box) = + "xk = $(ψ.xk)\n" * " "^14 * "sj = $(ψ.sj)\n" * " "^14 * "lb = $(ψ.l)\n" * " "^14 * "ub = $(ψ.u)" + +function prox!( + y::AbstractVector{R}, + ψ::ShiftedIndBallL0Box{I, R, V0, V1, V2, V3, V4, VI}, + q::AbstractVector{R}, + σ::R, +) where {I <: Integer, R <: Real, V0 <: AbstractVector{R}, V1 <: AbstractVector{R}, V2 <: AbstractVector{R}, V3, V4, VI <: AbstractArray{<:Integer}} + # y = ψ.xk + ψ.sj + q in-place + copyto!(y, q) + @inbounds for i in eachindex(y) + y[i] += ψ.xk[i] + ψ.sj[i] + end + # find largest entries + sortperm!(ψ.p, y, rev = true, by = abs) # use ψ.p as placeholder + # set smallest to zero + for idx in ψ.p[(ψ.h.r + 1):end] + y[idx] = 0 + end + + # clip back to box around the base shift + @inbounds for i in eachindex(y) + li = isa(ψ.l, Real) ? ψ.l : ψ.l[i] + ui = isa(ψ.u, Real) ? ψ.u : ψ.u[i] + v = y[i] - (ψ.xk[i] + ψ.sj[i]) + if v < li + y[i] = li + elseif v > ui + y[i] = ui + else + y[i] = v + end + end + + return y +end \ No newline at end of file diff --git a/src/shiftedNormL1B2.jl b/src/shiftedNormL1B2.jl index 12ecdae..bd2c6e5 100644 --- a/src/shiftedNormL1B2.jl +++ b/src/shiftedNormL1B2.jl @@ -29,7 +29,28 @@ mutable struct ShiftedNormL1B2{ end end -(ψ::ShiftedNormL1B2)(y) = ψ.h(ψ.xk + ψ.sj + y) + IndBallL2(ψ.Δ)(ψ.sj + y) +@inline function _chi_norm(χ::NormL2{R}, v::AbstractVector{R}) where {R} + s2 = zero(R) + @inbounds for i in eachindex(v) + vi = v[i] + s2 += vi * vi + end + return χ.lambda * sqrt(s2) +end + +function (ψ::ShiftedNormL1B2)(y) + @inbounds for i in eachindex(ψ.xsy) + ψ.xsy[i] = ψ.xk[i] + ψ.sj[i] + y[i] + end + s2 = zero(eltype(y)) + @inbounds for i in eachindex(y) + v = ψ.sj[i] + y[i] + s2 += v * v + end + χy = ψ.χ.lambda * sqrt(s2) + ball_val = (χy <= ψ.Δ) ? zero(χy) : oftype(χy, Inf) + return ψ.h(ψ.xsy) + ball_val +end shifted(h::NormL1{R}, xk::AbstractVector{R}, Δ::R, χ::NormL2{R}) where {R <: Real} = ShiftedNormL1B2(h, xk, zero(xk), Δ, χ, false) @@ -50,15 +71,83 @@ function prox!( q::AbstractVector{R}, σ::R, ) where {R <: Real, V0 <: AbstractVector{R}, V1 <: AbstractVector{R}, V2 <: AbstractVector{R}} - ProjB(z) = min.(max.(z, ψ.sj .+ q .- ψ.λ * σ), ψ.sj .+ q .+ ψ.λ * σ) - froot(η) = η - ψ.χ(ProjB((-ψ.xk) .* (η / ψ.Δ))) + λ = ψ.λ + λχ = ψ.χ.lambda + + function projB!(dest::AbstractVector{R}, scale::R) + @inbounds for i in eachindex(dest) + lo = ψ.sj[i] + q[i] - λ * σ + hi = ψ.sj[i] + q[i] + λ * σ + zi = -(ψ.xk[i]) * scale + dest[i] = zi < lo ? lo : (zi > hi ? hi : zi) + end + return dest + end - y .= ProjB(-ψ.xk) + function chi_norm(v::AbstractVector{R}) + s2 = zero(R) + @inbounds for i in eachindex(v) + vi = v[i] + s2 += vi * vi + end + return λχ * sqrt(s2) + end + + projB!(y, one(R)) + + if ψ.Δ <= chi_norm(y) + froot(η) = begin + scale = η / ψ.Δ + projB!(ψ.sol, scale) + η - chi_norm(ψ.sol) + end + + f0 = froot(zero(R)) + fΔ = froot(ψ.Δ) + eta = zero(R) + if f0 == zero(R) + eta = zero(R) + elseif fΔ == zero(R) + eta = ψ.Δ + elseif f0 * fΔ < zero(R) + eta = find_zero(froot, (zero(R), ψ.Δ), Roots.Bisection()) + else + η0 = ψ.Δ / 2 + eta = try + find_zero(froot, η0) + catch e + @warn "Root finding failed: $e; falling back to Δ" exception=(e, catch_backtrace()) + ψ.Δ + end + end + + if eta == zero(R) + @inbounds for i in eachindex(y) + lo = ψ.sj[i] + q[i] - λ * σ + hi = ψ.sj[i] + q[i] + λ * σ + yi = zero(R) + y[i] = yi < lo ? lo : (yi > hi ? hi : yi) + end + else + scale = eta / ψ.Δ + projB!(y, scale) + s = ψ.Δ / eta + @inbounds for i in eachindex(y) + y[i] *= s + end + end + end + + @inbounds for i in eachindex(y) + y[i] -= ψ.sj[i] + end - if ψ.Δ ≤ ψ.χ(y) - η = find_zero(froot, ψ.Δ) - y .= ProjB((-ψ.xk) .* (η / ψ.Δ)) * (ψ.Δ / η) + χy = chi_norm(y) + if χy > ψ.Δ + s = ψ.Δ / χy + @inbounds for i in eachindex(y) + y[i] *= s + end end - y .-= ψ.sj return y end diff --git a/src/shiftedNuclearnorm.jl b/src/shiftedNuclearnorm.jl index 35f0614..8f6fb16 100644 --- a/src/shiftedNuclearnorm.jl +++ b/src/shiftedNuclearnorm.jl @@ -66,13 +66,23 @@ function prox!( V2 <: AbstractVector{R}, } λ = ψ.h.lambda - ψ.sol .= q .+ ψ.xk .+ ψ.sj - ψ.h.A .= reshape_array(ψ.sol, size(ψ.h.A)) + # ψ.sol = q + ψ.xk + ψ.sj (in-place to avoid temporaries) + copyto!(ψ.sol, q) + @inbounds for i in eachindex(ψ.sol) + ψ.sol[i] += ψ.xk[i] + ψ.sj[i] + end + # copy reshaped sol into A + copyto!(ψ.h.A, reshape_array(ψ.sol, size(ψ.h.A))) psvd_dd!(ψ.h.F, ψ.h.A, full = false) - ψ.h.F.S .= max.(0, ψ.h.F.S .- λ * σ) + # in-place positive thresholding + @inbounds for i in eachindex(ψ.h.F.S) + v = ψ.h.F.S[i] - λ * σ + ψ.h.F.S[i] = v > 0 ? v : zero(v) + end for i ∈ eachindex(ψ.h.F.S) + s = ψ.h.F.S[i] for j = 1:size(ψ.h.A, 1) - ψ.h.F.U[j, i] = ψ.h.F.U[j, i] .* ψ.h.F.S[i] + ψ.h.F.U[j, i] = ψ.h.F.U[j, i] * s end end mul!(ψ.h.A, ψ.h.F.U, ψ.h.F.Vt) diff --git a/src/shiftedRank.jl b/src/shiftedRank.jl index 5571b0d..b2bc518 100644 --- a/src/shiftedRank.jl +++ b/src/shiftedRank.jl @@ -66,16 +66,21 @@ function prox!( V2 <: AbstractVector{R}, } λ = ψ.h.lambda - ψ.sol .= q .+ ψ.xk .+ ψ.sj - ψ.h.A .= reshape_array(ψ.sol, size(ψ.h.A)) + # ψ.sol = q + ψ.xk + ψ.sj + copyto!(ψ.sol, q) + @inbounds for i in eachindex(ψ.sol) + ψ.sol[i] += ψ.xk[i] + ψ.sj[i] + end + copyto!(ψ.h.A, reshape_array(ψ.sol, size(ψ.h.A))) psvd_dd!(ψ.h.F, ψ.h.A, full = false) c = sqrt(2 * λ * σ) for i ∈ eachindex(ψ.h.F.S) - if ψ.h.F.S[i] <= c - ψ.h.F.U[:, i] .= 0 + si = ψ.h.F.S[i] + if si <= c + fill!(view(ψ.h.F.U, :, i), zero(si)) else for j = 1:size(ψ.h.A, 1) - ψ.h.F.U[j, i] = ψ.h.F.U[j, i] .* ψ.h.F.S[i] + ψ.h.F.U[j, i] = ψ.h.F.U[j, i] * si end end end diff --git a/test/runtests.jl b/test/runtests.jl index ed17918..c6b964b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -585,7 +585,7 @@ for (op, tr, shifted_op) ∈ zip( end # loop over operators with a trust region -for (op, tr, shifted_op) ∈ zip((:NormL2,), (:NormLinf,), (:ShiftedGroupNormL2Binf,)) +for (op, tr, shifted_op) ∈ zip((:NormL2,), (:NormLinf,), (:ShiftedGroupNormL2Box,)) @testset "$shifted_op" begin ShiftedOp = eval(shifted_op) Op = eval(op) @@ -599,10 +599,13 @@ for (op, tr, shifted_op) ∈ zip((:NormL2,), (:NormLinf,), (:ShiftedGroupNormL2B @test typeof(ψ) == ShiftedOp{ Float64, Vector{Float64}, - Vector{Colon}, + Vector{UnitRange{Int64}}, Vector{Float64}, Vector{Float64}, Vector{Float64}, + Float64, + Float64, + UnitRange{Int64}, } @test all(ψ.sj .== 0) @test all(ψ.xk .== x) @@ -635,7 +638,7 @@ for (op, tr, shifted_op) ∈ zip((:NormL2,), (:NormLinf,), (:ShiftedGroupNormL2B 0.010000000000000, ] s = ShiftedProximalOperators.prox(ψ, q, ν) - @test all(s .≈ s_correct) + @test all(isapprox.(s, s_correct, atol = 1.0e-4)) @test ψ.χ(s) ≤ ψ.Δ || ψ.χ(s) ≈ ψ.Δ # test shift update @@ -668,17 +671,20 @@ for (op, tr, shifted_op) ∈ zip((:NormL2,), (:NormLinf,), (:ShiftedGroupNormL2B @test typeof(ψ) == ShiftedOp{ Float32, Vector{Float32}, - Vector{Colon}, + Vector{UnitRange{Int64}}, SubArray{Float32, 1, Vector{Float32}, Tuple{StepRange{Int64, Int64}}, true}, Vector{Float32}, Vector{Float32}, + Float32, + Float32, + UnitRange{Int64}, } @test typeof(ψ.λ) == Vector{Float32} @test ψ.λ == [h.lambda] @test ψ(zeros(Float32, 5)) == h(x) end end -for (op, tr, shifted_op) ∈ zip((:GroupNormL2,), (:NormLinf,), (:ShiftedGroupNormL2Binf,)) +for (op, tr, shifted_op) ∈ zip((:GroupNormL2,), (:NormLinf,), (:ShiftedGroupNormL2Box,)) @testset "$shifted_op" begin ShiftedOp = eval(shifted_op) Op = eval(op) @@ -709,6 +715,9 @@ for (op, tr, shifted_op) ∈ zip((:GroupNormL2,), (:NormLinf,), (:ShiftedGroupNo Vector{Float64}, Vector{Float64}, Vector{Float64}, + Float64, + Float64, + UnitRange{Int64}, } @test all(ψ.sj .== 0) @test all(ψ.xk .== x) @@ -771,6 +780,9 @@ for (op, tr, shifted_op) ∈ zip((:GroupNormL2,), (:NormLinf,), (:ShiftedGroupNo SubArray{Float32, 1, Vector{Float32}, Tuple{StepRange{Int64, Int64}}, true}, Vector{Float32}, Vector{Float32}, + Float32, + Float32, + UnitRange{Int64}, } @test typeof(ψ.λ) == Vector{Float32} @test ψ.λ == h.lambda @@ -779,7 +791,7 @@ for (op, tr, shifted_op) ∈ zip((:GroupNormL2,), (:NormLinf,), (:ShiftedGroupNo end # loop over operators with a trust region -for (op, tr, shifted_op) ∈ zip((:IndBallL0,), (:NormLinf,), (:ShiftedIndBallL0BInf,)) +for (op, tr, shifted_op) ∈ zip((:IndBallL0,), (:NormLinf,), (:ShiftedIndBallL0Box,)) @testset "$shifted_op" begin ShiftedOp = eval(shifted_op) χ = eval(tr)(1.0) @@ -789,7 +801,7 @@ for (op, tr, shifted_op) ∈ zip((:IndBallL0,), (:NormLinf,), (:ShiftedIndBallL0 x = ones(3) Δ = 0.5 ψ = shifted(h, x, Δ, χ) - @test typeof(ψ) == ShiftedOp{Int64, Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}} + @test typeof(ψ) == ShiftedOp{Int64, Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, UnitRange{Int64}} @test all(ψ.xk .== x) @test typeof(ψ.r) == Int64 @test ψ.r == h.r @@ -836,6 +848,9 @@ for (op, tr, shifted_op) ∈ zip((:IndBallL0,), (:NormLinf,), (:ShiftedIndBallL0 SubArray{Float32, 1, Vector{Float32}, Tuple{StepRange{Int64, Int64}}, true}, Vector{Float32}, Vector{Float32}, + Float32, + Float32, + UnitRange{Int64}, } @test typeof(ψ.r) == Int32 @test ψ.r == h.r @@ -1222,6 +1237,112 @@ for (op, shifted_op) ∈ zip((:Nuclearnorm,), (:ShiftedNuclearnorm,)) end end +# Test the new generalized Box variants +@testset "ShiftedIndBallL0Box" begin + h = IndBallL0(2) + x = ones(4) + l = -0.5 + u = 0.5 + + # Test basic constructor with scalar bounds + ψ = shifted(h, x, l, u) + @test typeof(ψ) == ShiftedIndBallL0Box{Int64, Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, UnitRange{Int64}} + @test ψ.l == l + @test ψ.u == u + @test all(ψ.xk .== x) + @test ψ.r == h.r + + # Test function evaluation + @test ψ(zeros(4)) == h(x) + y = [0.1, -0.2, 0.3, -0.4] + @test ψ(y) == h(x + y) # y inside the box + + # Test out of bounds + y_out = [0.6, 0.0, 0.0, 0.0] # violates upper bound + @test ψ(y_out) == Inf + + # Test with vector bounds + l_vec = [-0.5, -0.3, -0.6, -0.4] + u_vec = [0.5, 0.3, 0.6, 0.4] + ψ2 = shifted(h, x, l_vec, u_vec) + @test ψ2.l == l_vec + @test ψ2.u == u_vec + + # Test backward compatibility with Binf (Δ, χ) + χ = Conjugate(IndBallL1(1.0)) + Δ = 0.3 + ψ3 = shifted(h, x, Δ, χ) + @test typeof(ψ3) == ShiftedIndBallL0Box{Int64, Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, UnitRange{Int64}} + @test ψ3.l == -Δ + @test ψ3.u == Δ + + # Test set_radius! and set_bounds! + set_radius!(ψ, 0.7) + @test ψ.l == -0.7 + @test ψ.u == 0.7 + + set_bounds!(ψ, -0.2, 0.8) + @test ψ.l == -0.2 + @test ψ.u == 0.8 +end + +@testset "ShiftedGroupNormL2Box" begin + v = [1:2, 3:4] + λ = [0.5, 0.8] + h = GroupNormL2(λ, v) + x = ones(4) + l = -0.4 + u = 0.6 + + # Test basic constructor with scalar bounds + ψ = shifted(h, x, l, u) + @test typeof(ψ) == ShiftedGroupNormL2Box{Float64, Vector{Float64}, Vector{UnitRange{Int64}}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, UnitRange{Int64}} + @test ψ.l == l + @test ψ.u == u + @test all(ψ.xk .== x) + @test ψ.h.lambda == λ + @test ψ.h.idx == v + + # Test function evaluation + @test ψ(zeros(4)) == h(x) + y = [0.1, -0.2, 0.2, -0.3] + @test ψ(y) == h(x + y) # y inside the box + + # Test out of bounds + y_out = [0.7, 0.0, 0.0, 0.0] # violates upper bound + @test ψ(y_out) == Inf + + # Test with vector bounds + l_vec = [-0.4, -0.3, -0.5, -0.2] + u_vec = [0.6, 0.4, 0.5, 0.3] + ψ2 = shifted(h, x, l_vec, u_vec) + @test ψ2.l == l_vec + @test ψ2.u == u_vec + + # Test backward compatibility with Binf (Δ, χ) + χ = Conjugate(IndBallL1(1.0)) + Δ = 0.25 + ψ3 = shifted(h, x, Δ, χ) + @test typeof(ψ3) == ShiftedGroupNormL2Box{Float64, Vector{Float64}, Vector{UnitRange{Int64}}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, UnitRange{Int64}} + @test ψ3.l == -Δ + @test ψ3.u == Δ + + # Test with NormL2 (single group case) + h_single = NormL2(0.7) + ψ4 = shifted(h_single, x, l, u) + @test typeof(ψ4) == ShiftedGroupNormL2Box{Float64, Vector{Float64}, Vector{UnitRange{Int64}}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, UnitRange{Int64}} + @test ψ4.h.lambda == [0.7] + + # Test set_radius! and set_bounds! + set_radius!(ψ, 0.9) + @test ψ.l == -0.9 + @test ψ.u == 0.9 + + set_bounds!(ψ, -0.1, 0.7) + @test ψ.l == -0.1 + @test ψ.u == 0.7 +end + include("testsbox.jl") include("partial_prox.jl") include("test_allocs.jl")