Skip to content

Commit 7f1bd88

Browse files
add unit tests for store previous Jacobian option
1 parent dca73f4 commit 7f1bd88

1 file changed

Lines changed: 28 additions & 1 deletion

File tree

test/runtests.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ for (op, composite_op, shifted_op) ∈
4848
y = similar(x)
4949
ν = 0.1056
5050
prox!(y, ϕ, x, ν)
51-
51+
@test ϕ.full_row_rank == true
52+
5253
if "$op" == "NormL2"
5354
y_true = [0.24545429, 0.75250248, -0.66619752, 1.19372286]
5455
norm = Op(1.0)
@@ -63,6 +64,31 @@ for (op, composite_op, shifted_op) ∈
6364
@test ϕ.A == SparseMatrixCOO(Float64[2 0 0 -1; 0 1 1 0])
6465
@test ϕ(ones(Float64, 4)) ==
6566
h([1.0, 2.0] + SparseMatrixCOO(Float64[2 0 0 -1; 0 1 1 0])*ones(Float64, 4))
67+
68+
# test store previous Jacobian option
69+
70+
function c_store_prev!(z, x)
71+
z[1] = 2*x[1]^2 - x[4]
72+
z[2] = x[2] + x[3]
73+
end
74+
function J_store_prev!(z, x)
75+
z.vals .= Float64[4.0*x[1], 1.0, 1.0, -1.0]
76+
end
77+
ψ_store_prev = CompositeOp(λ, c_store_prev!, J_store_prev!, A, b, store_previous_jacobian = true)
78+
@test ψ_store_prev.store_previous_jacobian == true
79+
80+
xk = [1.0, 0.0, 0.0, 0.0]
81+
ϕ_store_prev = shifted(ψ_store_prev, xk)
82+
@test !isnothing(ϕ_store_prev.A_prev)
83+
84+
@test all(ϕ_store_prev.A_prev .== ϕ_store_prev.A) # A_prev should be a copied version of A at this point
85+
@test all(ϕ_store_prev.A .== SparseMatrixCOO([4.0 0.0 0.0 -1.0; 0.0 1.0 1.0 0.0]))
86+
87+
xk = [2.0, 0.0, 0.0, 0.0]
88+
shift!(ϕ_store_prev, xk)
89+
@test all(ϕ_store_prev.A_prev .== SparseMatrixCOO([4.0 0.0 0.0 -1.0; 0.0 1.0 1.0 0.0]))
90+
@test all(ϕ_store_prev.A .== SparseMatrixCOO([8.0 0.0 0.0 -1.0; 0.0 1.0 1.0 0.0]))
91+
6692

6793
# test different types
6894
h = Op(Float32(λ))
@@ -101,6 +127,7 @@ for (op, composite_op, shifted_op) ∈
101127
y = similar(x)
102128
ν = 0.1056
103129
prox!(y, ϕ, x, ν)
130+
@test ϕ.full_row_rank == false
104131
if "$op" == "NormL2"
105132
y_true = [0.33642, 1.1287, -0.29, 1.14824]
106133
norm = Op(1.0)

0 commit comments

Comments
 (0)