Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Dhairya Gandhi <dhairya@juliahub.com>"]
version = "0.1.2"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -18,7 +19,8 @@ SCPLinearSolveExt = ["LinearSolve"]
SymbolicCompilerPassesRotationsExt = ["Rotations"]

[compat]
LinearAlgebra = "1"
DataStructures = "0.19.3"
LinearAlgebra = "1.11.0, 1.10"
LinearSolve = "3.53.0"
PreallocationTools = "0.4.34, 1"
StaticArrays = "1.9.15, 1"
Expand Down
24 changes: 22 additions & 2 deletions ext/SCPLinearSolveExt/SCPLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,37 @@ using SymbolicUtils
using SymbolicUtils.Code
using LinearSolve
using LinearAlgebra
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE, LINEARSOLVE_LIB
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE
using StaticArrays

__init__() = SymbolicCompilerPasses.LINEARSOLVE_LIB[] = true

const LINSOLVEPROB_CACHE = Dict()

function get_linear_prob(A::StaticArray, B::StaticArray)
prob = LinearSolve.LinearProblem(A, B)
end

function get_linear_prob(A::TA, B::TB) where {TA, TB}
get!(LINSOLVEPROB_CACHE, A) do
prob = LinearSolve.LinearProblem(A, B)
init(prob)
end::Base.promote_op(init, Tuple{Base.promote_op(LinearSolve.LinearProblem, Tuple{TA, TB})})
end

function linear_solve(A, B)
linsolve = get_factorization(A, B)
linsolve = get_linear_prob(A, B)
linsolve.b = B
sol = solve!(linsolve)
return sol.u
end

function linear_solve(A::StaticArray, B::StaticArray)
linsolve = get_linear_prob(A, B)
sol = solve(linsolve)
return sol.u
end

function get_factorization(A, B)
get!(FACTORIZATION_CACHE, A) do
prob = LinearSolve.LinearProblem(A, B)
Expand Down
4 changes: 4 additions & 0 deletions src/SymbolicCompilerPasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import SymbolicUtils.Code: Code, OptimizationRule, substitute_in_ir, apply_optim
import SymbolicUtils: search_variables, search_variables!
using StaticArrays

using DataStructures

function bank(dic, key, value)
if haskey(dic, key)
dic[key] = vcat(dic[key], value)
Expand All @@ -25,4 +27,6 @@ include("hvncat_static_opt.jl")
include("ldiv_opt.jl")
include("la_opt.jl")

include("mb_opt.jl")

end # module SymbolicCompilerPasses
15 changes: 5 additions & 10 deletions src/hvncat_static_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,16 @@ function transform_hvncat_to_static(expr::Code.Let, match_data::Vector{HvncatMat
# Column vector: SVector{n}(elements...)
n = dims[1]
t = term(Core.apply_type, StaticArrays.SVector, n; type = Any)
static_ctor = Term{T}(
t,
elements;
type=symtype(lhs_var)
)
else
# Matrix: SMatrix{m,n}(elements...)
m, n = dims
t = term(Core.apply_type, StaticArrays.SMatrix, m, n; type = Any)
static_ctor = Term{T}(
t,
elements;
type=symtype(lhs_var)
)
end
static_ctor = Term{T}(
t,
elements;
type=symtype(lhs_var)
)

new_assignment = Assignment(lhs_var, static_ctor)
transformations[match.assignment_idx] = new_assignment
Expand Down
5 changes: 4 additions & 1 deletion src/ldiv_opt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const FACTORIZATION_CACHE = WeakKeyDict()
const FACTORIZATION_CACHE = Dict()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Purely out of curiosity, what motivated this change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StaticArrays dont finalize


struct LdivMatch{Ta, Tb, S <: Assignment, P <: AbstractString} <: AbstractMatched
A::Ta
Expand Down Expand Up @@ -27,6 +27,7 @@ function detect_ldiv_pattern(expr::Code.Let, state)
all_arrays || return false

A, B = args
@show validate_ldiv_shapes(A, B)
validate_ldiv_shapes(A, B)
end

Expand All @@ -50,6 +51,7 @@ For A \\ B:
- B must have n rows: (n, m) or (n,)
"""
function validate_ldiv_shapes(A, B)
return true
A_shape = shape(A)
B_shape = shape(B)

Expand Down Expand Up @@ -149,6 +151,7 @@ function get_factorization(A)
qr_A = get!(FACTORIZATION_CACHE, A) do
qr(A)
end
# qr_A = qr(A)

qr_A
end
Expand Down
63 changes: 63 additions & 0 deletions src/mb_opt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
function detect_small_views(expr::Code.Let, state)
matches = []
for (i, p) in enumerate(expr.pairs)
r = rhs(p)
iscall(r) || continue
if operation(r) === view
arr, inds... = arguments(r)
myt = find_term(inds[1], expr)
is_small_hvncat(size(Code.rhs(myt))...) || continue
push!(matches, (idx = i, expr = r))
end
end
matches
end

function construct_type(dims)
# if length(dims) == 1
# return Core.apply_type(SVector, dims[1])
# else
# return Core.apply_type(SVector, Tuple(dims))
# end
Core.apply_type(SVector, length(dims))
end

function find_term(target, expr::Code.Let)
filter(expr.pairs) do p
Code.lhs(p) === target
end |> only
end

function transform_view(expr, match_data, state)
new_pairs = []
idxs = Set(getproperty.(match_data, :idx))
transformations = Dict()
for match in match_data
idx = match.idx
r = match.expr
T = symtype(r)
V = vartype(r)
arr, inds... = arguments(r)
t = term(construct_type, inds[1])
transformations[idx] = Term{V}(t, [r], type = T)
end

for (i, p) in enumerate(expr.pairs)
if i in idxs
new_rhs = transformations[i]
push!(new_pairs, Code.Assignment(lhs(p), new_rhs))
else
push!(new_pairs, p)
end
end

Code.Let(new_pairs, expr.body, expr.let_block)
end


const MB_VIEW_RULE = OptimizationRule(
"MB_VIEW_RULE",
detect_small_views,
transform_view,
10,
)
2 changes: 0 additions & 2 deletions src/ortho_inv_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ function transform_inv_optimization(expr, matches, state::Code.CSEState)
)
else
t = term(is_orthogonal_type, A)
# @show t

# code = IfElse(
# t,
# transpose(A),
Expand Down
7 changes: 4 additions & 3 deletions test/ortho_inv_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ function check_ortho_opt(expr, A, B)
current = SU.Code.cse(expr)
toexpr(current)

optimized = SC.ortho_inv_opt(current, SU.Code.CSEState())
# optimized = SC.ortho_inv_opt(current, SU.Code.CSEState())
optimized = Code.apply_optimization_rule(current, SU.Code.CSEState(), SC.ORTHO_INV_RULE)
# return optimized
# return toexpr(optimized)

Expand All @@ -54,7 +55,7 @@ function check_ortho_opt(expr, A, B)
end


@testset "Orthogonal Matrices: inv -> transpose" begin
# @testset "Orthogonal Matrices: inv -> transpose" begin
@syms A[1:3, 1:3] B[1:3, 1:3] C[1:3, 1:3] D[1:3, 1:3] E[1:3, 1:3]
Ao = SU.setmetadata(A, SC.IsOrthogonal, true)

Expand All @@ -70,4 +71,4 @@ end

expr4 = inv(Ao * B) + B
check_ortho_opt(expr4, A, B)
end
# end
Loading