Skip to content

Commit e9ab531

Browse files
authored
Merge pull request #80 from VEZY/optimize-traversal
Optimization of tree traversal -> implements #79 and fixes #78 + add benchmarks + add `ancestors!` + new method for `descendants!` for providing our own vector to fill
2 parents 6f706f1 + b47ccd6 commit e9ab531

14 files changed

Lines changed: 527 additions & 108 deletions

AGENT.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# MultiScaleTreeGraph Performance Agent Notes
2+
3+
## Goal
4+
- Optimize traversal-heavy workloads for very large trees.
5+
- Prioritize low allocations and type-stable code paths.
6+
7+
## Benchmark Commands
8+
- Local suite:
9+
- `julia --project=benchmark benchmark/benchmarks.jl`
10+
- Package tests (workaround for current precompile deadlock on Julia 1.12.1):
11+
- `julia --project --compiled-modules=no -e 'using Pkg; Pkg.test()'`
12+
13+
## CI Benchmarks
14+
- Uses `AirspeedVelocity.jl` via `.github/workflows/Benchmarks.yml`.
15+
- Benchmark definitions live in `benchmark/benchmarks.jl` and must expose `const SUITE`.
16+
17+
## Current Hot Paths
18+
- `src/compute_MTG/traverse.jl`
19+
- `src/compute_MTG/ancestors.jl`
20+
- `src/compute_MTG/descendants.jl`
21+
- `src/compute_MTG/indexing.jl`
22+
- `src/compute_MTG/check_filters.jl`
23+
- `src/types/Node.jl`
24+
- `src/compute_MTG/node_funs.jl`
25+
26+
## Practical Optimization Rules
27+
- Avoid allocating temporary arrays in per-node loops.
28+
- Prefer in-place APIs for repeated queries:
29+
- `ancestors!(buffer, node, key; ...)`
30+
- `descendants!(buffer, node, key; ...)`
31+
- Keep filter checks branch-light when no filters are provided.
32+
- Keep key access on typed attribute containers (`NamedTuple`, `MutableNamedTuple`, typed dicts) in specialized methods when possible.
33+
- Preserve API behavior and add tests for every optimization that changes internals.

benchmark/benchmarks.jl

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,12 @@ function ancestors_workload(nodes, reps::Int)
8888
return s
8989
end
9090

91-
function ancestors_workload_inplace_1(nodes, reps::Int)
92-
s = 0.0
93-
@inbounds for _ in 1:reps
94-
for n in nodes
95-
out = ancestors!(n, :mass, recursivity_level=4, type=Float64)
96-
for v in out
97-
s += v
98-
end
99-
end
100-
end
101-
return s
102-
end
103-
104-
function ancestors_workload_inplace_2(nodes, reps::Int)
91+
function ancestors_workload_inplace(nodes, reps::Int)
10592
s = 0.0
10693
buf = Float64[]
10794
@inbounds for _ in 1:reps
10895
for n in nodes
109-
ancestors!(buf, n, :mass, recursivity_level=4, type=Float64)
96+
ancestors!(buf, n, :mass, recursivity_level=4)
11097
for v in buf
11198
s += v
11299
end
@@ -130,7 +117,7 @@ end
130117

131118
function descendants_extraction_workload_inplace_2(root)
132119
vals = Float64[]
133-
descendants!(vals, root, :mass, type=Float64)
120+
descendants!(vals, root, :mass)
134121
end
135122

