Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9e6d120
binf to box
arnavk23 Oct 1, 2025
35ff431
Delete Manifest.toml
arnavk23 Oct 5, 2025
5c2d2b4
Update src/shiftedIndBallL0Box.jl
arnavk23 Oct 11, 2025
25fd6a9
Update src/shiftedGroupNormL2Box.jl
arnavk23 Oct 11, 2025
bcbd34f
Update src/shiftedGroupNormL2Box.jl
arnavk23 Oct 11, 2025
17dfe64
precompile
arnavk23 Oct 11, 2025
aa30ee7
Merge branch 'master' into master
arnavk23 Oct 11, 2025
08a2fd4
changes to be committed:
arnavk23 Oct 12, 2025
7cba1a6
Update src/ShiftedProximalOperators.jl
arnavk23 Oct 12, 2025
98e7e7c
Update src/shiftedNormL1B2.jl
arnavk23 Oct 12, 2025
226ae91
Update src/shiftedGroupNormL2Box.jl
arnavk23 Oct 12, 2025
b662925
psvd implementation
arnavk23 Oct 15, 2025
8d8854c
failing checks
arnavk23 Oct 15, 2025
54d09b8
Update src/ShiftedProximalOperators.jl
arnavk23 Oct 15, 2025
59cb443
shifted failing checks
arnavk23 Oct 15, 2025
fa44c5f
buffer
arnavk23 Oct 15, 2025
299013b
Update src/psvd.jl
arnavk23 Oct 15, 2025
d9a25a3
Update src/psvd.jl
arnavk23 Oct 15, 2025
c60aaa5
Update src/psvd.jl
arnavk23 Oct 15, 2025
3f74bac
Update src/psvd.jl
arnavk23 Oct 15, 2025
4735288
Update src/psvd.jl
arnavk23 Oct 15, 2025
d7c0da1
Update src/psvd.jl
arnavk23 Oct 15, 2025
7f38bf4
Update src/psvd.jl
arnavk23 Oct 15, 2025
577f84c
Make Box variants canonical for IndBallL0 and GroupNormL2; reduce all…
arnavk23 Oct 15, 2025
50835ab
Reduce allocations in prox! implementations: in-place updates for nuc…
arnavk23 Oct 17, 2025
e9b9c12
Update src/psvd.jl
arnavk23 Oct 17, 2025
144be82
Refactor ShiftedNormL1B2 to reduce allocations: replace broadcasted o…
arnavk23 Oct 19, 2025
cbcbfb0
tests: align CompositeOp callbacks with PR #147 by renaming c!/J! to …
arnavk23 Oct 19, 2025
7ebd781
allocations for shiftedNormL1B2
arnavk23 Oct 27, 2025
228c2e8
test allocs for shifted NormL1 B2 proximal operator
arnavk23 Oct 27, 2025
d6e315a
Delete Manifest.toml
arnavk23 Oct 27, 2025
078ce25
review changes
arnavk23 Dec 17, 2025
9898610
Merge branch 'master' into fix/allocs-pr145
arnavk23 Dec 17, 2025
69d14e7
Update test_allocs.jl
arnavk23 Dec 17, 2025
37bc27f
review changes - 2
arnavk23 Dec 17, 2025
f11fcee
Update src/shiftedGroupNormL2Box.jl
arnavk23 Dec 17, 2025
a4368c8
Update src/shiftedIndBallL0Box.jl
arnavk23 Dec 17, 2025
0d32e4a
copilot additions
arnavk23 Dec 17, 2025
8c5d1a0
Revert "copilot additions"
arnavk23 Dec 17, 2025
d75c700
Merge branch 'master' into fix/allocs-pr145
arnavk23 Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/Nuclearnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,23 @@ function prox!(
x::AbstractVector{R},
gamma::R,
) where {R <: Real, S <: AbstractArray, T, Tr, M <: AbstractArray{T}}
f.A .= reshape_array(x, size(f.A))
# copy reshaped x into internal matrix A
copyto!(f.A, reshape_array(x, size(f.A)))
psvd_dd!(f.F, f.A, full = false)
c = sqrt(2 * f.lambda * gamma)
f.F.S .= max.(0, f.F.S .- f.lambda * gamma)
# in-place shrink singular values
@inbounds for i in eachindex(f.F.S)
v = f.F.S[i] - f.lambda * gamma
f.F.S[i] = v > 0 ? v : zero(v)
end
# scale U by singular values in-place
for i ∈ eachindex(f.F.S)
s = f.F.S[i]
for j = 1:size(f.A, 1)
f.F.U[j, i] = f.F.U[j, i] * f.F.S[i]
f.F.U[j, i] = f.F.U[j, i] * s
end
end
mul!(f.A, f.F.U, f.F.Vt)
y .= reshape_array(f.A, (size(y, 1), 1))
copyto!(y, reshape_array(f.A, (size(y, 1), 1)))
return y
end
41 changes: 38 additions & 3 deletions src/ShiftedProximalOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,27 @@ include("shiftedGroupNormL2.jl")
include("shiftedNormL1B2.jl")
include("shiftedNormL1Box.jl")
include("shiftedIndBallL0.jl")
include("shiftedIndBallL0BInf.jl")
include("shiftedIndBallL0Box.jl")
include("shiftedRootNormLhalfBox.jl")
include("shiftedGroupNormL2Binf.jl")
include("shiftedGroupNormL2Box.jl")
Comment thread
arnavk23 marked this conversation as resolved.
include("shiftedRank.jl")
include("shiftedCappedl1.jl")
include("shiftedNuclearnorm.jl")

