Skip to content

Commit ef596ff

Browse files
Merge pull request #3 from JuliaComputing/dg/ldiv
Feat: Amortize factorization cost for `\`
2 parents 18009c1 + 0643fd7 commit ef596ff

8 files changed

Lines changed: 557 additions & 3 deletions

File tree

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
99
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1010

11+
[weakdeps]
12+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
13+
1114
[sources]
12-
SymbolicUtils = {url = "https://github.com/DhairyaLGandhi/SymbolicUtils.jl", rev = "dg/opt_api"}
15+
SymbolicUtils = {rev = "dg/opt_api", url = "https://github.com/DhairyaLGandhi/SymbolicUtils.jl"}
16+
17+
[extensions]
18+
SCPLinearSolveExt = ["LinearSolve"]
1319

1420
[compat]
1521
LinearAlgebra = "1.11.0"
22+
LinearSolve = "3.53.0"
1623
PreallocationTools = "0.4.34"
1724
SymbolicUtils = "4.1.0"
1825
julia = "1.11"
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
module SCPLinearSolveExt
2+
3+
using SymbolicUtils
4+
using SymbolicUtils.Code
5+
using LinearSolve
6+
using LinearAlgebra
7+
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE
8+
9+
SymbolicCompilerPasses.LINEARSOLVE_LIB[] = true
10+
11+
function linear_solve(A, B)
12+
linsolve = get_factorization(A, B)
13+
linsolve.b = B
14+
sol = solve!(linsolve)
15+
return sol.u
16+
end
17+
18+
function get_factorization(A, B)
19+
get!(FACTORIZATION_CACHE, A) do
20+
prob = LinearSolve.LinearProblem(A, B)
21+
linsolve = init(prob)
22+
end
23+
end
24+
25+
26+
function ldiv_transformation(safe_matches, ::Val{true})
27+
@info "Using LinearSolve.jl for in-place backsolve optimizations.
28+
In order to opt-out of using LinearSolve, set SymbolicCompilerPasses.LINEARSOLVE_LIB[] = false." maxlog=Inf
29+
# Build transformation
30+
transformations = Dict{Int, Code.Assignment}()
31+
32+
rejected_matches = []
33+
for match in safe_matches
34+
A, B = match.A, match.B
35+
result_var = Code.lhs(match.ldiv_candidate)
36+
T = Code.vartype(B)
37+
38+
# Create: result = ldiv!(A, B)
39+
if Code.symtype(B) <: AbstractVector
40+
ldiv_call = Code.Term{T}(
41+
linear_solve,
42+
[A, B];
43+
type=Code.symtype(B)
44+
)
45+
else
46+
@warn "Skipping LinearSolve optimization for match as B is not a vector." maxlog=Inf
47+
push!(rejected_matches, match)
48+
continue
49+
end
50+
51+
ldiv_assignment = Code.Assignment(result_var, ldiv_call)
52+
transformations[match.ldiv_idx] = ldiv_assignment
53+
end
54+
fallback_transformations = ldiv_transformation(rejected_matches, Val(false))
55+
merge(transformations, fallback_transformations)
56+
end
57+
58+
59+
end

src/SymbolicCompilerPasses.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ module SymbolicCompilerPasses
33
using LinearAlgebra
44
using PreallocationTools
55
using SymbolicUtils
6-
import SymbolicUtils: symtype, vartype, Sym, BasicSymbolic, Term, iscall, operation, arguments, maketerm, Const
6+
import SymbolicUtils: symtype, vartype, Sym, BasicSymbolic, Term, iscall, operation, arguments, maketerm, Const, shape, isterm, unwrap,
7+
is_function_symbolic, is_called_function_symbolic, getname, Unknown
78
import SymbolicUtils.Code: Code, OptimizationRule, substitute_in_ir, apply_optimization_rules, AbstractMatched,
8-
Assignment, CSEState, lhs, rhs, apply_substitution_map
9+
Assignment, CSEState, lhs, rhs, apply_substitution_map, issym
10+
import SymbolicUtils: search_variables, search_variables!
911

1012
function bank(dic, key, value)
1113
if haskey(dic, key)
@@ -16,5 +18,7 @@ function bank(dic, key, value)
1618
end
1719

1820
include("matmuladd.jl")
21+
include("ldiv_opt.jl")
22+
include("la_opt.jl")
1923

2024
end # module SymbolicCompilerPasses

src/la_opt.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
function detect_f(expr, f, state)
2+
args = search_variables(expr)
3+
arg_counts = Dict(arg => count_occurrences(arg, expr) for arg in args)
4+
triu_candidates_idx = findall(expr.pairs) do x
5+
r = rhs(x)
6+
iscall(r) || return false
7+
op = operation(r)
8+
op === f || return false
9+
10+
args = arguments(r)
11+
length(args) == 1 || return false
12+
13+
arg = args[1]
14+
symtype(arg) <: AbstractMatrix || return false
15+
16+
# The var"##cse#" variables must get counted at least twice (once in LHS and a second time in RHS)
17+
# TODO: needs to detect provenance better; see expr7
18+
get!(arg_counts, arg) do
19+
count_occurrences(arg, expr)
20+
end == 1 || return false
21+
22+
true
23+
end
24+
25+
triu_candidates = expr.pairs[triu_candidates_idx]
26+
27+
matches = map(triu_candidates_idx, triu_candidates) do idx, candidate
28+
A = arguments(rhs(candidate))[1]
29+
GenericRuleMatched(A, candidate, idx)
30+
end
31+
32+
f = filter(!isnothing, matches)
33+
isempty(f) ? nothing : f
34+
end
35+
36+
struct GenericRuleMatched{Ta, S} <: Code.AbstractMatched
37+
A::Ta
38+
candidate::S
39+
idx::Int
40+
end
41+
42+
transform_f(expr, f!, ::Nothing, state) = expr
43+
function transform_f(expr, f!, matches, state)
44+
45+
new_pairs = []
46+
transformed_idxs = getproperty.(matches, :idx)
47+
for (idx, pair) in enumerate(expr.pairs)
48+
if idx in transformed_idxs
49+
match = matches[findfirst(==(idx), transformed_idxs)]
50+
A = match.A
51+
candidate = match.candidate
52+
53+
push!(new_pairs, Assignment(lhs(candidate), term(f!, A,)))
54+
else
55+
push!(new_pairs, pair)
56+
end
57+
58+
end
59+
60+
Code.Let(new_pairs, expr.body, false)
61+
end
62+
63+
64+
function GenericRule(name, f, f!, priority)
65+
OptimizationRule(
66+
name,
67+
(expr, state) -> detect_f(expr, f, state),
68+
(expr, matches, state) -> transform_f(expr, f!, matches, state),
69+
priority
70+
)
71+
end
72+
73+
const TRIU_RULE = GenericRule("triu", LinearAlgebra.triu, LinearAlgebra.triu!, 8)
74+
const TRIL_RULE = GenericRule("tril", LinearAlgebra.tril, LinearAlgebra.tril!, 8)
75+
const NORMALIZE_RULE = GenericRule("normalize", LinearAlgebra.normalize, LinearAlgebra.normalize!, 8)
76+
# const CONJ_RULE = GenericRule("conj", LinearAlgebra.conj, LinearAlgebra.conj!, 8)
77+
78+
function triu_opt(expr, state::CSEState)
79+
# Try to apply optimization rules
80+
optimized = apply_optimization_rules(expr, state, TRIU_RULE)
81+
if optimized !== nothing
82+
return optimized
83+
end
84+
85+
# If no optimization applied, return original expression
86+
return expr
87+
end

0 commit comments

Comments
 (0)