Skip to content

Commit 847c2c6

Browse files
fix: stabilize ArenaNode prototype traversal and tests
1 parent 8dcfc89 commit 847c2c6

3 files changed

Lines changed: 170 additions & 31 deletions

File tree

src/ArenaNode.jl

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import ..NodeModule:
1010
unsafe_get_children,
1111
get_child,
1212
set_child!,
13-
set_children!
13+
set_children!,
14+
branch_equal,
15+
leaf_equal
1416

1517
"""Array-backed arena storing the fields of a tree node in a struct-of-arrays form.
1618
@@ -55,10 +57,14 @@ Core fields are accessed and mutated via `getproperty`/`setproperty!`.
5557
struct ArenaNode{T,D} <: AbstractExpressionNode{T,D}
5658
arena::Arena{T,D}
5759
idx::Int32
60+
61+
@inline function ArenaNode{T,D}(arena::Arena{T,D}, idx::Int32) where {T,D}
62+
return new{T,D}(arena, idx)
63+
end
5864
end
5965

60-
@inline ArenaNode(arena::Arena{T,D}, idx::Integer) where {T,D} =
61-
ArenaNode{T,D}(arena, Int32(idx))
66+
@inline ArenaNode(arena::Arena{T,D}, idx::Int32) where {T,D} =
67+
ArenaNode{T,D}(arena, idx)
6268

6369
@inline function _zero_children(::Val{D}) where {D}
6470
return ntuple(_ -> Int32(0), Val(D))
@@ -179,6 +185,37 @@ end
179185
end
180186
end
181187

188+
# DispatchDoctor's `@stable` checking is based on type inference from argument *types*,
189+
# and does not consider constant propagation of `getproperty(x, ::Symbol)`. Define
190+
# ArenaNode-specific equality helpers that access the underlying arena directly.
191+
@inline function branch_equal(a::ArenaNode{T,D}, b::ArenaNode{T,D})::Bool where {T,D}
192+
arena_a = getfield(a, :arena)
193+
arena_b = getfield(b, :arena)
194+
ia = Int(getfield(a, :idx))
195+
ib = Int(getfield(b, :idx))
196+
@inbounds return arena_a.op[ia] == arena_b.op[ib]
197+
end
198+
199+
@inline function leaf_equal(
200+
a::ArenaNode{T1,D}, b::ArenaNode{T2,D}
201+
)::Bool where {T1,T2,D}
202+
arena_a = getfield(a, :arena)
203+
arena_b = getfield(b, :arena)
204+
ia = Int(getfield(a, :idx))
205+
ib = Int(getfield(b, :idx))
206+
207+
@inbounds begin
208+
const_a = arena_a.constant[ia]
209+
const_b = arena_b.constant[ib]
210+
const_a == const_b || return false
211+
if const_a
212+
return arena_a.val[ia]::T1 == arena_b.val[ib]::T2
213+
else
214+
return arena_a.feature[ia] == arena_b.feature[ib]
215+
end
216+
end
217+
end
218+
182219
"""Return an `NTuple{D,Nullable{ArenaNode}}` of children wrappers.
183220
184221
Unused slots are represented as poison nodes (mirroring `Node`), so that
@@ -225,7 +262,7 @@ end
225262
end
226263

227264
@inline function set_children!(
228-
n::ArenaNode{T,D}, children::Union{Tuple,AbstractVector}
265+
n::ArenaNode{T,D}, children::Union{Tuple,AbstractVector{<:AbstractNode{D}}}
229266
) where {T,D}
230267
D2 = length(children)
231268
idxs = _zero_children(Val(D))
@@ -291,22 +328,22 @@ function arena_from_tree(tree::AbstractExpressionNode{T,D}) where {T,D}
291328
end
292329

293330
"""Convert an arena-backed node back into a heap-allocated `Node` tree."""
294-
function tree_from_arena(tree::ArenaNode{T,D}) where {T,D}
295-
function rebuild(n::ArenaNode{T,D})
296-
d = n.degree
297-
if d == 0
298-
return n.constant ? Node{T,D}(; val=n.val) : Node{T,D}(T; feature=n.feature)
299-
else
300-
# Use a vector here to avoid `Val(d)` with runtime `d`.
301-
cs = Vector{Node{T,D}}(undef, Int(d))
302-
@inbounds for i in 1:Int(d)
303-
cs[i] = rebuild(get_child(n, i))
304-
end
305-
return Node{T,D}(T; op=n.op, children=cs)
331+
@inline function _tree_from_arena(n::ArenaNode{T,D})::Node{T,D} where {T,D}
332+
d = n.degree
333+
if d == 0
334+
return n.constant ? Node{T,D}(; val=n.val) : Node{T,D}(T; feature=n.feature)
335+
else
336+
# Use a vector here to avoid `Val(d)` with runtime `d`.
337+
cs = Vector{Node{T,D}}(undef, Int(d))
338+
@inbounds for i in 1:Int(d)
339+
cs[i] = _tree_from_arena(get_child(n, i))
306340
end
341+
return Node{T,D}(T; op=n.op, children=cs)
307342
end
343+
end
308344