function (ψ::ShiftedProximableFunction)(y)
@. ψ.xsy = ψ.xk + ψ.sj + y
return ψ.h(ψ.xsy)
h = ψ.h
if isa(h, NormL1)
return h.lambda * norm(ψ.xsy, 1)
elseif isa(h, NormL0)
return h.lambda * count(!iszero, ψ.xsy)
elseif isa(h, RootNormLhalf)
return h.lambda * sum(sqrt ∘ abs, ψ.xsy)
elseif isa(h, NormL2)
return h.lambda * norm(ψ.xsy)
else
return h(ψ.xsy)
end
end

function (ψ::ShiftedCompositeProximableFunction)(y)
Expand Down Expand Up @@ -98,6 +109,8 @@ end
set_radius!(ψ::ShiftedNormL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedNormL1Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedRootNormLhalfBox, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedIndBallL0Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)
set_radius!(ψ::ShiftedGroupNormL2Box, Δ::R) where {R <: Real} = set_bounds!(ψ, -Δ, Δ)

"""
set_bounds!(ψ, l, u)
Expand All @@ -116,6 +129,28 @@ end
return ψ.h.lambda
elseif prop === :r
return ψ.h.r
elseif prop === :Δ
# For Box variants, convert symmetric box constraints back to radius
if hasfield(typeof(ψ), :l) && hasfield(typeof(ψ), :u)
l = getfield(ψ, :l)
u = getfield(ψ, :u)
if isa(l, Real) && isa(u, Real) && l == -u
return u # Return radius when box is symmetric [-Δ, Δ]
elseif isa(l, AbstractVector) && isa(u, AbstractVector) && all(l .== -u)
return u[1] # Return radius when all elements are symmetric
else
error("Cannot convert asymmetric box constraints to radius Δ")
end
else
return getfield(ψ, prop) # Fall back to field access for Binf types
end
elseif prop === :χ
# For backward compatibility, provide a dummy χ for Box variants
if hasfield(typeof(ψ), :l) && hasfield(typeof(ψ), :u)
return Conjugate(IndBallL1(1.0)) # Dummy conjugate
else
return getfield(ψ, prop)
end
else
return getfield(ψ, prop)
end
Expand Down
185 changes: 32 additions & 153 deletions src/psvd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,38 @@ PSVD{T}(F::PSVD) where {T} = PSVD(
convert(AbstractMatrix{T}, F.Vt),
convert(AbstractVector{T}, F.work),
convert(AbstractVector{BlasInt}, F.iwork),
convert(AbstractVector{Tr}, F.rwork),
convert(AbstractVector{real(T)}, F.rwork),
)
Factorization{T}(F::PSVD) where {T} = PSVD{T}(F)

function psvd(
A::StridedMatrix{T};
full::Bool = false,
alg::Algorithm = default_svd_alg(A),
destructive::Bool = false,
) where {T <: BlasFloat}
m, n = size(A)
if m == 0 || n == 0
u = Matrix{T}(I, m, full ? m : n)
s = real(zeros(T, 0))
vt = Matrix{T}(I, n, n)
Tr = real(T)
return PSVD(u, s, vt, T[], BlasInt[], Tr[])
end

if typeof(alg) <: LinearAlgebra.QRIteration
F = psvd_workspace_qr(A, full = full)
return psvd_qr!(F, destructive ? A : copy(A); full = full)
else
F = psvd_workspace_dd(A, full = full)
return psvd_dd!(F, destructive ? A : copy(A); full = full)
end
end

# iteration for destructuring into components
Base.iterate(S::PSVD) = (S.U, Val(:S))
Base.iterate(S::PSVD, ::Val{:S}) = (S.S, Val(:V))
Base.iterate(S::PSVD, ::Val{:V}) = (S.V, Val(:done))
Base.iterate(S::PSVD, ::Val{:V}) = (S.Vt', Val(:done))
Base.iterate(S::PSVD, ::Val{:done}) = nothing

# Functions for alg = QRIteration()
Expand Down Expand Up @@ -149,7 +173,6 @@ for (gesvd, elty, relty) in ((:dgesvd_, :Float64, :Float64), (:sgesvd_, :Float32
) where {M}
jobuvt = full ? 'A' : 'S'
m, n = size(A)
m, n = size(A)
minmn = min(m, n)
@assert length(F.S) == minmn
@assert size(F.U) == (jobuvt == 'A' ? (m, m) : (m, minmn))
Expand Down Expand Up @@ -204,14 +227,15 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp
@eval begin
function psvd_workspace_qr(A::StridedMatrix{$elty}; full::Bool = false)
jobuvt = full ? 'A' : 'S'
m, n = size(A)
minmn = min(m, n)
S = similar(A, $relty, minmn)
U = similar(A, $elty, jobuvt == 'A' ? (m, m) : (m, minmn))
Vt = similar(A, $elty, jobuvt == 'A' ? (n, n) : (minmn, n))
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
info = Ref{BlasInt}()
rwork = Vector{R}(undef, 5minmn)
rwork = Vector{$relty}(undef, max(1, 5 * minmn))
ccall(
(@blasfunc($gesvd), libblastrampoline),
Cvoid,
Expand All @@ -234,8 +258,8 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp
Clong,
Clong,
),
jobu,
jobvt,
jobuvt,
jobuvt,
m,
n,
A,
Expand Down Expand Up @@ -295,8 +319,8 @@ for (gesvd, elty, relty) in ((:zgesvd_, :ComplexF64, :Float64), (:cgesvd_, :Comp
Clong,
Clong,
),
jobu,
jobvt,
jobuvt,
jobuvt,
m,
n,
A,
Expand Down Expand Up @@ -439,148 +463,3 @@ for (gesdd, elty, relty) in ((:dgesdd_, :Float64, :Float64), (:sgesdd_, :Float32
end
end
end

for (gesdd, elty, relty) in ((:zgesdd_, :ComplexF64, :Float64), (:cgesdd_, :ComplexF32, :Float32))
@eval begin
function psvd_workspace_dd(A::StridedMatrix{$elty}; full::Bool = false)
require_one_based_indexing(A)
chkstride1(A)
job = full ? 'A' : 'S'
m, n = size(A)
minmn = min(m, n)
U = similar(A, $elty, job == 'A' ? (m, m) : (m, minmn))
Vt = similar(A, $elty, job == 'A' ? (n, n) : (minmn, n))
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
S = similar(A, $relty, minmn)
rwork = Vector{$relty}(undef, minmn * max(5 * minmn + 7, 2 * max(m, n) + 2 * minmn + 1))
iwork = Vector{BlasInt}(undef, 8 * minmn)
info = Ref{BlasInt}()
ccall(
(@blasfunc($gesdd), libblastrampoline),
Cvoid,
(
Ref{UInt8},
Ref{BlasInt},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{BlasInt},
Ptr{BlasInt},
Clong,
),
job,
m,
n,
A,
max(1, stride(A, 2)),
S,
U,
max(1, stride(U, 2)),
Vt,
max(1, stride(Vt, 2)),
work,
lwork,
rwork,
iwork,
info,
1,
)
chklapackerror(info[])
# Work around issue with truncated Float32 representation of lwork in
# sgesdd by using nextfloat. See
# http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036
# and
# https://github.com/scipy/scipy/issues/5401
lwork = round(BlasInt, nextfloat(real(work[1])))
resize!(work, lwork)
rwork = Vector{$relty}(undef, 0)
return PSVD(U, S, Vt, work, iwork, rwork)
end

# !!! this call destroys the contents of A
function psvd_dd!(
F::PSVD{$elty, $relty, M},
A::StridedMatrix{$elty};
full::Bool = false,
) where {M}
job = full ? 'A' : 'S'
m, n = size(A)
minmn = min(m, n)
@assert length(F.S) == minmn
@assert size(F.U) == job == 'A' ? (m, m) : (m, minmn)
@assert size(F.Vt) == job == 'A' ? (n, n) : (minmn, n)
info = Ref{BlasInt}()
lwork = length(F.work)
ccall(
(@blasfunc($gesdd), libblastrampoline),
Cvoid,
(
Ref{UInt8},
Ref{BlasInt},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$relty},
Ptr{BlasInt},
Ptr{BlasInt},
Clong,
),
job,
m,
n,
A,
max(1, stride(A, 2)),
S,
U,
max(1, stride(U, 2)),
VT,
max(1, stride(VT, 2)),
F.work,
lwork,
F.rwork,
F.iwork,
info,
1,
)
chklapackerror(info[])
return F
end
end
end

function psvd(
A::StridedMatrix{T};
full::Bool = false,
alg::Algorithm = default_svd_alg(A),
) where {T <: BlasFloat}
m, n = size(A)
if m == 0 || n == 0
u, s, vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T, 0)), Matrix{T}(I, n, n))
Tr = real(T)
return PSVD(u, s, vt, T[], BlasInt[], Tr[])
else
if typeof(alg) <: LinearAlgebra.QRIteration
F = psvd_workspace_qr(A, full = full)
return psvd_qr!(F, copy(A), full = full)
else
F = psvd_workspace_dd(A, full = full)
return psvd_dd!(F, copy(A), full = full)
end
end
end
Loading