Skip to content

Commit 6c7dda4

Browse files
authored
Remove allocs l2 (#137)
remove allocs from ψ evaluation and prox! in shifted-L2, add tests.
1 parent 20e71ce commit 6c7dda4

3 files changed

Lines changed: 25 additions & 6 deletions

File tree

src/groupNormL2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ GroupNormL2(lambda::AbstractVector{R} = [one(R)], idx::I = [:]) where {R <: Real
3333
function (f::GroupNormL2)(x::AbstractArray{R}) where {R <: Real}
3434
sum_c = R(0)
3535
for (idx, λ) zip(f.idx, f.lambda)
36-
sum_c += λ * norm(x[idx])
36+
@views sum_c += λ * norm(x[idx])
3737
end
3838
return sum_c
3939
end

src/shiftedGroupNormL2.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,18 @@ function prox!(
6262
V1 <: AbstractVector{R},
6363
V2 <: AbstractVector{R},
6464
}
65-
ψ.sol .= q + ψ.xk + ψ.sj
65+
@. ψ.sol = q + ψ.xk + ψ.sj
6666
for (idx, λ) zip.h.idx, ψ.h.lambda)
67-
snorm = norm.sol[idx])
67+
sol_idx = view.sol, idx)
68+
yv = view(y, idx)
69+
snorm = norm(sol_idx)
6870
if snorm == 0
6971
y[idx] .= 0
7072
else
71-
y[idx] .= max(1 - σ * λ / snorm, 0) .* ψ.sol[idx]
73+
α = max(1 - σ * λ / snorm, 0)
74+
@. yv = α * sol_idx
7275
end
7376
end
74-
y .-=.xk + ψ.sj)
77+
@. y -=.xk + ψ.sj)
7578
return y
7679
end

test/test_allocs.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ end
4040
CompositeOp = eval(composite_op)
4141

4242
function c!(z, x)
43-
z[1] = 2*x[1] - x[4]
43+
z[1] = 2 * x[1] - x[4]
4444
z[2] = x[2] + x[3]
4545
end
4646
function J!(z, x)
@@ -113,6 +113,22 @@ end
113113
@test @wrappedallocs(iprox!(y, ψ, y, d)) == 0
114114
end
115115

116+
for op (:NormL2,)
117+
h = eval(op)(1.0)
118+
n = 1000
119+
xk = rand(n)
120+
y = rand(n)
121+
d = rand(n)
122+
123+
@test @wrappedallocs(prox!(y, h, y, 1.0)) == 0
124+
125+
ψ = shifted(h, xk)
126+
127+
@test @wrappedallocs(ψ(y)) == 0
128+
129+
@test @wrappedallocs(prox!(y, ψ, y, 1.0)) == 0
130+
end
131+
116132
for (op, shifted_op) zip((:Rank, :Nuclearnorm), (:ShiftedRank, :ShiftedNuclearnorm))
117133
ShiftedOp = eval(shifted_op)
118134
Op = eval(op)

0 commit comments

Comments
 (0)