-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathexponential.jl
More file actions
66 lines (57 loc) · 1.88 KB
/
exponential.jl
File metadata and controls
66 lines (57 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Inputs
# ------
function copy_input(::typeof(exponential), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
copy_input(::typeof(exponential), A::Diagonal) = copy(A)
function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)"))
@check_size(expA, (m, m))
return @check_scalar(expA, A)
end
function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert expA isa Diagonal
@check_size(expA, (m, m))
@check_scalar(expA, A)
return nothing
end
# Outputs
# -------
function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm)
n = size(A, 1) # square check will happen later
expA = similar(A, (n, n))
return expA
end
function initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm)
return similar(A)
end
# Implementation
# --------------
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA)
copyto!(expA, LinearAlgebra.exp(A))
return expA
end
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh)
D, V = eigh_full(A, alg.eigh_alg)
iV = inv(V)
map!(exp, diagview(D), diagview(D))
mul!(expA, rmul!(V, D), iV)
return expA
end
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig)
D, V = eig_full(A, alg.eig_alg)
iV = inv(V)
map!(exp, diagview(D), diagview(D))
mul!(expA, rmul!(V, D), iV)
return expA
end
# Diagonal logic
# --------------
function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm)
check_input(exponential!, A, expA, alg)
map!(exp, diagview(expA), diagview(A))
return expA
end