Skip to content

Commit b6711df

Browse files
Add the special cases for di<0 and di=0, along with corresponding test
1 parent 394725d commit b6711df

3 files changed

Lines changed: 36 additions & 10 deletions

File tree

src/shiftedNormL0.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,22 @@ function iprox!(
6767
λ = ψ.h.lambda
6868
for i eachindex(y)
6969
di = d[i]
70-
if d[i] 0
71-
y[i] = - 1/eps(R)
72-
continue
70+
gi = g[i]
71+
if di < 0
72+
y[i] = -Inf
73+
elseif di == 0
74+
if gi == 0
75+
y[i] = -ψ.xk[i] - ψ.sj[i]
76+
else
77+
y[i] = sign(gi) * Inf
78+
end
7379
else
7480
ci = sqrt(2 * λ * di)
7581
xps = ψ.xk[i] + ψ.sj[i]
76-
if abs(di * xps - g[i]) ci
82+
if abs(di * xps - gi) ci
7783
y[i] = -xps
7884
else
79-
y[i] = -g[i] / di
85+
y[i] = -gi / di
8086
end
8187
end
8288
end

src/shiftedNormL1.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,16 @@ function iprox!(
6767
@. y = -ψ.xk - ψ.sj
6868

6969
for i eachindex(y)
70-
if d[i] < 0
71-
y[i] = - 1/eps(R)
72-
continue
70+
di = d[i]
71+
gi = g[i]
72+
if di < 0
73+
y[i] = -Inf
74+
elseif di == 0
75+
if abs(gi) > λ
76+
y[i] = sign(gi) * Inf
77+
end
7378
else
74-
y[i] = min(max(y[i], -g[i] / d[i] - λ / d[i]), -g[i] / d[i] + λ / d[i])
79+
y[i] = min(max(y[i], -gi / di - λ / di), -gi / di + λ / di)
7580
end
7681
end
7782

test/partial_prox.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# test partial prox feature for operators that implement it
22
for op (:NormL0, :NormL1, :RootNormLhalf)
33
@testset "shifted $op with box partial prox" begin
4-
h = eval(op)(3.14)
4+
λ = 3.14
5+
h = eval(op)(λ)
56
n = 5
67
l = zeros(n)
78
u = ones(n)
@@ -58,6 +59,7 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf)
5859
# tests iprox without bounds
5960
if op == :NormL0 || op == :NormL1
6061
ψ = shifted(h, x)
62+
# test iprox with d > 0
6163
for d [ones(n), 2 * ones(n)]
6264
y = iprox(ψ, q, d)
6365
σ = d[1]
@@ -68,6 +70,19 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf)
6870
end
6971
end
7072
end
73+
# test iprox with d < 0
74+
for d [-ones(n), -2 * ones(n)]
75+
y = iprox(ψ, q, d)
76+
@test all(isinf.(y))
77+
end
78+
# test iprox with d = 0
79+
d = zeros(n)
80+
q1 =+ 1) * ones(n)
81+
y = iprox(ψ, q1, d)
82+
@test all(isinf.(y))
83+
q2 = zeros(n)
84+
y = iprox(ψ, q2, d)
85+
@test all(y .== -ψ.xk - ψ.sj)
7186
end
7287
end
7388
end

0 commit comments

Comments
 (0)