309-
return rebuild(tree)
345+
function tree_from_arena(tree::ArenaNode{T,D}) where {T,D}
346+
return _tree_from_arena(tree)
310347
end
311348

312349
################################################################################
@@ -492,8 +529,9 @@ function rewrite_commutative_constants_right!(tree::ArenaNode{T,D}, operators) w
492529
reset!(cursor, tree)
493530

494531
while true
495-
node = next!(cursor)
496-
node === nothing && break
532+
maybe_node = next!(cursor)
533+
maybe_node.null && break
534+
node = maybe_node[]
497535

498536
node.degree == 2 || continue
499537
f = operators.binops[node.op]
@@ -637,7 +675,7 @@ function tree_mapreduce_postfix_with_stack(
637675
sizehint!(stack, root_idx)
638676

639677
@inbounds for i in 1:root_idx
640-
node = ArenaNode(tree.arena, i)
678+
node = ArenaNode(tree.arena, Int32(i))
641679
d = node.degree
642680
if d == 0
643681
push!(stack, f_leaf(node)::R)
@@ -679,8 +717,9 @@ mutable struct ArenaCursor{T,D}
679717
end
680718
end
681719

682-
@inline ArenaCursor(tree::ArenaNode{T,D}; capacity::Integer=0) where {T,D} =
683-
ArenaCursor(tree.arena; capacity)
720+
@inline function ArenaCursor(tree::ArenaNode{T,D}; capacity::Integer=0) where {T,D}
721+
return ArenaCursor(tree.arena; capacity=capacity)::ArenaCursor{T,D}
722+
end
684723

685724
"""Reset the cursor stack to start a preorder traversal at `root`."""
686725
@inline function reset!(c::ArenaCursor{T,D}, root::Int32) where {T,D}
@@ -691,11 +730,13 @@ end
691730
@inline reset!(c::ArenaCursor, root::ArenaNode) = reset!(c, root.idx)
692731

693732
"""Pop the next node in preorder (or return `nothing` when done)."""
694-
function next!(c::ArenaCursor{T,D}) where {T,D}
695-
isempty(c.stack) && return nothing
733+
function next!(c::ArenaCursor{T,D})::Nullable{ArenaNode{T,D}} where {T,D}
734+
if isempty(c.stack)
735+
return Nullable(true, ArenaNode{T,D}(c.arena, Int32(0)))
736+
end
696737

697738
idx = pop!(c.stack)
698-
n = ArenaNode{T,D}(c.arena, idx)
739+
node = ArenaNode{T,D}(c.arena, idx)
699740

700741
# Push children in reverse order so the leftmost child is visited next.
701742
d = @inbounds c.arena.degree[Int(idx)]
@@ -707,7 +748,7 @@ function next!(c::ArenaCursor{T,D}) where {T,D}
707748
end
708749
end
709750

710-
return n
751+
return Nullable(false, node)
711752
end
712753

713754
"""Traverse a tree in preorder using a reusable cursor."""
@@ -717,9 +758,9 @@ function foreach_preorder!(f, root::ArenaNode{T,D}, cursor::ArenaCursor{T,D}) wh
717758

718759
reset!(cursor, root)
719760
while true
720-
n = next!(cursor)
721-
n === nothing && break
722-
f(n)
761+
maybe_n = next!(cursor)
762+
maybe_n.null && break
763+
f(maybe_n[])
723764
end
724765
return nothing
725766
end

src/NodeUtils.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import ..NodeModule:
1717
import ..ValueInterfaceModule:
1818
pack_scalar_constants!, unpack_scalar_constants, count_scalar_constants, get_number_type
1919

20+
import ..ArenaNodeModule: ArenaNode, ArenaCursor, reset!, next!
21+
2022
"""
2123
count_depth(tree::AbstractNode)::Int
2224
@@ -70,7 +72,25 @@ has_operators(tree::AbstractExpressionNode) = tree.degree != 0
7072
Check if an expression is a constant numerical value, or
7173
whether it depends on input features.
7274
"""
73-
is_constant(tree::AbstractExpressionNode) = all(t -> t.degree != 0 || t.constant, tree)
75+
is_constant(tree::AbstractExpressionNode) = !any(t -> t.degree == 0 && !t.constant, tree)
76+
77+
# Specialized implementation for arena-backed nodes to keep DispatchDoctor inference concrete.
78+
function is_constant(tree::ArenaNode{T,D}) where {T,D}
79+
cursor = ArenaCursor(tree; capacity=Int(getfield(tree, :idx)))
80+
reset!(cursor, tree)
81+
while true
82+
maybe_n = next!(cursor)
83+
maybe_n.null && break
84+
n = maybe_n[]
85+
arena = getfield(n, :arena)
86+
i = Int(getfield(n, :idx))
87+
d = @inbounds arena.degree[i]
88+
if d == 0
89+
@inbounds arena.constant[i] || return false
90+
end
91+
end
92+
return true
93+
end
7494

7595
"""
7696
count_scalar_constants(tree::AbstractExpressionNode{T})::Int64 where {T}

test/test_arenanode.jl

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ const AN = DynamicExpressions.ArenaNodeModule
117117
@test string_tree(atree_pf, operators) == string_tree(atree, operators)
118118
y_pf, ok_pf = eval_tree_array(atree_pf, X, operators)
119119
@test ok_pf
120-
@test y_pf y_tree
120+
@test y_pf y_mut
121121

122122
# Minimal rewrite prototype should preserve postfix validity:
123123
tree_constleft = 3.2 * x1
@@ -134,4 +134,82 @@ const AN = DynamicExpressions.ArenaNodeModule
134134
@test ok_after
135135
@test y_after y_before
136136
end
137+
138+
@testset "Arena allocations" begin
139+
# DispatchDoctor checks in the test environment can dominate allocation counts.
140+
# Measure these low-level allocation properties in a fresh process using the
141+
# package project (dispatch doctor disabled by default there).
142+
project_root = normpath(joinpath(@__DIR__, ".."))
143+
144+
alloc_script = raw"""
145+
local_prefs = joinpath(dirname(Base.active_project()), "LocalPreferences.toml")
146+
prefs_text = string(
147+
"[DynamicExpressions]\n",
148+
"dispatch_doctor_mode = ",
149+
repr("disable"),
150+
"\n",
151+
)
152+
write(local_prefs, prefs_text)
153+
atexit(() -> rm(local_prefs; force=true))
154+
155+
using DynamicExpressions
156+
const AN = DynamicExpressions.ArenaNodeModule
157+
158+
operators = OperatorEnum(1 => (sin, cos), 2 => (+, *))
159+
x1 = DynamicExpressions.Node{Float64}(; feature=1)
160+
161+
function alloc_push_constant!(arena)
162+
AN.push_constant!(arena, 1.0)
163+
return nothing
164+
end
165+
166+
function alloc_set_child!(parent, child)
167+
set_child!(parent, child, 1)
168+
return nothing
169+
end
170+
171+
function alloc_copy_tree!(arena, tree)
172+
AN._copy_to_arena!(arena, tree)
173+
return nothing
174+
end
175+
176+
arena_push = AN.Arena{Float64,2}(; capacity=16)
177+
178+
base_tree = sin(x1)
179+
parent_arena = AN.Arena{Float64,2}(; capacity=16)
180+
parent_idx = AN._copy_to_arena!(parent_arena, base_tree)
181+
parent = AN.ArenaNode(parent_arena, parent_idx)
182+
183+
child_tree = x1 * 3.2
184+
child_arena = AN.Arena{Float64,2}(; capacity=16)
185+
child_idx = AN._copy_to_arena!(child_arena, child_tree)
186+
child = AN.ArenaNode(child_arena, child_idx)
187+
188+
tree_large = sin(x1) + x1 * 3.2 + cos(x1)
189+
arena_large = AN.Arena{Float64,2}(; capacity=64)
190+
191+
alloc_push_constant!(arena_push) # warmup
192+
alloc_set_child!(parent, child) # warmup
193+
alloc_copy_tree!(arena_large, tree_large) # warmup
194+
195+
println("push_constant=$(@allocated alloc_push_constant!(arena_push))")
196+
println("set_child=$(@allocated alloc_set_child!(parent, child))")
197+
println("copy_tree=$(@allocated alloc_copy_tree!(arena_large, tree_large))")
198+
"""
199+
200+
julia_bin = joinpath(Sys.BINDIR, Base.julia_exename())
201+
cmd = `$(julia_bin) --startup-file=no --project=$(project_root) -e $(alloc_script)`
202+
out = read(cmd, String)
203+
204+
allocs = Dict{String,Int}()
205+
for m in eachmatch(r"(push_constant|set_child|copy_tree)=(\d+)", out)
206+
allocs[m.captures[1]] = parse(Int, m.captures[2])
207+
end
208+
209+
@test all(k -> haskey(allocs, k), ("push_constant", "set_child", "copy_tree"))
210+
@test allocs["push_constant"] <= 1024
211+
@test allocs["set_child"] <= 1024
212+
@test allocs["copy_tree"] <= 1024
213+
end
214+
137215
end

0 commit comments

Comments
 (0)