Skip to content

Commit bd92643

Browse files
committed
Some code polishing
1 parent ef2dd69 commit bd92643

3 files changed

Lines changed: 77 additions & 67 deletions

File tree

src/iterative_methods.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ function lyapci(A::AbstractMatrix, C::AbstractMatrix; abstol = zero(float(real(e
1515
LinearAlgebra.checksquare(C) == n ||
1616
throw(DimensionMismatch("C must be a square matrix of dimension $n"))
1717
sym = isreal(A) && isreal(C) && issymmetric(C)
18-
her = ishermitian(C)
18+
her = ishermitian(C)
1919
LT = lyapop(A; her = sym)
20-
20+
2121
if sym
2222
xt, info = cgls(LT,-triu2vec(C); abstol, reltol, maxiter)
2323
else
@@ -1232,7 +1232,7 @@ function cgls!(x, WS, A, b; shift = 0, abstol = 0, reltol = 1e-6, maxiter = max(
12321232

12331233
#s = A'*r-shift*x
12341234
mul!(s,adjointA,r)
1235-
shift == 0 || axpy!(-shift, x, s)
1235+
shift == 0 || axpy!(-shift, x, s)
12361236

12371237
# Initialize
12381238
p .= s
@@ -1251,7 +1251,6 @@ function cgls!(x, WS, A, b; shift = 0, abstol = 0, reltol = 1e-6, maxiter = max(
12511251
# Main loop
12521252
#--------------------------------------------------------------------------
12531253
while (k < maxiter) && (flag == 0)
1254-
12551254
k += 1
12561255

12571256
#q = A*p;
@@ -1286,6 +1285,7 @@ function cgls!(x, WS, A, b; shift = 0, abstol = 0, reltol = 1e-6, maxiter = max(
12861285

12871286
# Output
12881287
resNE = norms / norms0;
1288+
#@show k, resNE
12891289
isnan(resNE) && (resNE = zero(norms))
12901290

12911291
end # while

src/meoperators.jl

Lines changed: 31 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,6 @@ function LinearMaps._unsafe_mul!(y::AbstractVector, L::LyapunovMap{T,TA,Continuo
477477
T1 = promote_type(T, eltype(x))
478478
if L.her
479479
X = vec2triu(convert(AbstractVector{T1}, x), her=true)
480-
#y[:] = triu2vec(L.A*X + X*L.A')
481480
mulcsym!(y, L.A, X)
482481
else
483482
X = reshape(convert(AbstractVector{T1}, x), n, n)
@@ -495,8 +494,7 @@ for ttype in (LinearMaps.TransposeMap, LinearMaps.AdjointMap)
495494
T1 = promote_type(T, eltype(y))
496495
if L.lmap.her
497496
Y = vec2triu(convert(AbstractVector{T1}, y), her=false)
498-
#x[:] = triu2vec(L.lmap.A'*Y + Y*L.lmap.A)
499-
mulcsym!(x, L.lmap.A', Y, dual = true)
497+
mulcsym!(x, L.lmap.A, Y, dual = true)
500498
else
501499
Y = reshape(convert(AbstractVector{T1}, y), n, n)
502500
# (x[:] = (L.A'*Y + Y*L.A)[:])
@@ -507,16 +505,16 @@ for ttype in (LinearMaps.TransposeMap, LinearMaps.AdjointMap)
507505
return x
508506
end
509507
end
510-
511-
function mulcsym!(y::AbstractVector, A::AbstractMatrix, X::AbstractMatrix; dual = false)
508+
function mulcsym!(y::AbstractVector{T}, A::AbstractMatrix{T}, X::AbstractMatrix{T}; dual = false) where {T <: Real}
512509
require_one_based_indexing(y, A, X)
513510
# A*X + X*A'
514511
n = size(A, 1)
515-
if dual
516-
#Y = A*X+X*A'
512+
Y = similar(X, n, n)
513+
if dual
514+
#Y = A'*X+X*A
517515
Y = similar(X, n, n)
518-
mul!(Y, X, A')
519-
mul!(Y, A, X, true, true)
516+
mul!(Y, X, A)
517+
mul!(Y, A', X, true, true)
520518
# y[:] = triu2vec(Y+transpose(Y)-Diagonal(Y))
521519
@inbounds begin
522520
k = 1
@@ -527,20 +525,16 @@ function mulcsym!(y::AbstractVector, A::AbstractMatrix, X::AbstractMatrix; dual
527525
end
528526
end
529527
end
530-
return y
531-
end
532-
ZERO = zero(eltype(y))
533-
@inbounds begin
534-
k = 1
535-
for j = 1:n
536-
for i = 1:j
537-
temp = ZERO
538-
for l = 1:n
539-
temp += A[i,l] * X[l,j] + X[i,l] * conj(A[j,l])
540-
end
541-
y[k] = temp
542-
k += 1
543-
end
528+
else
529+
mul!(Y, A, X)
530+
@inbounds begin
531+
k = 1
532+
for j = 1:n
533+
for i = 1:j
534+
y[k] = Y[i,j] + Y[j,i]
535+
k += 1
536+
end
537+
end
544538
end
545539
end
546540
return y
@@ -707,9 +701,10 @@ for ttype in (LinearMaps.TransposeMap, LinearMaps.AdjointMap)
707701
n = size(L.lmap.A, 1)
708702
T1 = promote_type(T, eltype(y))
709703
if L.lmap.her
710-
Y = vec2triu(convert(AbstractVector{T1}, y), her=false)
704+
Y = UpperTriangular(vec2triu(convert(AbstractVector{T1}, y), her=false))
711705
#x[:] = triu2vec(L.lmap.A'*Y + Y*L.lmap.A)
712-
mulcsym!(x, L.lmap.A', L.lmap.E', Y, dual = true)
706+
#mulcsym!(x, L.lmap.A', L.lmap.E', Y, dual = true)
707+
mulcsym!(x, L.lmap.A, L.lmap.E, Y, dual = true)
713708
else
714709
X = reshape(x, n, n)
715710
Y = reshape(convert(AbstractVector{T1}, y), n, n)
@@ -723,19 +718,12 @@ for ttype in (LinearMaps.TransposeMap, LinearMaps.AdjointMap)
723718
return x
724719
end
725720
end
726-
function mulcsym!(y::AbstractVector, A::AbstractMatrix, E::AbstractMatrix, X::AbstractMatrix; dual = false)
721+
function mulcsym!(y::AbstractVector{T}, A::AbstractMatrix{T}, E::AbstractMatrix{T}, X::AbstractMatrix{T}; dual = false) where {T <: Real}
727722
require_one_based_indexing(y, A)
728-
# AXE' + EXA'
729723
n = size(A, 1)
730-
Y = similar(X, n, n)
731724
if dual
732-
#Y = AXE' + EXA'
733-
Y = similar(X, n, n)
734-
temp = similar(Y, (n, n))
735-
mul!(temp, E, X)
736-
mul!(Y, temp, A')
737-
mul!(temp, X, E')
738-
mul!(Y, A, temp, 1, 1)
725+
# A'XE + E'XA with X upper triangular
726+
Y = E'*(X*A) + A'*(X*E)
739727
# y[:] = triu2vec(Y+transpose(Y)-Diagonal(Y))
740728
@inbounds begin
741729
k = 1
@@ -747,22 +735,16 @@ function mulcsym!(y::AbstractVector, A::AbstractMatrix, E::AbstractMatrix, X::Ab
747735
end
748736
end
749737
else
750-
ZERO = zero(eltype(y))
751-
# Y = XE'
752-
mul!(Y, X, E')
753-
# AY + Y'A'
754-
@inbounds begin
738+
# AXE' + EXA' with X symmetric
739+
Y = (A*X)*E'
740+
@inbounds begin
755741
k = 1
756742
for j = 1:n
757-
for i = 1:j
758-
temp = ZERO
759-
for l = 1:n
760-
temp += A[i,l] * Y[l,j] + conj(Y[l,i] * A[j,l])
761-
end
762-
y[k] = temp
763-
k += 1
764-
end
765-
end
743+
for i = 1:j
744+
y[k] = Y[i,j] + Y[j,i]
745+
k += 1
746+
end
747+
end
766748
end
767749
end
768750
return y

test/test_iterative.jl

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,39 @@ end
418418
@test norm(X-X1) < 1.e-7
419419

420420
# the same with lyapc
421-
@time X2 = lyapc(A*A', -(A*B'+B*A'))
421+
@time X2 = lyapc(A*A', hermitianpart!(-(A*B'+B*A')));
422422
@test norm(X2*A-B)/norm(X2) < 1.e-7
423423

424424
# the same with lyapci
425-
@time X3, = lyapci(A*A', -(A*B'+B*A'),reltol=1.e-10);
425+
@time X3, = lyapci(A*A', hermitianpart!(-(A*B'+B*A')),reltol=1.e-10);
426426
@test norm(X3*A-B)/norm(X3) < 1.e-7
427427

428+
m, n = 1000, 123
429+
A = randn(m,n); B = hermitianpart!(randn(m,m)) * A;
430+
@time X, info = gtsylvi([I],[A'*A],[A],[A],B,reltol=1.e-14);
431+
@test norm(X*A'*A+A*X'*A-B) < 1.e-7
432+
433+
# the same with KrylovKit
434+
@time λ, = linsolve(y -> let Y = reshape(y,m,n); vec(A*(Y'*A)+Y*(A'*A)); end, vec(B));
435+
X1 = reshape(λ, m, n)
436+
@test norm(X-X1) < 1.e-7
437+
438+
try
439+
@time X2 = lyapc(A*A', hermitianpart!(-(A*B'+B*A')));
440+
@test norm(X2*A-B)/norm(X2) < 1.e-7
441+
catch
442+
@test true
443+
end
444+
445+
# the same with lyapci
446+
@time X3, = lyapci(A*A', hermitianpart!(-(A*B'+B*A')),reltol=1.e-10);
447+
@test norm(X3*A-B)/norm(X3) < 1.e-7
448+
449+
A1 = A*A'; C = Matrix(hermitianpart!(-(A*B'+B*A')))
450+
@time λ, = linsolve(y -> let Y = reshape(y,m,m); vec(A1*Y+Y*A1'); end, -vec(C));
451+
X1 = reshape(λ, m, m);
452+
@test norm(X1*A-B)/norm(X1) < 1.e-7
453+
428454

429455

430456
A = [I,Matrix(rand(4,4))]
@@ -481,41 +507,43 @@ end
481507
A = rand(Ty,n,n); E = rand(Ty,n,n);
482508
Q = rand(Ty,n,n); C = Hermitian(Q);
483509
# Lyapunov equation, Hermitian case
484-
X, info = lyapci(A, C)
510+
@time X, info = lyapci(A, C)
485511
@test norm(A*X+X*A'+C)/norm(X) < 1.e-4 && ishermitian(X)
486-
X, info = lyapdi(A, C)
512+
@time X1 = reshape(linsolve(y -> let Y = reshape(y,n,n); vec(A*Y+Y*A'); end, -vec(C))[1],n,n);
513+
@test norm(A*X1+X1*A'+C)/norm(X1) < 1.e-4
514+
@time X, info = lyapdi(A, C)
487515
@test norm(A*X*A' -X+C)/norm(X) < 1.e-4 && ishermitian(X)
488516

489517
# Lyapunov equation, non-Hermitian case
490-
X, info = lyapci(A, Q)
518+
@time X, info = lyapci(A, Q)
491519
@test norm(A*X+X*A'+Q)/norm(X) < 1.e-4
492-
X, info = lyapdi(A, Q)
520+
@time X, info = lyapdi(A, Q)
493521
@test norm(A*X*A' -X+Q)/norm(X) < 1.e-4
494522

495523
# generalized Lyapunov equation, Hermitian case
496-
X, info = lyapci(A, E, C)
524+
@time X, info = lyapci(A, E, C)
497525
@test norm(A*X*E'+E*X*A'+C)/norm(X) < 1.e-4 && ishermitian(X)
498-
X, info = lyapdi(A, E, C)
526+
@time X, info = lyapdi(A, E, C)
499527
@test norm(A*X*A' -E*X*E'+C)/norm(X) < 1.e-4 && ishermitian(X)
500528

501529
# generalized Lyapunov equation, non-Hermitian case
502-
X, info = lyapci(A, E, Q)
530+
@time X, info = lyapci(A, E, Q)
503531
@test norm(A*X*E'+E*X*A'+Q)/norm(X) < 1.e-4
504-
X, info = lyapdi(A, E, Q)
532+
@time X, info = lyapdi(A, E, Q)
505533
@test norm(A*X*A' - E*X*E'+Q)/norm(X) < 1.e-4
506534

507535
# Sylvester equation
508536
B = rand(Ty,m,m); W = rand(Ty,n,m)
509-
X, info = sylvci(A, B, W)
537+
@time X, info = sylvci(A, B, W)
510538
@test norm(A*X+X*B-W)/norm(X) < 1.e-4
511-
X, info = sylvdi(A, B, W)
539+
@time X, info = sylvdi(A, B, W)
512540
@test norm(A*X*B+X-W)/norm(X) < 1.e-4
513541

514542
# generalized Sylvester equation
515543
C = rand(Ty,n,n); D = rand(Ty,m,m)
516-
X, info = gsylvi(A, B, C, D, W)
544+
@time X, info = gsylvi(A, B, C, D, W)
517545
@test norm(A*X*B+C*X*D-W)/norm(X) < 1.e-4
518-
X, info = gsylvi(A, B', C', D, W)
546+
@time X, info = gsylvi(A, B', C', D, W)
519547
@test norm(A*X*B'+C'*X*D-W)/norm(X) < 1.e-4
520548
end
521549

0 commit comments

Comments
 (0)