Skip to content

Commit 34c6699

Browse files
Merge pull request #16 from JuliaComputing/dg/mb
Convert small views to Static
2 parents 1d2d1fe + 91acc97 commit 34c6699

8 files changed

Lines changed: 105 additions & 19 deletions

File tree

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Dhairya Gandhi <dhairya@juliahub.com>"]
44
version = "0.1.2"
55

66
[deps]
7+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -18,7 +19,8 @@ SCPLinearSolveExt = ["LinearSolve"]
1819
SymbolicCompilerPassesRotationsExt = ["Rotations"]
1920

2021
[compat]
21-
LinearAlgebra = "1"
22+
DataStructures = "0.19.3"
23+
LinearAlgebra = "1.11.0, 1.10"
2224
LinearSolve = "3.53.0"
2325
PreallocationTools = "0.4.34, 1"
2426
StaticArrays = "1.9.15, 1"

ext/SCPLinearSolveExt/SCPLinearSolveExt.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,37 @@ using SymbolicUtils
44
using SymbolicUtils.Code
55
using LinearSolve
66
using LinearAlgebra
7-
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE, LINEARSOLVE_LIB
7+
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE
8+
using StaticArrays
89

910
__init__() = SymbolicCompilerPasses.LINEARSOLVE_LIB[] = true
1011

12+
const LINSOLVEPROB_CACHE = Dict()
13+
14+
function get_linear_prob(A::StaticArray, B::StaticArray)
15+
prob = LinearSolve.LinearProblem(A, B)
16+
end
17+
18+
function get_linear_prob(A::TA, B::TB) where {TA, TB}
19+
get!(LINSOLVEPROB_CACHE, A) do
20+
prob = LinearSolve.LinearProblem(A, B)
21+
init(prob)
22+
end::Base.promote_op(init, Tuple{Base.promote_op(LinearSolve.LinearProblem, Tuple{TA, TB})})
23+
end
24+
1125
function linear_solve(A, B)
12-
linsolve = get_factorization(A, B)
26+
linsolve = get_linear_prob(A, B)
1327
linsolve.b = B
1428
sol = solve!(linsolve)
1529
return sol.u
1630
end
1731

32+
function linear_solve(A::StaticArray, B::StaticArray)
33+
linsolve = get_linear_prob(A, B)
34+
sol = solve(linsolve)
35+
return sol.u
36+
end
37+
1838
function get_factorization(A, B)
1939
get!(FACTORIZATION_CACHE, A) do
2040
prob = LinearSolve.LinearProblem(A, B)

src/SymbolicCompilerPasses.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import SymbolicUtils.Code: Code, OptimizationRule, substitute_in_ir, apply_optim
1111
import SymbolicUtils: search_variables, search_variables!
1212
using StaticArrays
1313

14+
using DataStructures
15+
1416
function bank(dic, key, value)
1517
if haskey(dic, key)
1618
dic[key] = vcat(dic[key], value)
@@ -25,4 +27,6 @@ include("hvncat_static_opt.jl")
2527
include("ldiv_opt.jl")
2628
include("la_opt.jl")
2729

30+
include("mb_opt.jl")
31+
2832
end # module SymbolicCompilerPasses

