Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

- Fixed a bug where calling `DataFrame.alias()` twice on the same DataFrame (e.g. for a self-join) caused both aliases to share the same internal column-mapping dictionary. This made `col("R", "col")` resolve to the same column as `col("L", "col")`, producing incorrect join conditions and filter expressions.

#### Improvements

- Improved CTE optimization to deduplicate identical subtrees in self-joins, which were previously emitted as repeated subqueries.

## 1.51.1 (2026-05-28)

#### Documentation
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,13 @@ def stringify(d):
string = query
if query_params:
string = f"{string}#{query_params}"
if hasattr(node, "expr_to_alias") and node.expr_to_alias:
string = f"{string}#{stringify(node.expr_to_alias)}"
# if hasattr(node, "expr_to_alias") and node.expr_to_alias:
# # Hash by alias values only, not the UUID keys, since UUID keys are regenerated on every deep-copy/re-resolve (e.g. the two
# # branches of a self-join). This lets nodes representing the same computation hash identically, enabling CTE dedup for self-joins.
# # NOTE: since nodes with different UUID keys can now share a CTE, _replace_duplicate_node_with_cte must merge each duplicate's
# # UUID→alias entries into the shared CTE so parent re-resolution can resolve any UUID variant (see companion comment there).
# # Different alias values (e.g. a "_WITH_AD_GROUP" join suffix from _disambiguate) still hash differently, preserving SNOW-2261400.
# string = f"{string}#{sorted(set(node.expr_to_alias.values()))}"
if (
hasattr(node, "df_aliased_col_name_to_real_col_name")
and node.df_aliased_col_name_to_real_col_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ def apply(self) -> RepeatedSubqueryEliminationResult:
total_num_ctes=self._total_number_ctes,
)

@staticmethod
def _has_alias_conflict(
node: TreeNode, existing_cte: Optional[SnowflakePlan]
) -> bool:
"""Whether sharing ``existing_cte`` for ``node`` would map the same expr_id to a
different alias. encode_query_id hashes expr_to_alias by alias values only, so
nodes mapping the same expr_id to different aliases can collide. Merging such a
node into the shared CTE would silently drop an entry and corrupt parent column
resolution, so in that case we skip the CTE and render the node inline."""
if existing_cte is None:
return False
node_expr_to_alias = getattr(node, "expr_to_alias", None) or {}
return any(
key in existing_cte.expr_to_alias
and existing_cte.expr_to_alias[key] != alias
for key, alias in node_expr_to_alias.items()
)

