Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
535 changes: 535 additions & 0 deletions Manifest.toml

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions src/Nuclearnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,23 @@ function prox!(
x::AbstractVector{R},
gamma::R,
) where {R <: Real, S <: AbstractArray, T, Tr, M <: AbstractArray{T}}
f.A .= reshape_array(x, size(f.A))
# copy reshaped x into internal matrix A
copyto!(f.A, reshape_array(x, size(f.A)))
psvd_dd!(f.F, f.A, full = false)
c = sqrt(2 * f.lambda * gamma)
f.F.S .= max.(0, f.F.S .- f.lambda * gamma)
# in-place shrink singular values
@inbounds for i in eachindex(f.F.S)
v = f.F.S[i] - f.lambda * gamma
f.F.S[i] = v > 0 ? v : zero(v)
end
# scale U by singular values in-place
for i ∈ eachindex(f.F.S)
s = f.F.S[i]
for j = 1:size(f.A, 1)
f.F.U[j, i] = f.F.U[j, i] * f.F.S[i]
f.F.U[j, i] = f.F.U[j, i] * s
end
end
mul!(f.A, f.F.U, f.F.Vt)
y .= reshape_array(f.A, (size(y, 1), 1))
copyto!(y, reshape_array(f.A, (size(y, 1), 1)))
return y
end
67 changes: 63 additions & 4 deletions src/ShiftedProximalOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,51 @@ include("shiftedGroupNormL2.jl")
include("shiftedNormL1B2.jl")
include("shiftedNormL1Box.jl")
include("shiftedIndBallL0.jl")
include("shiftedIndBallL0BInf.jl")
include("shiftedIndBallL0Box.jl")
include("shiftedRootNormLhalfBox.jl")
include("shiftedGroupNormL2Binf.jl")
include("shiftedGroupNormL2Box.jl")
include("shiftedRank.jl")
include("shiftedCappedl1.jl")
include("shiftedNuclearnorm.jl")

function (ψ::ShiftedProximableFunction)(y)
@. ψ.xsy = ψ.xk + ψ.sj + y
return ψ.h(ψ.xsy)
# assign elementwise to avoid temporary allocations from broadcasted RHS
for i in eachindex(ψ.xsy)
ψ.xsy[i] = ψ.xk[i] + ψ.sj[i] + y[i]
end
# Fast, allocation-friendly evaluations for common proximable h types
h = ψ.h
if isa(h, NormL1)
λ = h.lambda
s = zero(eltype(ψ.xsy))
for i in eachindex(ψ.xsy)
s += abs(ψ.xsy[i])
end
return λ * s
elseif isa(h, NormL0)
λ = h.lambda
cnt = zero(Int)
for i in eachindex(ψ.xsy)
cnt += (ψ.xsy[i] == zero(eltype(ψ.xsy))) ? 0 : 1
end
return λ * cnt
elseif isa(h, RootNormLhalf)
λ = h.lambda
s = zero(eltype(ψ.xsy))
for i in eachindex(ψ.xsy)
s += sqrt(abs(ψ.xsy[i]))
end
return λ * s
elseif isa(h, NormL2)
λ = h.lambda
s = zero(eltype(ψ.xsy))
for i in eachindex(ψ.xsy)
s += ψ.xsy[i]^2
end
return λ * sqrt(s)
else
return h(ψ.xsy)
end
end

function (ψ::ShiftedCompositeProximableFunction)(y)
Expand Down Expand Up @@ -97,6 +132,8 @@ end
set_radius!(ψ::ShiftedNormL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedNormL1Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedRootNormLhalfBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedIndBallL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedGroupNormL2Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)

