-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSCPLinearSolveExt.jl
More file actions
79 lines (64 loc) · 2.27 KB
/
SCPLinearSolveExt.jl
File metadata and controls
79 lines (64 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
module SCPLinearSolveExt
using SymbolicUtils
using SymbolicUtils.Code
using LinearSolve
using LinearAlgebra
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE
using StaticArrays
__init__() = SymbolicCompilerPasses.LINEARSOLVE_LIB[] = true
const LINSOLVEPROB_CACHE = Dict()
function get_linear_prob(A::StaticArray, B::StaticArray)
prob = LinearSolve.LinearProblem(A, B)
end
function get_linear_prob(A::TA, B::TB) where {TA, TB}
get!(LINSOLVEPROB_CACHE, A) do
prob = LinearSolve.LinearProblem(A, B)
init(prob)
end::Base.promote_op(init, Tuple{Base.promote_op(LinearSolve.LinearProblem, Tuple{TA, TB})})
end
function linear_solve(A, B)
linsolve = get_linear_prob(A, B)
linsolve.b = B
sol = solve!(linsolve)
return sol.u
end
function linear_solve(A::StaticArray, B::StaticArray)
linsolve = get_linear_prob(A, B)
sol = solve(linsolve)
return sol.u
end
function get_factorization(A, B)
get!(FACTORIZATION_CACHE, A) do
prob = LinearSolve.LinearProblem(A, B)
linsolve = init(prob)
end
end
function ldiv_transformation(safe_matches, ::Val{true})
@info "Using LinearSolve.jl for in-place backsolve optimizations.
In order to opt-out of using LinearSolve, set SymbolicCompilerPasses.LINEARSOLVE_LIB[] = false." maxlog=Inf
# Build transformation
transformations = Dict{Int, Code.Assignment}()
rejected_matches = []
for match in safe_matches
A, B = match.A, match.B
result_var = Code.lhs(match.ldiv_candidate)
T = Code.vartype(B)
# Create: result = ldiv!(A, B)
if Code.symtype(B) <: AbstractVector
ldiv_call = Code.Term{T}(
linear_solve,
[A, B];
type=Code.symtype(B)
)
else
@warn "Skipping LinearSolve optimization for match as B is not a vector." maxlog=Inf
push!(rejected_matches, match)
continue
end
ldiv_assignment = Code.Assignment(result_var, ldiv_call)
transformations[match.ldiv_idx] = ldiv_assignment
end
fallback_transformations = ldiv_transformation(rejected_matches, Val(false))
merge(transformations, fallback_transformations)
end
end