def _replace_duplicate_node_with_cte(
self,
root: TreeNode,
Expand Down Expand Up @@ -159,16 +177,20 @@ def _update_parents(
if node in visited_nodes:
continue

# if the node is a duplicated node and deduplication is not done for the node,
# start the deduplication transformation use CTE
if node.encoded_node_id_with_query in duplicated_node_ids:
if node.encoded_node_id_with_query in resolved_with_block_map:
# if the corresponding CTE block has been created, use the existing
# one.
resolved_with_block = resolved_with_block_map[
node.encoded_node_id_with_query
]
else:
# Decide whether this node should be represented by a (new or shared) CTE:
# it must be a detected duplicate, and sharing the CTE must not introduce an
# alias conflict (see _has_alias_conflict). When it cannot be a CTE, the node
# is left inline and only the parent-propagation path applies, exactly like a
# non-duplicated node.
resolved_with_block = resolved_with_block_map.get(
node.encoded_node_id_with_query
)
is_cte_node = node.encoded_node_id_with_query in duplicated_node_ids and (
not self._has_alias_conflict(node, resolved_with_block)
)
if is_cte_node:
if resolved_with_block is None:
# no CTE block has been created for this node yet, create one.
if (
self._query_generator.session.reduce_describe_query_enabled
and context._is_snowpark_connect_compatible_mode
Expand All @@ -187,6 +209,12 @@ def _update_parents(
node.encoded_node_id_with_query
] = resolved_with_block
self._total_number_ctes += 1
elif getattr(node, "expr_to_alias", None):
# reuse the existing CTE block. expr_ids are regenerated on copy, so
# this node's keys differ from the node the CTE was built from; merge
# this node's entries so every expr_id variant stays resolvable during
# parent re-resolution.
resolved_with_block.expr_to_alias.update(node.expr_to_alias)
_update_parents(
node, should_replace_child=True, new_child=resolved_with_block
)
Expand Down
131 changes: 124 additions & 7 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import copy
import re
import tracemalloc
from contextlib import contextmanager
Expand Down Expand Up @@ -30,6 +31,8 @@
uuid_string,
when_matched,
to_timestamp,
stddev_samp,
when,
)
from snowflake.snowpark.types import (
StructType,
Expand Down Expand Up @@ -699,9 +702,10 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df):
# the second one is incorrect join condition as we have rsuffix for join alias
assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql
assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
# when using different df_ad_group with disambiguation, because rsuffix in join,
# they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
# However there is still a CTE for create_dataframe call
# Both cases produce 1 CTE: the disambiguated rhs_remapped wrapper nodes hash
# identically (same SQL + same alias values, different UUID keys), so they're
# merged into a single CTE via the expr_to_alias merge fix. The raw VALUES
# table is absorbed inline into that CTE body rather than becoming its own CTE.
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1


Expand Down Expand Up @@ -874,12 +878,15 @@ def test_sql_simplifier(session):
join_count=2,
)
with SqlCounter(query_count=0, describe_count=0):
# When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered
# different and have different ids. So only df1 and df will be converted to a CTE
# With value-sort hashing, df1/df2/df3 now hash identically (same SQL +
# same alias values, different UUID keys). df2 and df3 are replaced with a
# shared CTE, but df1's left-join position remains inline. That gives 2
# CTEs (base VALUES + filtered df1) and the filter appears twice (once in
# the CTE body, once inline for the left-join position).
assert (
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2
)
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2

df7 = df1.with_column("c", lit(1))
df8 = df1.with_column("c", lit(1)).with_column("d", lit(1))
Expand Down Expand Up @@ -1949,3 +1956,113 @@ def test_uniform_cte_optimization_depends_on_gen(session, use_bare_random, expec

vals = [row["VAL"] for row in result_df.collect()]
assert (vals[:5] == vals[5:]) == expect_cte


def test_cte_tpcds_q39_style_self_join_deduplication(session):
"""TPCDS_Q39-style self-join: filtered aggregation df aliased twice and self-joined.

Verifies that the shared `inv` computation (group-by + agg + cov filter) is
pushed into a single CTE rather than being inlined once per alias branch.
The CTE body should contain stddev_samp/avg exactly once; the outer query
references it twice (once for inv1, once for inv2).
"""
if not session._sql_simplifier_enabled:
pytest.skip("SQL simplifier is not enabled")

# Synthetic data shaped like the Q39 inventory result after the inner join:
# (item_sk, warehouse_sk, month, quantity). High-variance values so that
# BOTH months pass cov > 1 for each item/warehouse pair, making an incorrect
# cross-join (4 rows) detectable vs. the correct equi-join (2 rows).
raw = session.create_dataframe(
[
(10, 1, 1, 10),
(
10,
1,
1,
390,
), # item 10, wh 1, month 1: mean=200, stdev≈268.7, cov≈1.34 > 1
(10, 1, 2, 20),
(
10,
1,
2,
380,
), # item 10, wh 1, month 2: mean=200, stdev≈254.6, cov≈1.27 > 1
(20, 2, 1, 5),
(
20,
2,
1,
395,
), # item 20, wh 2, month 1: mean=200, stdev≈275.8, cov≈1.38 > 1
(20, 2, 2, 30),
(
20,
2,
2,
370,
), # item 20, wh 2, month 2: mean=200, stdev≈240.4, cov≈1.20 > 1
],
schema=["i_item_sk", "w_warehouse_sk", "d_moy", "qty"],
)

# Mirrors Q39's inner "foo" aggregation subquery.
agg = raw.group_by("i_item_sk", "w_warehouse_sk", "d_moy").agg(
stddev_samp("qty").alias("stdev"),
avg("qty").cast("double").alias("mean"),
)

# Mirrors Q39's outer "inv" CTE: compute cov and filter on cov > 1.
# All four (item, warehouse, month) combinations pass cov > 1.
inv = agg.with_column(
"cov",
when(col("mean") == 0, lit(None)).otherwise(col("stdev") / col("mean")),
).filter(when(col("mean") == 0, lit(0)).otherwise(col("stdev") / col("mean")) > 1)

inv_r = copy.copy(inv)
result = (
inv.join(inv_r, on=["i_item_sk", "w_warehouse_sk"], rsuffix="_r")
.filter(col("d_moy") == 1)
.filter(col("d_moy_r") == 2)
)

sql = result.queries["queries"][-1]
normalized = Utils.normalize_sql(sql)

with SqlCounter(query_count=0, describe_count=0):
# The shared `inv` computation should be deduplicated into exactly one CTE.
assert count_number_of_ctes(normalized) == 1

# The CTE should appear at least 3 times: once in the WITH definition
# and at least twice in the body (one per alias branch).
cte_name_match = re.search(r"WITH\s+(\w+)\s+AS", normalized)
assert cte_name_match is not None, "expected a WITH CTE in the generated SQL"
cte_name = cte_name_match.group(1)
assert (
normalized.count(cte_name) >= 3
), f"CTE '{cte_name}' should appear in the definition and both join branches"

# The aggregation (stddev_samp / GROUP BY) must appear exactly once —
# inside the CTE body. Two occurrences would mean `inv` is inlined
# separately for each alias branch instead of being shared.
assert (
normalized.lower().count("stddev_samp") == 1
), "stddev_samp should appear once (in the CTE), not once per alias branch"
assert (
normalized.upper().count("GROUP BY") == 1
), "GROUP BY should appear once (in the CTE), not once per alias branch"

# Correctness: the CTE-optimized result must match the non-optimized result.
# Correct equi-join on (i_item_sk, w_warehouse_sk) produces 2 rows (item 10
# and item 20, each pairing their month-1 and month-2 stats). A wrong
# cross-join would produce 4 rows, so this check_result is meaningful.
check_result(
session,
result,
expect_cte_optimized=True,
query_count=1,
describe_count=0,
union_count=0,
join_count=1,
)
72 changes: 71 additions & 1 deletion tests/unit/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
encode_node_id_with_query,
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import (
RepeatedSubqueryElimination,
)


def create_test_case1():
Expand Down Expand Up @@ -118,6 +121,9 @@ def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer):


def test_encode_node_id_with_query_includes_aliases():
# expr_to_alias is hashed by sorted(set(values())) so two nodes with the
# same alias values but different UUID keys (e.g. deep-copied self-join
# branches) produce the same hash.
node = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
Expand All @@ -134,13 +140,77 @@ def stringify_dict(d: dict) -> str:
if node.query_params:
expected_string = f"{expected_string}#{node.query_params}"
if node.expr_to_alias:
expected_string = f"{expected_string}#{stringify_dict(node.expr_to_alias)}"
# Values-only sort (no UUID keys) normalizes away UUID differences
expected_string = (
f"{expected_string}#{sorted(set(node.expr_to_alias.values()))}"
)
if node.df_aliased_col_name_to_real_col_name:
expected_string = f"{expected_string}#{stringify_dict(node.df_aliased_col_name_to_real_col_name)}"

expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10]
assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace"

