Skip to content

Commit 8dcfc89

Browse files
ArenaNode: traversal-based postfix emit + cursor-based rewrite
1 parent 2faa7dd commit 8dcfc89

2 files changed

Lines changed: 115 additions & 60 deletions

File tree

src/ArenaNode.jl

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,20 +332,62 @@ struct PostfixExpr{T,D}
332332
op::Vector{UInt8}
333333
end
334334

335-
"""Emit a [`PostfixExpr`](@ref) for an arena-backed tree.
335+
"""Emit a [`PostfixExpr`](@ref) (postorder / postfix) encoding of `tree`.
336336
337-
This copies the arena prefix `1:tree.idx`.
337+
This traverses the tree via child pointers and emits nodes in postorder, with the
338+
root last.
339+
340+
!!! note
341+
This is intended as a **serialization/debug utility**. It is *not* an execution
342+
strategy (i.e. we should not repeatedly convert between representations to make
343+
algorithms work).
338344
"""
339345
function emit_postfix(tree::ArenaNode{T,D}) where {T,D}
340-
n = Int(tree.idx)
341346
a = tree.arena
342-
return PostfixExpr{T,D}(
343-
copy(@view a.degree[1:n]),
344-
copy(@view a.constant[1:n]),
345-
copy(@view a.val[1:n]),
346-
copy(@view a.feature[1:n]),
347-
copy(@view a.op[1:n]),
348-
)
347+
348+
# Preallocate using the reachable node count (ignoring potential sharing).
349+
n = length(tree; break_sharing=Val(true))
350+
degree = UInt8[]
351+
constant = Bool[]
352+
val = T[]
353+
feature = UInt16[]
354+
op = UInt8[]
355+
sizehint!(degree, n)
356+
sizehint!(constant, n)
357+
sizehint!(val, n)
358+
sizehint!(feature, n)
359+
sizehint!(op, n)
360+
361+
# Iterative postorder traversal: (idx, expanded_children?).
362+
stack = Tuple{Int32,Bool}[]
363+
sizehint!(stack, n)
364+
push!(stack, (tree.idx, false))
365+
366+
while !isempty(stack)
367+
idx, expanded = pop!(stack)
368+
i = Int(idx)
369+
if expanded
370+
@inbounds begin
371+
push!(degree, a.degree[i])
372+
push!(constant, a.constant[i])
373+
push!(val, a.val[i])
374+
push!(feature, a.feature[i])
375+
push!(op, a.op[i])
376+
end
377+
else
378+
push!(stack, (idx, true))
379+
d = @inbounds a.degree[i]
380+
if d != 0
381+
child_idxs = @inbounds a.children[i]
382+
@inbounds for j in Int(d):-1:1
383+
c = child_idxs[j]
384+
c != 0 && push!(stack, (c, false))
385+
end
386+
end
387+
end
388+
end
389+
390+
return PostfixExpr{T,D}(degree, constant, val, feature, op)
349391
end
350392

351393
@generated function _postfix_pop_children(
@@ -436,13 +478,23 @@ is_commutative(_) = false
436478

437479
"""Rewrite commutative binary operators so that constants appear on the right.
438480
439-
This is a minimal in-place rewrite that *preserves postfix validity* since it does
440-
not change any node degrees; it only swaps child pointers.
481+
This is a minimal in-place rewrite that only swaps child pointers.
482+
483+
!!! note
484+
This does **not** rely on any arena index ordering (e.g. postfix layout). It traverses
485+
the tree via child pointers.
486+
487+
Since it does not change degrees, any postfix encoding that depends only on `degree`
488+
(e.g. [`emit_postfix`](@ref)) is preserved.
441489
"""
442490
function rewrite_commutative_constants_right!(tree::ArenaNode{T,D}, operators) where {T,D}
443-
root_idx = Int(tree.idx)
444-
@inbounds for i in 1:root_idx
445-
node = ArenaNode{T,D}(tree.arena, i)
491+
cursor = ArenaCursor(tree)
492+
reset!(cursor, tree)
493+
494+
while true
495+
node = next!(cursor)
496+
node === nothing && break
497+
446498
node.degree == 2 || continue
447499
f = operators.binops[node.op]
448500
is_commutative(f) || continue
@@ -454,6 +506,7 @@ function rewrite_commutative_constants_right!(tree::ArenaNode{T,D}, operators) w
454506
set_child!(node, l, 2)
455507
end
456508
end
509+
457510
return tree
458511
end
459512

test/test_arenanode.jl

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,6 @@ const AN = DynamicExpressions.ArenaNodeModule
3737
collected_idxs = map(n -> n.idx, collected)
3838
@test collected_idxs == seen
3939

40-
# Postfix stack-based utilities (mirroring symbolic_regression.rs patterns):
41-
@test AN.is_valid_postfix(atree)
42-
43-
sizes = Int[]
44-
size_stack = Int[]
45-
AN.subtree_sizes_into!(atree, sizes, size_stack)
46-
start, stop = AN.subtree_range(sizes, Int(atree.idx))
47-
@test start == 1
48-
@test stop == Int(atree.idx)
49-
50-
depth_stack = Int[]
51-
depth_postfix = AN.tree_mapreduce_postfix_with_stack(
52-
atree,
53-
_ -> 1,
54-
_ -> 0,
55-
(_, children) -> maximum(children) + 1,
56-
depth_stack,
57-
)
58-
@test depth_postfix == count_depth(atree)
59-
6040
# Evaluation should match:
6141
X = randn(Float64, 1, 50)
6242
y_tree, ok_tree = eval_tree_array(tree, X, operators)
@@ -65,31 +45,6 @@ const AN = DynamicExpressions.ArenaNodeModule
6545
@test ok_atree
6646
@test y_tree y_atree
6747

68-
# Postfix roundtrip sanity check (debug utility; not an execution strategy):
69-
pf = AN.emit_postfix(atree)
70-
atree_pf = AN.parse_postfix_to_arena(pf)
71-
@test AN.is_valid_postfix(atree_pf)
72-
@test count_nodes(atree_pf) == count_nodes(atree)
73-
@test string_tree(atree_pf, operators) == string_tree(atree, operators)
74-
y_pf, ok_pf = eval_tree_array(atree_pf, X, operators)
75-
@test ok_pf
76-
@test y_pf y_tree
77-
78-
# Minimal rewrite prototype should preserve postfix validity:
79-
tree_constleft = 3.2 * x1
80-
atree_constleft = AN.arena_from_tree(tree_constleft)
81-
@test AN.is_valid_postfix(atree_constleft)
82-
y_before, ok_before = eval_tree_array(atree_constleft, X, operators)
83-
@test ok_before
84-
@test atree_constleft.l.constant
85-
AN.rewrite_commutative_constants_right!(atree_constleft, operators)
86-
@test AN.is_valid_postfix(atree_constleft)
87-
@test !atree_constleft.l.constant
88-
@test atree_constleft.r.constant
89-
y_after, ok_after = eval_tree_array(atree_constleft, X, operators)
90-
@test ok_after
91-
@test y_after y_before
92-
9348
# In-place set_node! should work even when the source tree is from a different arena.
9449
# (This is important for API-compat with algorithms that construct new subtrees.)
9550
atree_setnode = AN.arena_from_tree(tree)
@@ -132,4 +87,51 @@ const AN = DynamicExpressions.ArenaNodeModule
13287
y_tree2, ok_tree2 = eval_tree_array(tree2, X, operators)
13388
@test ok_tree2
13489
@test y_tree2 y_mut
90+
91+
@testset "Postfix / debug utilities (not an execution strategy)" begin
92+
# Postfix stack-based utilities (mirroring symbolic_regression.rs patterns):
93+
@test AN.is_valid_postfix(atree)
94+
95+
sizes = Int[]
96+
size_stack = Int[]
97+
AN.subtree_sizes_into!(atree, sizes, size_stack)
98+
start, stop = AN.subtree_range(sizes, Int(atree.idx))
99+
@test start == 1
100+
@test stop == Int(atree.idx)
101+
102+
depth_stack = Int[]
103+
depth_postfix = AN.tree_mapreduce_postfix_with_stack(
104+
atree,
105+
_ -> 1,
106+
_ -> 0,
107+
(_, children) -> maximum(children) + 1,
108+
depth_stack,
109+
)
110+
@test depth_postfix == count_depth(atree)
111+
112+
# Postfix roundtrip sanity check (debug utility; not an execution strategy):
113+
pf = AN.emit_postfix(atree)
114+
atree_pf = AN.parse_postfix_to_arena(pf)
115+
@test AN.is_valid_postfix(atree_pf)
116+
@test count_nodes(atree_pf) == count_nodes(atree)
117+
@test string_tree(atree_pf, operators) == string_tree(atree, operators)
118+
y_pf, ok_pf = eval_tree_array(atree_pf, X, operators)
119+
@test ok_pf
120+
@test y_pf y_tree
121+
122+
# Minimal rewrite prototype should preserve postfix validity:
123+
tree_constleft = 3.2 * x1
124+
atree_constleft = AN.arena_from_tree(tree_constleft)
125+
@test AN.is_valid_postfix(atree_constleft)
126+
y_before, ok_before = eval_tree_array(atree_constleft, X, operators)
127+
@test ok_before
128+
@test atree_constleft.l.constant
129+
AN.rewrite_commutative_constants_right!(atree_constleft, operators)
130+
@test AN.is_valid_postfix(atree_constleft)
131+
@test !atree_constleft.l.constant
132+
@test atree_constleft.r.constant
133+
y_after, ok_after = eval_tree_array(atree_constleft, X, operators)
134+
@test ok_after
135+
@test y_after y_before
136+
end
135137
end

0 commit comments

Comments
 (0)