1+ function detect_small_views (expr:: Code.Let , state)
2+ matches = []
3+ for (i, p) in enumerate (expr. pairs)
4+ r = rhs (p)
5+ iscall (r) || continue
6+ if operation (r) === view
7+ arr, inds... = arguments (r)
8+ myt = find_term (inds[1 ], expr)
9+ is_small_hvncat (size (Code. rhs (myt))... ) || continue
10+ push! (matches, (idx = i, expr = r))
11+ end
12+ end
13+ matches
14+ end
15+
16+ function construct_type (dims)
17+ # if length(dims) == 1
18+ # return Core.apply_type(SVector, dims[1])
19+ # else
20+ # return Core.apply_type(SVector, Tuple(dims))
21+ # end
22+ Core. apply_type (SVector, length (dims))
23+ end
24+
25+ function find_term (target, expr:: Code.Let )
26+ filter (expr. pairs) do p
27+ Code. lhs (p) === target
28+ end |> only
29+ end
30+
31+ function transform_view (expr, match_data, state)
32+ new_pairs = []
33+ idxs = Set (getproperty .(match_data, :idx ))
34+ transformations = Dict ()
35+ for match in match_data
36+ idx = match. idx
37+ r = match. expr
38+ T = symtype (r)
39+ V = vartype (r)
40+ arr, inds... = arguments (r)
41+ t = term (construct_type, inds[1 ])
42+ transformations[idx] = Term {V} (t, [r], type = T)
43+ end
44+
45+ for (i, p) in enumerate (expr. pairs)
46+ if i in idxs
47+ new_rhs = transformations[i]
48+ push! (new_pairs, Code. Assignment (lhs (p), new_rhs))
49+ else
50+ push! (new_pairs, p)
51+ end
52+ end
53+
54+ Code. Let (new_pairs, expr. body, expr. let_block)
55+ end
56+
57+
58+ const MB_VIEW_RULE = OptimizationRule (
59+ " MB_VIEW_RULE" ,
60+ detect_small_views,
61+ transform_view,
62+ 10 ,
63+ )
0 commit comments