Skip to content

Commit 0643fd7

Browse files
chore: use FACTORIZATION_CACHE for LinearSolve
1 parent ea89f25 commit 0643fd7

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

ext/SCPLinearSolveExt/SCPLinearSolveExt.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@ using SymbolicUtils
44
using SymbolicUtils.Code
55
using LinearSolve
66
using LinearAlgebra
7-
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache
7+
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE
88

9-
@warn "here"
109
SymbolicCompilerPasses.LINEARSOLVE_LIB[] = true
1110

1211
function linear_solve(A, B)
13-
prob = LinearSolve.LinearProblem(A, B)
14-
linsolve = init(prob)
12+
linsolve = get_factorization(A, B)
13+
linsolve.b = B
1514
sol = solve!(linsolve)
1615
return sol.u
1716
end
1817

18+
function get_factorization(A, B)
19+
get!(FACTORIZATION_CACHE, A) do
20+
prob = LinearSolve.LinearProblem(A, B)
21+
linsolve = init(prob)
22+
end
23+
end
24+
1925

2026
function ldiv_transformation(safe_matches, ::Val{true})
2127
@info "Using LinearSolve.jl for in-place backsolve optimizations.
@@ -30,7 +36,7 @@ function ldiv_transformation(safe_matches, ::Val{true})
3036
T = Code.vartype(B)
3137

3238
# Create: result = ldiv!(A, B)
33-
if Code.symtype(B) isa AbstractVector
39+
if Code.symtype(B) <: AbstractVector
3440
ldiv_call = Code.Term{T}(
3541
linear_solve,
3642
[A, B];

0 commit comments

Comments
 (0)