136123
suite_name = "mstg"
@@ -151,17 +138,18 @@ SUITE[suite_name] = BenchmarkGroup([
151138
root, leaves, sample_nodes = synthetic_tree()
152139
SUITE[suite_name]["traverse"]["full_tree_nodes"] = @benchmarkable traverse!($root, _ -> nothing)
153140
SUITE[suite_name]["traverse_extract"]["descendants_mass"] = @benchmarkable descendants_extraction_workload($root)
154-
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace"] = @benchmarkable descendants_extraction_workload_inplace_1($root)
141+
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace_1"] = @benchmarkable descendants_extraction_workload_inplace_1($root)
155142

156143
# Add this one only if we have a method for `descendants!(val, node, key, type)`
157144
if hasmethod(descendants!, Tuple{AbstractVector,Node,Symbol})
158-
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace"] = @benchmarkable descendants_extraction_workload_inplace_2($root)
145+
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace_2"] = @benchmarkable descendants_extraction_workload_inplace_2($root)
159146
end
160147

161148
SUITE[suite_name]["many_queries"]["children_repeated"] = @benchmarkable children_workload($sample_nodes, 300)
162149
SUITE[suite_name]["many_queries"]["parent_repeated"] = @benchmarkable parent_workload($sample_nodes, 300)
163150
SUITE[suite_name]["many_queries"]["ancestors_repeated"] = @benchmarkable ancestors_workload($leaves, 40)
164-
if hasmethod(ancestors!, Tuple{AbstractVector,Node,Symbol})
151+
# Test if ancestors! exists in the package first:
152+
if isdefined(MultiScaleTreeGraph, :ancestors!)
165153
SUITE[suite_name]["many_queries"]["ancestors_repeated_inplace"] = @benchmarkable ancestors_workload_inplace($leaves, 40)
166154
end
167155

src/MultiScaleTreeGraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ export insert_parents!, insert_generations!, insert_children!, insert_siblings!
7979
export insert_parent!, insert_generation!, insert_child!, insert_sibling!
8080
export write_mtg
8181
export is_segment!
82-
export descendants, ancestors, descendants!
82+
export descendants, ancestors, ancestors!, descendants!
8383
export Node
8484
export AbstractNodeMTG
8585
export NodeMTG

src/compute_MTG/ancestors.jl

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ function ancestors(
7777

7878
# Change the filtering function if we also want to remove nodes with nothing values.
7979
filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key)
80+
use_no_filter = no_node_filters(scale, symbol, link, filter_fun_)
8081

8182
val = Array{type,1}()
8283
# Put the recursivity level into an array so it is mutable in-place:
8384

8485
if self
85-
if is_filtered(node, scale, symbol, link, filter_fun_)
86+
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun_)
8687
val_ = unsafe_getindex(node, key)
8788
push!(val, val_)
8889
elseif !all
@@ -91,31 +92,47 @@ function ancestors(
9192
end
9293
end
9394

94-
ancestors_(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
95+
if use_no_filter
96+
ancestors_values_no_filter!(node, key, val, recursivity_level)
97+
else
98+
ancestors_values!(node, key, scale, symbol, link, all, filter_fun_, val, recursivity_level)
99+
end
95100
return val
96101
end
97102

98103

99-
function ancestors_(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
100-
101-
if !isroot(node) && recursivity_level != 0
102-
parent_ = parent(node)
104+
function ancestors_values!(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
105+
current = node
106+
remaining = recursivity_level
103107

104-
# Is there any filter happening for the current node? (FALSE if filtered out):
108+
while !isroot(current) && remaining != 0
109+
parent_ = parent(current)
105110
keep = is_filtered(parent_, scale, symbol, link, filter_fun)
106111

107112
if keep
108-
val_ = unsafe_getindex(parent_, key)
109-
push!(val, val_)
113+
push!(val, unsafe_getindex(parent_, key))
110114
# Only decrement the recursivity level when the current node is not filtered-out
111-
recursivity_level -= 1
115+
remaining -= 1
112116
end
113117

114118
# If we want to continue even if the current node is filtered-out
115-
if all || keep
116-
ancestors_(parent_, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
117-
end
119+
(all || keep) || break
120+
current = parent_
121+
end
122+
return val
123+
end
124+
125+
function ancestors_values_no_filter!(node, key, val, recursivity_level)
126+
current = node
127+
remaining = recursivity_level
128+
129+
while !isroot(current) && remaining != 0
130+
parent_ = parent(current)
131+
push!(val, unsafe_getindex(parent_, key))
132+
remaining -= 1
133+
current = parent_
118134
end
135+
return val
119136
end
120137

121138
# Version that returns the nodes instead of the values:
@@ -133,41 +150,123 @@ function ancestors(
133150
# Check the filters once, and then compute the ancestors recursively using `ancestors_`
134151
check_filters(node, scale=scale, symbol=symbol, link=link)
135152

136-
# Change the filtering function if we also want to remove nodes with nothing values.
153+
use_no_filter = no_node_filters(scale, symbol, link, filter_fun)
137154
val = Array{typeof(node),1}()
138155
# Put the recursivity level into an array so it is mutable in-place:
139156

140157
if self
141-
if is_filtered(node, scale, symbol, link, filter_fun)
158+
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun)
142159
push!(val, node)
143160
elseif !all
144161
# We don't keep the value and we have to stop at the first filtered-out value
145162
return val
146163
end
147164
end
148165

149-
ancestors_(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
166+
if use_no_filter
167+
ancestors_nodes_no_filter!(node, val, recursivity_level)
168+
else
169+
ancestors_nodes!(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
170+
end
150171
return val
151172
end
152173

153174

154-
function ancestors_(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
155-
156-
if !isroot(node) && recursivity_level != 0
157-
parent_ = parent(node)
175+
function ancestors_nodes!(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
176+
current = node
177+
remaining = recursivity_level
158178

159-
# Is there any filter happening for the current node? (FALSE if filtered out):
179+
while !isroot(current) && remaining != 0
180+
parent_ = parent(current)
160181
keep = is_filtered(parent_, scale, symbol, link, filter_fun)
161182

162183
if keep
163184
push!(val, parent_)
164185
# Only decrement the recursivity level when the current node is not filtered-out
165-
recursivity_level -= 1
186+
remaining -= 1
166187
end
167188

168189
# If we want to continue even if the current node is filtered-out
169-
if all || keep
170-
ancestors_(parent_, scale, symbol, link, all, filter_fun, val, recursivity_level)
190+
(all || keep) || break
191+
current = parent_
192+
end
193+
return val
194+
end
195+
196+
function ancestors_nodes_no_filter!(node, val, recursivity_level)
197+
current = node
198+
remaining = recursivity_level
199+
200+
while !isroot(current) && remaining != 0
201+
parent_ = parent(current)
202+
push!(val, parent_)
203+
remaining -= 1
204+
current = parent_
205+
end
206+
return val
207+
end
208+
209+
function ancestors!(
210+
out::AbstractVector,
211+
node, key;
212+
scale=nothing,
213+
symbol=nothing,
214+
link=nothing,
215+
all::Bool=true,
216+
self=false,
217+
filter_fun=nothing,
218+
recursivity_level=-1,
219+
ignore_nothing=false,
220+
type::Union{Union,DataType}=Any,
221+
)
222+
check_filters(node, scale=scale, symbol=symbol, link=link)
223+
filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key)
224+
use_no_filter = no_node_filters(scale, symbol, link, filter_fun_)
225+
226+
empty!(out)
227+
if self
228+
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun_)
229+
push!(out, unsafe_getindex(node, key))
230+
elseif !all
231+
return out
232+
end
233+
end
234+
235+
if use_no_filter
236+
ancestors_values_no_filter!(node, key, out, recursivity_level)
237+
else
238+
ancestors_values!(node, key, scale, symbol, link, all, filter_fun_, out, recursivity_level)
239+
end
240+
return out
241+
end
242+
243+
function ancestors!(
244+
out::AbstractVector,
245+
node;
246+
scale=nothing,
247+
symbol=nothing,
248+
link=nothing,
249+
all::Bool=true,
250+
self=false,
251+
filter_fun=nothing,
252+
recursivity_level=-1,
253+
)
254+
check_filters(node, scale=scale, symbol=symbol, link=link)
255+
use_no_filter = no_node_filters(scale, symbol, link, filter_fun)
256+
257+
empty!(out)
258+
if self
259+
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun)
260+
push!(out, node)
261+
elseif !all
262+
return out
171263
end
172264
end
265+
266+
if use_no_filter
267+
ancestors_nodes_no_filter!(node, out, recursivity_level)
268+
else
269+
ancestors_nodes!(node, scale, symbol, link, all, filter_fun, out, recursivity_level)
270+
end
271+
return out
173272
end

src/compute_MTG/check_filters.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ check_filters(mtg, scale = (1,2))
1111
check_filters(mtg, scale = (1,2), symbol = "Leaf", link = "<")
1212
```
1313
"""
14+
@inline no_node_filters(scale, symbol, link, filter_fun=nothing) =
15+
isnothing(scale) && isnothing(symbol) && isnothing(link) && isnothing(filter_fun)
16+
1417
function check_filters(node::Node{N,A}; scale=nothing, symbol=nothing, link=nothing) where {N<:AbstractNodeMTG,A}
18+
no_node_filters(scale, symbol, link) && return nothing
1519

1620
root_node = get_root(node)
1721

@@ -76,7 +80,10 @@ end
7680
end
7781

7882
@inline function is_filtered(filter, value::T) where {T<:Union{Tuple,Array}}
79-
all(map(x -> is_filtered(filter, x), value))
83+
for x in value
84+
is_filtered(filter, x) || return false
85+
end
86+
return true
8087
end
8188

8289

0 commit comments

Comments
 (0)