src/hvncat_static_opt.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,21 +204,16 @@ function transform_hvncat_to_static(expr::Code.Let, match_data::Vector{HvncatMat
204204
# Column vector: SVector{n}(elements...)
205205
n = dims[1]
206206
t = term(Core.apply_type, StaticArrays.SVector, n; type = Any)
207-
static_ctor = Term{T}(
208-
t,
209-
elements;
210-
type=symtype(lhs_var)
211-
)
212207
else
213208
# Matrix: SMatrix{m,n}(elements...)
214209
m, n = dims
215210
t = term(Core.apply_type, StaticArrays.SMatrix, m, n; type = Any)
216-
static_ctor = Term{T}(
217-
t,
218-
elements;
219-
type=symtype(lhs_var)
220-
)
221211
end
212+
static_ctor = Term{T}(
213+
t,
214+
elements;
215+
type=symtype(lhs_var)
216+
)
222217

223218
new_assignment = Assignment(lhs_var, static_ctor)
224219
transformations[match.assignment_idx] = new_assignment

src/ldiv_opt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const FACTORIZATION_CACHE = WeakKeyDict()
1+
const FACTORIZATION_CACHE = Dict()
22

33
struct LdivMatch{Ta, Tb, S <: Assignment, P <: AbstractString} <: AbstractMatched
44
A::Ta
@@ -27,6 +27,7 @@ function detect_ldiv_pattern(expr::Code.Let, state)
2727
all_arrays || return false
2828

2929
A, B = args
30+
@show validate_ldiv_shapes(A, B)
3031
validate_ldiv_shapes(A, B)
3132
end
3233

@@ -50,6 +51,7 @@ For A \\ B:
5051
- B must have n rows: (n, m) or (n,)
5152
"""
5253
function validate_ldiv_shapes(A, B)
54+
return true
5355
A_shape = shape(A)
5456
B_shape = shape(B)
5557

@@ -149,6 +151,7 @@ function get_factorization(A)
149151
qr_A = get!(FACTORIZATION_CACHE, A) do
150152
qr(A)
151153
end
154+
# qr_A = qr(A)
152155

153156
qr_A
154157
end

src/mb_opt.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
function detect_small_views(expr::Code.Let, state)
2+
matches = []
3+
for (i, p) in enumerate(expr.pairs)
4+
r = rhs(p)
5+
iscall(r) || continue
6+
if operation(r) === view
7+
arr, inds... = arguments(r)
8+
myt = find_term(inds[1], expr)
9+
is_small_hvncat(size(Code.rhs(myt))...) || continue
10+
push!(matches, (idx = i, expr = r))
11+
end
12+
end
13+
matches
14+
end
15+
16+
function construct_type(dims)
17+
# if length(dims) == 1
18+
# return Core.apply_type(SVector, dims[1])
19+
# else
20+
# return Core.apply_type(SVector, Tuple(dims))
21+
# end
22+
Core.apply_type(SVector, length(dims))
23+
end
24+
25+
function find_term(target, expr::Code.Let)
26+
filter(expr.pairs) do p
27+
Code.lhs(p) === target
28+
end |> only
29+
end
30+
31+
function transform_view(expr, match_data, state)
32+
new_pairs = []
33+
idxs = Set(getproperty.(match_data, :idx))
34+
transformations = Dict()
35+
for match in match_data
36+
idx = match.idx
37+
r = match.expr
38+
T = symtype(r)
39+
V = vartype(r)
40+
arr, inds... = arguments(r)
41+
t = term(construct_type, inds[1])
42+
transformations[idx] = Term{V}(t, [r], type = T)
43+
end
44+
45+
for (i, p) in enumerate(expr.pairs)
46+
if i in idxs
47+
new_rhs = transformations[i]
48+
push!(new_pairs, Code.Assignment(lhs(p), new_rhs))
49+
else
50+
push!(new_pairs, p)
51+
end
52+
end
53+
54+
Code.Let(new_pairs, expr.body, expr.let_block)
55+
end
56+
57+
58+
const MB_VIEW_RULE = OptimizationRule(
59+
"MB_VIEW_RULE",
60+
detect_small_views,
61+
transform_view,
62+
10,
63+
)

src/ortho_inv_opt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ function transform_inv_optimization(expr, matches, state::Code.CSEState)
6666
)
6767
else
6868
t = term(is_orthogonal_type, A)
69-
# @show t
70-
7169
# code = IfElse(
7270
# t,
7371
# transpose(A),

test/ortho_inv_opt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ function check_ortho_opt(expr, A, B)
2828
current = SU.Code.cse(expr)
2929
toexpr(current)
3030

31-
optimized = SC.ortho_inv_opt(current, SU.Code.CSEState())
31+
# optimized = SC.ortho_inv_opt(current, SU.Code.CSEState())
32+
optimized = Code.apply_optimization_rule(current, SU.Code.CSEState(), SC.ORTHO_INV_RULE)
3233
# return optimized
3334
# return toexpr(optimized)
3435

@@ -54,7 +55,7 @@ function check_ortho_opt(expr, A, B)
5455
end
5556

5657

57-
@testset "Orthogonal Matrices: inv -> transpose" begin
58+
# @testset "Orthogonal Matrices: inv -> transpose" begin
5859
@syms A[1:3, 1:3] B[1:3, 1:3] C[1:3, 1:3] D[1:3, 1:3] E[1:3, 1:3]
5960
Ao = SU.setmetadata(A, SC.IsOrthogonal, true)
6061

@@ -70,4 +71,4 @@ end
7071

7172
expr4 = inv(Ao * B) + B
7273
check_ortho_opt(expr4, A, B)
73-
end
74+
# end

0 commit comments

Comments
 (0)