Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 5911773

Browse files
committed
refactor: improve variable remapping for join and in nodes
Update remap_variables to explicitly assign remapped children for JoinNode and InNode. This prevents KeyError or incorrect node reuse in self-join scenarios where identical child nodes must be remapped to unique tree branches. Additionally, explicitly remap join conditions to ensure they reference the correct child-specific column IDs. Includes a unit test verifying stability and correctness for self-joins.
1 parent 61c17e3 commit 5911773

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

bigframes/core/rewrite/identifiers.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,21 @@ def remap_variables(
5050
new_child_nodes.append(new_child)
5151
new_child_mappings.append(child_mappings)
5252

53+
new_root = root
54+
5355
# Step 2: Transform children to use their new nodes.
54-
remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = {
55-
child: new_child for child, new_child in zip(root.child_nodes, new_child_nodes)
56-
}
57-
new_root = root.transform_children(lambda node: remapped_children[node])
56+
if isinstance(new_root, nodes.JoinNode) or isinstance(new_root, nodes.InNode):
57+
new_root = dataclasses.replace(
58+
new_root,
59+
left_child=new_child_nodes[0],
60+
right_child=new_child_nodes[1],
61+
)
62+
else:
63+
remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = {
64+
child: new_child
65+
for child, new_child in zip(root.child_nodes, new_child_nodes)
66+
}
67+
new_root = root.transform_children(lambda node: remapped_children[node])
5868

5969
# Step 3: Transform the current node using the mappings from its children.
6070
# "reversed" is required for InNode so that in case of a duplicate column ID,
@@ -70,6 +80,20 @@ def remap_variables(
7080
new_child_mappings[0], allow_partial_bindings=True
7181
),
7282
)
83+
elif isinstance(new_root, nodes.JoinNode):
84+
new_root = typing.cast(nodes.JoinNode, new_root)
85+
new_conds = tuple(
86+
(
87+
l_cond.remap_column_refs(
88+
new_child_mappings[0], allow_partial_bindings=True
89+
),
90+
r_cond.remap_column_refs(
91+
new_child_mappings[1], allow_partial_bindings=True
92+
),
93+
)
94+
for l_cond, r_cond in new_root.conditions
95+
)
96+
new_root = dataclasses.replace(new_root, conditions=new_conds)
7397
else:
7498
new_root = new_root.remap_refs(downstream_mappings)
7599

tests/unit/core/rewrite/test_identifiers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,45 @@ def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too):
151151
left_col_id = new_node.left_col.id.name
152152
new_node.validate_tree()
153153
assert left_col_id.startswith("id_")
154+
155+
156+
def test_remap_variables_join_self_stability(leaf):
157+
# Create a join node with the same child twice
158+
# We wrap them in distinct SelectionNodes so they can have their IDs remapped
159+
# independently to avoid ID collisions in the resulting tree.
160+
leaf_selection_left = nodes.SelectionNode(
161+
leaf,
162+
tuple(nodes.AliasedRef.identity(f.id) for f in leaf.fields),
163+
)
164+
leaf_selection_right = nodes.SelectionNode(
165+
leaf,
166+
tuple(nodes.AliasedRef.identity(f.id) for f in leaf.fields),
167+
)
168+
169+
node = nodes.JoinNode(
170+
left_child=leaf_selection_left,
171+
right_child=leaf_selection_right,
172+
conditions=(
173+
(
174+
ex.DerefOp(leaf_selection_left.fields[0].id),
175+
ex.DerefOp(leaf_selection_right.fields[0].id),
176+
),
177+
),
178+
type="inner",
179+
propogate_order=False,
180+
)
181+
182+
# Run remap_variables
183+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
184+
# This used to raise KeyError before the fix
185+
new_node, mapping = id_rewrite.remap_variables(node, id_generator)
186+
187+
assert isinstance(new_node, nodes.JoinNode)
188+
new_node.validate_tree()
189+
190+
# Verify that conditions use child-specific IDs
191+
left_cond, right_cond = new_node.conditions[0]
192+
assert left_cond.id in new_node.left_child.ids
193+
assert right_cond.id in new_node.right_child.ids
194+
# Since it's a self-join remapped to a tree, the left and right IDs should be different
195+
assert left_cond.id != right_cond.id

0 commit comments

Comments
 (0)