1+ function detect_f (expr, f, state)
2+ args = search_variables (expr)
3+ arg_counts = Dict (arg => count_occurrences (arg, expr) for arg in args)
4+ triu_candidates_idx = findall (expr. pairs) do x
5+ r = rhs (x)
6+ iscall (r) || return false
7+ op = operation (r)
8+ op === f || return false
9+
10+ args = arguments (r)
11+ length (args) == 1 || return false
12+
13+ arg = args[1 ]
14+ symtype (arg) <: AbstractMatrix || return false
15+
16+ # The var"##cse#" variables must get counted at least twice (once in LHS and a second time in RHS)
17+ # TODO : needs to detect provenance better; see expr7
18+ get! (arg_counts, arg) do
19+ count_occurrences (arg, expr)
20+ end == 1 || return false
21+
22+ true
23+ end
24+
25+ triu_candidates = expr. pairs[triu_candidates_idx]
26+
27+ matches = map (triu_candidates_idx, triu_candidates) do idx, candidate
28+ A = arguments (rhs (candidate))[1 ]
29+ GenericRuleMatched (A, candidate, idx)
30+ end
31+
32+ f = filter (! isnothing, matches)
33+ isempty (f) ? nothing : f
34+ end
35+
36+ struct GenericRuleMatched{Ta, S} <: Code.AbstractMatched
37+ A:: Ta
38+ candidate:: S
39+ idx:: Int
40+ end
41+
42+ transform_f (expr, f!, :: Nothing , state) = expr
43+ function transform_f (expr, f!, matches, state)
44+
45+ new_pairs = []
46+ transformed_idxs = getproperty .(matches, :idx )
47+ for (idx, pair) in enumerate (expr. pairs)
48+ if idx in transformed_idxs
49+ match = matches[findfirst (== (idx), transformed_idxs)]
50+ A = match. A
51+ candidate = match. candidate
52+
53+ push! (new_pairs, Assignment (lhs (candidate), term (f!, A,)))
54+ else
55+ push! (new_pairs, pair)
56+ end
57+
58+ end
59+
60+ Code. Let (new_pairs, expr. body, false )
61+ end
62+
63+
64+ function GenericRule (name, f, f!, priority)
65+ OptimizationRule (
66+ name,
67+ (expr, state) -> detect_f (expr, f, state),
68+ (expr, matches, state) -> transform_f (expr, f!, matches, state),
69+ priority
70+ )
71+ end
72+
73+ const TRIU_RULE = GenericRule (" triu" , LinearAlgebra. triu, LinearAlgebra. triu!, 8 )
74+ const TRIL_RULE = GenericRule (" tril" , LinearAlgebra. tril, LinearAlgebra. tril!, 8 )
75+ const NORMALIZE_RULE = GenericRule (" normalize" , LinearAlgebra. normalize, LinearAlgebra. normalize!, 8 )
76+ # const CONJ_RULE = GenericRule("conj", LinearAlgebra.conj, LinearAlgebra.conj!, 8)
77+
78+ function triu_opt (expr, state:: CSEState )
79+ # Try to apply optimization rules
80+ optimized = apply_optimization_rules (expr, state, TRIU_RULE)
81+ if optimized != = nothing
82+ return optimized
83+ end
84+
85+ # If no optimization applied, return original expression
86+ return expr
87+ end
0 commit comments