"""
set_bounds!(ψ, l, u)
Expand All @@ -115,6 +152,28 @@ end
return ψ.h.lambda
elseif prop === :r
return ψ.h.r
elseif prop === :Δ
# For Box variants, convert symmetric box constraints back to radius
if hasfield(typeof(ψ), :l) && hasfield(typeof(ψ), :u)
l = getfield(ψ, :l)
u = getfield(ψ, :u)
if isa(l, Real) && isa(u, Real) && l == -u
return u # Return radius when box is symmetric [-Δ, Δ]
elseif isa(l, AbstractVector) && isa(u, AbstractVector) && all(l .== -u)
return u[1] # Return radius when all elements are symmetric
else
error("Cannot convert asymmetric box constraints to radius Δ")
end
else
return getfield(ψ, prop) # Fall back to field access for Binf types
end
elseif prop === :χ
# For backward compatibility, provide a dummy χ for Box variants
if hasfield(typeof(ψ), :l) && hasfield(typeof(ψ), :u)
return Conjugate(IndBallL1(1.0)) # Dummy conjugate
else
return getfield(ψ, prop)
end
else
return getfield(ψ, prop)
end
Expand Down
185 changes: 32 additions & 153 deletions src/psvd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,38 @@ PSVD{T}(F::PSVD) where {T} = PSVD(
convert(AbstractMatrix{T}, F.Vt),
convert(AbstractVector{T}, F.work),
convert(AbstractVector{BlasInt}, F.iwork),
convert(AbstractVector{Tr}, F.rwork),
convert(AbstractVector{real(T)}, F.rwork),
)
Factorization{T}(F::PSVD) where {T} = PSVD{T}(F)

function psvd(
A::StridedMatrix{T};
full::Bool = false,
alg::Algorithm = default_svd_alg(A),
destructive::Bool = false,
) where {T <: BlasFloat}
m, n = size(A)
if m == 0 || n == 0
u = Matrix{T}(I, m, full ? m : n)
s = real(zeros(T, 0))
vt = Matrix{T}(I, n, n)
Tr = real(T)
return PSVD(u, s, vt, T[], BlasInt[], Tr[])
end

if typeof(alg) <: LinearAlgebra.QRIteration
F = psvd_workspace_qr(A, full = full)
return psvd_qr!(F, destructive ? A : copy(A); full = full)
else
F = psvd_workspace_dd(A, full = full)
return psvd_dd!(F, destructive ? A : copy(A); full = full)
end
end

# iteration for destructuring into components
Base.iterate(S::PSVD) = (S.U, Val(:S))
Base.iterate(S::PSVD, ::Val{:S}) = (S.S, Val(:V))
Base.iterate(S::PSVD, ::Val{:V}) = (S.V, Val(:done))
Base.iterate(S::PSVD, ::Val{:V}) = (S.Vt', Val(:done))
Base.iterate(S::PSVD, ::Val{:done}) = nothing

# Functions for alg = QRIteration()
Expand Down Expand Up @@ -149,7 +173,6 @@ for (gesvd, elty, relty) in ((:dgesvd_, :Float64, :Float64), (:sgesvd_, :Float32
) where {M}
jobuvt = full ? 'A' : 'S'
m, n = size(A)
m, n = size(A)
minmn = min(m, n)
@assert length(F.S) == minmn
@assert size(F.U) == (jobuvt == 'A' ? (m, m) : (m, minmn))
Expand Down Expand Up @@ -204,14 +227,15 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp
@eval begin
function psvd_workspace_qr(A::StridedMatrix{$elty}; full::Bool = false)
jobuvt = full ? 'A' : 'S'
m, n = size(A)
minmn = min(m, n)
S = similar(A, $relty, minmn)
U = similar(A, $elty, jobuvt == 'A' ? (m, m) : (m, minmn))
Vt = similar(A, $elty, jobuvt == 'A' ? (n, n) : (minmn, n))
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
info = Ref{BlasInt}()
rwork = Vector{R}(undef, 5minmn)
rwork = Vector{$relty}(undef, max(1, 5 * minmn))
ccall(
(@blasfunc($gesvd), libblastrampoline),
Cvoid,
Expand All @@ -234,8 +258,8 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp
Clong,
Clong,
),
jobu,
jobvt,
jobuvt,
jobuvt,
m,
n,
A,
Expand Down Expand Up @@ -295,8 +319,8 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp
Clong,
Clong,
),
jobu,
jobvt,
jobuvt,
jobuvt,
m,
n,
A,
Expand Down Expand Up @@ -439,148 +463,3 @@ for (gesdd, elty, relty) in ((:dgesdd_, :Float64, :Float64), (:sgesdd_, :Float32
end
end
end

for (gesdd, elty, relty) in ((:zgesdd_, :ComplexF64, :Float64), (:cgesdd_, :ComplexF32, :Float32))
@eval begin
function psvd_workspace_dd(A::StridedMatrix{$elty}; full::Bool = false)
require_one_based_indexing(A)
chkstride1(A)
job = full ? 'A' : 'S'
m, n = size(A)
minmn = min(m, n)
U = similar(A, $elty, job == 'A' ? (m, m) : (m, minmn))
Vt = similar(A, $elty, job == 'A' ? (n, n) : (minmn, n))
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
S = similar(A, $relty, minmn)
rwork = Vector{$relty}(undef, minmn * max(5 * minmn + 7, 2 * max(m, n) + 2 * minmn + 1))
iwork = Vector{BlasInt}(undef, 8 * minmn)
info = Ref{BlasInt}()
ccall(
(@blasfunc($gesdd), libblastrampoline),
Cvoid,
(
Ref{UInt8},
Ref{BlasInt},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{BlasInt},
Ptr{BlasInt},
Clong,
),
job,
m,
n,
A,
max(1, stride(A, 2)),
S,
U,
max(1, stride(U, 2)),
Vt,
max(1, stride(Vt, 2)),
work,
lwork,
rwork,
iwork,
info,
1,
)
chklapackerror(info[])
# Work around issue with truncated Float32 representation of lwork in
# sgesdd by using nextfloat. See
# http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036
# and
# https://github.com/scipy/scipy/issues/5401
lwork = round(BlasInt, nextfloat(real(work[1])))
resize!(work, lwork)
rwork = Vector{$relty}(undef, 0)
return PSVD(U, S, Vt, work, iwork, rwork)
end

# !!! this call destroys the contents of A
function psvd_dd!(
F::PSVD{$elty, $relty, M},
A::StridedMatrix{$elty};
full::Bool = false,
) where {M}
job = full ? 'A' : 'S'
m, n = size(A)
minmn = min(m, n)
@assert length(F.S) == minmn
@assert size(F.U) == job == 'A' ? (m, m) : (m, minmn)
@assert size(F.Vt) == job == 'A' ? (n, n) : (minmn, n)
info = Ref{BlasInt}()
lwork = length(F.work)
ccall(
(@blasfunc($gesdd), libblastrampoline),
Cvoid,
(
Ref{UInt8},
Ref{BlasInt},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{BlasInt},
Ptr{BlasInt},
Clong,
),
job,
m,
n,
A,
max(1, stride(A, 2)),
S,
U,
max(1, stride(U, 2)),
VT,
max(1, stride(VT, 2)),
F.work,
lwork,
F.rwork,
F.iwork,
info,
1,
)
chklapackerror(info[])
return F
end
end
end

function psvd(
A::StridedMatrix{T};
full::Bool = false,
alg::Algorithm = default_svd_alg(A),
) where {T <: BlasFloat}
m, n = size(A)
if m == 0 || n == 0
u, s, vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T, 0)), Matrix{T}(I, n, n))
Tr = real(T)
return PSVD(u, s, vt, T[], BlasInt[], Tr[])
else
if typeof(alg) <: LinearAlgebra.QRIteration
F = psvd_workspace_qr(A, full = full)
return psvd_qr!(F, copy(A), full = full)
else
F = psvd_workspace_dd(A, full = full)
return psvd_dd!(F, copy(A), full = full)
end
end
end
Loading