Skip to content

Commit fc46398

Browse files
authored
Fix type for LeastSquares gradient (#134)
* add gradient test util * fix function value type
1 parent c9a250f commit fc46398

16 files changed

+44
-40
lines changed

src/functions/leastSquaresDirect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function gradient!(y::AbstractArray{C, N}, f::LeastSquaresDirect{N, R, C, M, V,
128128
f.res .-= f.b
129129
mul!(y, adjoint(f.A), f.res)
130130
y .*= f.lambda
131-
fy = (f.lambda/2)*dot(f.res, f.res)
131+
return (f.lambda / 2) * real(dot(f.res, f.res))
132132
end
133133

134134
function prox_naive(f::LeastSquaresDirect{N, R, C}, x::AbstractArray{C, N}, gamma::R=R(1)) where {N, R, C <: RealOrComplex{R}}

src/functions/leastSquaresIterative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function gradient!(y::AbstractArray{D, N}, f::LeastSquaresIterative{N, R, RC, M,
7070
f.res .-= f.b
7171
mul!(y, adjoint(f.A), f.res)
7272
y .*= f.lambda
73-
fy = (f.lambda/2)*dot(f.res, f.res)
73+
return (f.lambda / 2) * real(dot(f.res, f.res))
7474
end
7575

7676
function prox_naive(f::LeastSquaresIterative{N}, x::AbstractArray{D, N}, gamma::R=R(1)) where {N, R, D <: RealOrComplex{R}}

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ function prox_test(f, x::ArrayOrTuple{R}, gamma=R(1)) where R <: Real
6767
return y, fy
6868
end
6969

70+
# tests equality of the results of prox, prox! and prox_naive
71+
function gradient_test(f, x::ArrayOrTuple{R}, gamma=R(1)) where R <: Real
72+
grad_fx, fx = gradient(f, x)
73+
@test typeof(fx) == R
74+
return grad_fx, fx
75+
end
76+
7077
# test predicates consistency
7178
# i.e., that more specific properties imply less specific ones
7279
# e.g., the indicator of a subspace is the indicator of a set in particular

test/test_cubeNormL2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ for R in [Float16, Float32, Float64]
1414
call_test(f, x)
1515
gamma = R(0.5)+rand(R)
1616
y, f_y = prox_test(f, x, gamma)
17-
grad_f_y, f_y = gradient(f, y)
17+
grad_f_y, f_y = gradient_test(f, y)
1818
@test grad_f_y (x - y)/gamma
1919
end
2020
end

test/test_gradients.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,13 @@ for i = 1:length(stuff)
149149
ref_∇f = stuff[i]["∇f(x)"]
150150

151151
ref_fx = f(x)
152-
∇f = similar(x)
153-
fx = gradient!(∇f, f, x)
152+
∇f, fx = gradient_test(f, x)
154153
@test fx ref_fx
155154
@test ∇f ref_∇f
156155

157156
for j = 1:11
158157
#For initial point x and 10 other random points
159-
fx = gradient!(∇f, f, x)
158+
∇f, fx = gradient_test(f, x)
160159
for k = 1:10
161160
# Test conditions in different directions
162161
if ProximalOperators.is_convex(f)

test/test_huberLoss.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ x = 1.6*x/norm(x)
1717

1818
call_test(f, x)
1919
prox_test(f, x, 1.3)
20-
grad_fx, fx = gradient(f, x)
20+
grad_fx, fx = gradient_test(f, x)
2121

2222
@test abs(fx - f(x)) <= 1e-12
2323
@test norm(0.7*1.5*x/norm(x) - grad_fx, Inf) <= 1e-12
@@ -27,7 +27,7 @@ x = 1.4*x/norm(x)
2727

2828
call_test(f, x)
2929
prox_test(f, x, 0.9)
30-
grad_fx, fx = gradient(f, x)
30+
grad_fx, fx = gradient_test(f, x)
3131

3232
@test abs(fx - f(x)) <= 1e-12
3333
@test norm(0.7*x - grad_fx, Inf) <= 1e-12

test/test_leastSquares.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ predicates_test(f)
3838
@test ProximalOperators.is_generalized_quadratic(f) == true
3939
@test ProximalOperators.is_set(f) == false
4040

41-
grad_fx, fx = gradient(f, x)
41+
grad_fx, fx = gradient_test(f, x)
4242
lsres = A*x - b
4343
@test fx 0.5*norm(lsres)^2
4444
@test all(grad_fx .≈ (A'*lsres))
@@ -51,7 +51,7 @@ lam = R(0.1) + rand(R)
5151
f = LeastSquares(A, b, lam, iterative=(mode == :iterative))
5252
predicates_test(f)
5353

54-
grad_fx, fx = gradient(f, x)
54+
grad_fx, fx = gradient_test(f, x)
5555
@test fx (lam/2)*norm(lsres)^2
5656
@test all(grad_fx .≈ lam*(A'*lsres))
5757

test/test_linear.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ for R in [Float16, Float32, Float64]
1111
f = Linear(c)
1212
predicates_test(f)
1313
x = randn(R, shape)
14-
@test gradient(f, x) == (c, f(x))
14+
@test gradient_test(f, x) == (c, f(x))
1515
call_test(f, x)
1616
prox_test(f, x, R(0.5)+rand(R))
1717
end

test/test_logisticLoss.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ f_x_1 = f(x)
1515

1616
@test typeof(f_x_1) == T
1717

18-
grad_f_x, f_x_2 = gradient(f, x)
18+
grad_f_x, f_x_2 = gradient_test(f, x)
1919

2020
f_x_ref = 5.893450123044199
2121
grad_f_x_ref = [-1.0965878679450072, 0.17880438303317633, -0.07113880976635019, 1.3211956169668235, -0.4034121320549927]
@@ -25,13 +25,13 @@ grad_f_x_ref = [-1.0965878679450072, 0.17880438303317633, -0.07113880976635019,
2525
@test all(grad_f_x .≈ grad_f_x_ref)
2626

2727
z1, f_z1 = prox(f, x)
28-
grad_f_z1, = gradient(f, z1)
28+
grad_f_z1, = gradient_test(f, z1)
2929

3030
@test typeof(f_z1) == T
3131
@test norm((x - z1)./1.0 - grad_f_z1, Inf)/norm(grad_f_z1, Inf) <= 1e-4
3232

3333
z2, f_z2 = prox(f, x, T(2.0))
34-
grad_f_z2, = gradient(f, z2)
34+
grad_f_z2, = gradient_test(f, z2)
3535

3636
@test typeof(f_z2) == T
3737
@test norm((x - z2)./2.0 - grad_f_z2, Inf)/norm(grad_f_z2, Inf) <= 1e-4

test/test_moreauEnvelope.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ using LinearAlgebra
2020

2121
x = R[1.0, 2.0, 3.0, 4.0, 5.0]
2222

23-
grad_g_x, g_x = gradient(g, x)
23+
grad_g_x, g_x = gradient_test(g, x)
2424

2525
y, g_y = prox_test(g, x, R(1/2))
26-
grad_g_y, _ = gradient(g, y)
26+
grad_g_y, _ = gradient_test(g, y)
2727

2828
@test y + grad_g_y / 2 x
2929
@test g(y) g_y
@@ -48,15 +48,15 @@ end
4848

4949
@test g(x) h(x)
5050

51-
grad_g_x, g_x = gradient(g, x)
52-
grad_h_x, h_x = gradient(h, x)
51+
grad_g_x, g_x = gradient_test(g, x)
52+
grad_h_x, h_x = gradient_test(h, x)
5353

5454
@test g_x g(x)
5555
@test h_x h(x)
5656
@test all(grad_g_x .≈ grad_h_x)
5757

5858
y, g_y = prox_test(g, x, R(1/2))
59-
grad_g_y, _ = gradient(g, y)
59+
grad_g_y, _ = gradient_test(g, y)
6060

6161
@test y + grad_g_y / 2 x
6262
@test g(y) g_y

0 commit comments

Comments
 (0)