# Two nodes with the same SQL and same alias values but different UUID keys
# must hash identically — this is the Q39 self-join case.
node_same_values_diff_keys = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
expr_to_alias={"uuid_different": "ALIAS1"},
df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"},
)
assert encode_node_id_with_query(node) == encode_node_id_with_query(
node_same_values_diff_keys
)

# Two nodes with the same SQL but different alias values must hash
# differently — this preserves the SNOW-2261400 join-suffix fix.
node_different_values = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
expr_to_alias={"uuid1": "ALIAS1_WITH_SUFFIX"},
df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"},
)
assert encode_node_id_with_query(node) != encode_node_id_with_query(
node_different_values
)


def test_has_alias_conflict():
# encode_query_id hashes expr_to_alias by alias values only, so two nodes can
# share a CTE while carrying different expr_id keys. _has_alias_conflict guards
# the only unsafe case: the same expr_id mapping to a *different* alias, where
# merging would silently drop an entry and corrupt parent column resolution.
has_conflict = RepeatedSubqueryElimination._has_alias_conflict

node = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS1"})

# No existing CTE yet (first occurrence) -> nothing to conflict with.
assert has_conflict(node, None) is False

# Same expr_id mapped to the same alias -> safe to merge.
existing_same = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS1"})
assert has_conflict(node, existing_same) is False

# Disjoint expr_id keys (the normal self-join case: same alias values, fresh
# UUIDs) -> no conflict, the entries simply coexist after merge.
existing_disjoint = SimpleNamespace(expr_to_alias={"uuid2": "ALIAS1"})
assert has_conflict(node, existing_disjoint) is False

# Same expr_id mapped to a *different* alias -> conflict, must not share CTE.
existing_conflict = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS2"})
assert has_conflict(node, existing_conflict) is True

# A conflict on any one key is enough, even when other keys agree.
node_multi = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS1", "uuid2": "ALIAS2"})
existing_partial_conflict = SimpleNamespace(
expr_to_alias={"uuid1": "ALIAS1", "uuid2": "DIFFERENT"}
)
assert has_conflict(node_multi, existing_partial_conflict) is True

# Node without any expr_to_alias entries can never conflict.
node_empty = SimpleNamespace(expr_to_alias={})
assert has_conflict(node_empty, existing_conflict) is False


def test_select_statement_contains_data_generation(mock_session, mock_analyzer):
"""SelectStatement.contains_data_generation should detect zero-arg
Expand Down
Loading