Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

### Snowpark Python API Updates

#### Bug Fixes

- Fixed a bug where `cloudpickle` could not be resolved when registering a Python stored procedure or UDF with `runtime_version='3.13'`.

#### New Features

- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments).

#### Bug Fixes

- Fixed a bug where calling `DataFrame.alias()` twice on the same DataFrame (e.g. for a self-join) caused both aliases to share the same internal column-mapping dictionary. This made `col("R", "col")` resolve to the same column as `col("L", "col")`, producing incorrect join conditions and filter expressions.
- Fixed a bug where `cloudpickle` could not be resolved when registering a Python stored procedure or UDF with `runtime_version='3.13'`.

#### Improvements

- Improved CTE optimization to deduplicate identical subtrees in self-joins, which were previously emitted as repeated subqueries.

## 1.51.1 (2026-05-28)

Expand Down
7 changes: 6 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,12 @@ def stringify(d):
if query_params:
string = f"{string}#{query_params}"
if hasattr(node, "expr_to_alias") and node.expr_to_alias:
string = f"{string}#{stringify(node.expr_to_alias)}"
# Hash by alias values only, not the UUID keys, since UUID keys are regenerated on every deep-copy/re-resolve (e.g. the two
# branches of a self-join). This lets nodes representing the same computation hash identically, enabling CTE dedup for self-joins.
# NOTE: since nodes with different UUID keys can now share a CTE, _replace_duplicate_node_with_cte must merge each duplicate's
# UUID→alias entries into the shared CTE so parent re-resolution can resolve any UUID variant (see companion comment there).
# Different alias values (e.g. a "_WITH_AD_GROUP" join suffix from _disambiguate) still hash differently, preserving SNOW-2261400.
string = f"{string}#{sorted(set(node.expr_to_alias.values()))}"
Comment on lines +310 to +315

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what kind of cases will this produce same hash-keys. When it is wrong, what kind of risk are we dealing with

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I assume you are asking the hash-keys for a node -- same hash-keys are produced by nodes with same sql, sql params, df aliases, and expr_to_alias values (before the change it's by expr_to_alias.items().

(hash-keys for expr_to_alias are generated from uuid, it can rarely happen that two uuid collide)

  1. the only risk is parent column resolution as we relax the standard of repeated node identification. I introduced extra logic to detect conflicting expr_to_alias (same expr maps to two different alias) in nodes to prevent CTE in this case. thank you for bringing it up

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance of duplicate value entries here that would silently get swallowed by the sorted(set(...)) call?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, duplicate alias can happen but consolidating them is harmless in our case, it's not reflected in the newly-generated CTE optimized node, and the generated sql in the hash is enough for dup node detection in the self-join case.

the expr_to_alias is mostly for internal column name resolution in the downstream compilation stage.

theoretically I think this info shall be excluded from the hash computation of a node, but right now I keep it as a defensive layer to distinguish nodes.

if (
hasattr(node, "df_aliased_col_name_to_real_col_name")
and node.df_aliased_col_name_to_real_col_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ def apply(self) -> RepeatedSubqueryEliminationResult:
total_num_ctes=self._total_number_ctes,
)

@staticmethod
def _has_alias_conflict(
node: TreeNode, existing_cte: Optional[SnowflakePlan]
) -> bool:
"""Whether sharing ``existing_cte`` for ``node`` would map the same expr_id to a
different alias. encode_query_id hashes expr_to_alias by alias values only, so
nodes mapping the same expr_id to different aliases can collide. Merging such a
node into the shared CTE would silently drop an entry and corrupt parent column
resolution, so in that case we skip the CTE and render the node inline."""
if existing_cte is None:
return False
node_expr_to_alias = getattr(node, "expr_to_alias", None) or {}
return any(
key in existing_cte.expr_to_alias
and existing_cte.expr_to_alias[key] != alias
for key, alias in node_expr_to_alias.items()
)

def _replace_duplicate_node_with_cte(
self,
root: TreeNode,
Expand Down Expand Up @@ -159,16 +177,20 @@ def _update_parents(
if node in visited_nodes:
continue

# if the node is a duplicated node and deduplication is not done for the node,
# start the deduplication transformation use CTE
if node.encoded_node_id_with_query in duplicated_node_ids:
if node.encoded_node_id_with_query in resolved_with_block_map:
# if the corresponding CTE block has been created, use the existing
# one.
resolved_with_block = resolved_with_block_map[
node.encoded_node_id_with_query
]
else:
# Decide whether this node should be represented by a (new or shared) CTE:
# it must be a detected duplicate, and sharing the CTE must not introduce an
# alias conflict (see _has_alias_conflict). When it cannot be a CTE, the node
# is left inline and only the parent-propagation path applies, exactly like a
# non-duplicated node.
resolved_with_block = resolved_with_block_map.get(
node.encoded_node_id_with_query
)
is_cte_node = node.encoded_node_id_with_query in duplicated_node_ids and (
not self._has_alias_conflict(node, resolved_with_block)
)
if is_cte_node:
if resolved_with_block is None:
# no CTE block has been created for this node yet, create one.
if (
self._query_generator.session.reduce_describe_query_enabled
and context._is_snowpark_connect_compatible_mode
Expand All @@ -187,6 +209,12 @@ def _update_parents(
node.encoded_node_id_with_query
] = resolved_with_block
self._total_number_ctes += 1
elif getattr(node, "expr_to_alias", None):
# reuse the existing CTE block. expr_ids are regenerated on copy, so
# this node's keys differ from the node the CTE was built from; merge
# this node's entries so every expr_id variant stays resolvable during
# parent re-resolution.
resolved_with_block.expr_to_alias.update(node.expr_to_alias)
Comment thread
sfc-gh-joshi marked this conversation as resolved.
_update_parents(
node, should_replace_child=True, new_child=resolved_with_block
)
Expand Down
131 changes: 124 additions & 7 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import copy
import re
import tracemalloc
from contextlib import contextmanager
Expand Down Expand Up @@ -30,6 +31,8 @@
uuid_string,
when_matched,
to_timestamp,
stddev_samp,
when,
)
from snowflake.snowpark.types import (
StructType,
Expand Down Expand Up @@ -699,9 +702,10 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df):
# the second one is incorrect join condition as we have rsuffix for join alias
assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql
assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
# when using different df_ad_group with disambiguation, because rsuffix in join,
# they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
# However there is still a CTE for create_dataframe call
# Both cases produce 1 CTE: the disambiguated rhs_remapped wrapper nodes hash
# identically (same SQL + same alias values, different UUID keys), so they're
# merged into a single CTE via the expr_to_alias merge fix. The raw VALUES
# table is absorbed inline into that CTE body rather than becoming its own CTE.
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1


Expand Down Expand Up @@ -874,12 +878,15 @@ def test_sql_simplifier(session):
join_count=2,
)
with SqlCounter(query_count=0, describe_count=0):
# When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered
# different and have different ids. So only df1 and df will be converted to a CTE
# With value-sort hashing, df1/df2/df3 now hash identically (same SQL +
# same alias values, different UUID keys). df2 and df3 are replaced with a
# shared CTE, but df1's left-join position remains inline. That gives 2
# CTEs (base VALUES + filtered df1) and the filter appears twice (once in
# the CTE body, once inline for the left-join position).
assert (
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2
)
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2
Comment on lines -877 to +889

@sfc-gh-aling sfc-gh-aling May 29, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for clarity, more CTE is better for this case, see the below comparison between the generated sql:

Before (1 CTE, filter appears 3 times):

  WITH CTE_values AS (
    SELECT $1 AS "A", $2 AS "B" FROM VALUES (1,2),(3,4)                                                                                                                             
  )                                                                                                                                                                                 
  SELECT * FROM (                                                                                                                                                                   
    (                                                                                                                                                                               
      SELECT "A_XXX", "B_XXX", "A" AS "A_YYY", "B" AS "B_YYY"                                                                                                                       
      FROM (                                                                                                                                                                        
        (                                                                                                                                                                           
          SELECT "A" AS "A_XXX", "B" AS "B_XXX"                                                                                                                                     
          FROM CTE_values WHERE ("A" = 1)       -- filter #1 (df1, inline)                                                                                                          
        ) AS SNOWPARK_LEFT                                                                                                                                                          
        INNER JOIN (                                                                                                                                                                
          SELECT "A", "B"                                                                                                                                                           
          FROM CTE_values WHERE ("A" = 1)       -- filter #2 (df2, inline)                                                                                                          
        ) AS SNOWPARK_RIGHT                                                                                                                                                         
      )                                                           
    ) AS SNOWPARK_LEFT
    INNER JOIN (      
      SELECT "A", "B"
      FROM CTE_values WHERE ("A" = 1)           -- filter #3 (df3, inline)
    ) AS SNOWPARK_RIGHT                                                   
  )                                                                                                                                                                                 

After (2 CTEs, filter appears 2 times):

  WITH CTE_values AS (                                                                                                                                                              
    SELECT $1 AS "A", $2 AS "B" FROM VALUES (1,2),(3,4)                                                                                                                             
  ),                                                                                                                                                                                
  CTE_filtered AS (                              -- df2 and df3 deduplicated here                                                                                                   
    SELECT "A", "B"                                                                                                                                                                 
    FROM CTE_values WHERE ("A" = 1)             -- filter #1 (in CTE body)                                                                                                          
  )                                                                                                                                                                                 
  SELECT * FROM (                                                                                                                                                                   
    (                                                             
      SELECT "A_XXX", "B_XXX", "A" AS "A_YYY", "B" AS "B_YYY"
      FROM (                                                 
        (   
          SELECT "A" AS "A_XXX", "B" AS "B_XXX"
          FROM CTE_values WHERE ("A" = 1)       -- filter #2 (df1, still inline)
        ) AS SNOWPARK_LEFT                                                      
        INNER JOIN (SELECT * FROM CTE_filtered) AS SNOWPARK_RIGHT                                                                                                                   
      )                                                                                                                                                                             
    ) AS SNOWPARK_LEFT                                                                                                                                                              
    INNER JOIN (SELECT * FROM CTE_filtered) AS SNOWPARK_RIGHT                                                                                                                       
  )   


df7 = df1.with_column("c", lit(1))
df8 = df1.with_column("c", lit(1)).with_column("d", lit(1))
Expand Down Expand Up @@ -1949,3 +1956,113 @@ def test_uniform_cte_optimization_depends_on_gen(session, use_bare_random, expec

vals = [row["VAL"] for row in result_df.collect()]
assert (vals[:5] == vals[5:]) == expect_cte


def test_cte_tpcds_q39_style_self_join_deduplication(session):
"""TPCDS_Q39-style self-join: filtered aggregation df aliased twice and self-joined.

Verifies that the shared `inv` computation (group-by + agg + cov filter) is
pushed into a single CTE rather than being inlined once per alias branch.
The CTE body should contain stddev_samp/avg exactly once; the outer query
references it twice (once for inv1, once for inv2).
"""
if not session._sql_simplifier_enabled:
pytest.skip("SQL simplifier is not enabled")

# Synthetic data shaped like the Q39 inventory result after the inner join:
# (item_sk, warehouse_sk, month, quantity). High-variance values so that
# BOTH months pass cov > 1 for each item/warehouse pair, making an incorrect
# cross-join (4 rows) detectable vs. the correct equi-join (2 rows).
raw = session.create_dataframe(
[
(10, 1, 1, 10),
(
10,
1,
1,
390,
), # item 10, wh 1, month 1: mean=200, stdev≈268.7, cov≈1.34 > 1
(10, 1, 2, 20),
(
10,
1,
2,
380,
), # item 10, wh 1, month 2: mean=200, stdev≈254.6, cov≈1.27 > 1
(20, 2, 1, 5),
(
20,
2,
1,
395,
), # item 20, wh 2, month 1: mean=200, stdev≈275.8, cov≈1.38 > 1
(20, 2, 2, 30),
(
20,
2,
2,
370,
), # item 20, wh 2, month 2: mean=200, stdev≈240.4, cov≈1.20 > 1
],
schema=["i_item_sk", "w_warehouse_sk", "d_moy", "qty"],
)

# Mirrors Q39's inner "foo" aggregation subquery.
agg = raw.group_by("i_item_sk", "w_warehouse_sk", "d_moy").agg(
stddev_samp("qty").alias("stdev"),
avg("qty").cast("double").alias("mean"),
)

# Mirrors Q39's outer "inv" CTE: compute cov and filter on cov > 1.
# All four (item, warehouse, month) combinations pass cov > 1.
inv = agg.with_column(
"cov",
when(col("mean") == 0, lit(None)).otherwise(col("stdev") / col("mean")),
).filter(when(col("mean") == 0, lit(0)).otherwise(col("stdev") / col("mean")) > 1)

inv_r = copy.copy(inv)
result = (
inv.join(inv_r, on=["i_item_sk", "w_warehouse_sk"], rsuffix="_r")
.filter(col("d_moy") == 1)
.filter(col("d_moy_r") == 2)
)

sql = result.queries["queries"][-1]
normalized = Utils.normalize_sql(sql)

with SqlCounter(query_count=0, describe_count=0):
# The shared `inv` computation should be deduplicated into exactly one CTE.
assert count_number_of_ctes(normalized) == 1

# The CTE should appear at least 3 times: once in the WITH definition
# and at least twice in the body (one per alias branch).
cte_name_match = re.search(r"WITH\s+(\w+)\s+AS", normalized)
assert cte_name_match is not None, "expected a WITH CTE in the generated SQL"
cte_name = cte_name_match.group(1)
assert (
normalized.count(cte_name) >= 3
), f"CTE '{cte_name}' should appear in the definition and both join branches"

# The aggregation (stddev_samp / GROUP BY) must appear exactly once —
# inside the CTE body. Two occurrences would mean `inv` is inlined
# separately for each alias branch instead of being shared.
assert (
normalized.lower().count("stddev_samp") == 1
), "stddev_samp should appear once (in the CTE), not once per alias branch"
assert (
normalized.upper().count("GROUP BY") == 1
), "GROUP BY should appear once (in the CTE), not once per alias branch"

# Correctness: the CTE-optimized result must match the non-optimized result.
# Correct equi-join on (i_item_sk, w_warehouse_sk) produces 2 rows (item 10
# and item 20, each pairing their month-1 and month-2 stats). A wrong
# cross-join would produce 4 rows, so this check_result is meaningful.
check_result(
session,
result,
expect_cte_optimized=True,
query_count=1,
describe_count=0,
union_count=0,
join_count=1,
)
72 changes: 71 additions & 1 deletion tests/unit/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
encode_node_id_with_query,
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import (
RepeatedSubqueryElimination,
)


def create_test_case1():
Expand Down Expand Up @@ -118,6 +121,9 @@ def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer):


def test_encode_node_id_with_query_includes_aliases():
# expr_to_alias is hashed by sorted(set(values())) so two nodes with the
# same alias values but different UUID keys (e.g. deep-copied self-join
# branches) produce the same hash.
node = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
Expand All @@ -134,13 +140,77 @@ def stringify_dict(d: dict) -> str:
if node.query_params:
expected_string = f"{expected_string}#{node.query_params}"
if node.expr_to_alias:
expected_string = f"{expected_string}#{stringify_dict(node.expr_to_alias)}"
# Values-only sort (no UUID keys) normalizes away UUID differences
expected_string = (
f"{expected_string}#{sorted(set(node.expr_to_alias.values()))}"
)
if node.df_aliased_col_name_to_real_col_name:
expected_string = f"{expected_string}#{stringify_dict(node.df_aliased_col_name_to_real_col_name)}"

expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10]
assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace"

# Two nodes with the same SQL and same alias values but different UUID keys
# must hash identically — this is the Q39 self-join case.
node_same_values_diff_keys = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
expr_to_alias={"uuid_different": "ALIAS1"},
df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"},
)
assert encode_node_id_with_query(node) == encode_node_id_with_query(
node_same_values_diff_keys
)

# Two nodes with the same SQL but different alias values must hash
# differently — this preserves the SNOW-2261400 join-suffix fix.
node_different_values = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
expr_to_alias={"uuid1": "ALIAS1_WITH_SUFFIX"},
df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"},
)
assert encode_node_id_with_query(node) != encode_node_id_with_query(
node_different_values
)


def test_has_alias_conflict():
# encode_query_id hashes expr_to_alias by alias values only, so two nodes can
# share a CTE while carrying different expr_id keys. _has_alias_conflict guards
# the only unsafe case: the same expr_id mapping to a *different* alias, where
# merging would silently drop an entry and corrupt parent column resolution.
has_conflict = RepeatedSubqueryElimination._has_alias_conflict

node = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS1"})

# No existing CTE yet (first occurrence) -> nothing to conflict with.
assert has_conflict(node, None) is False

# Same expr_id mapped to the same alias -> safe to merge.
existing_same = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS1"})
assert has_conflict(node, existing_same) is False

# Disjoint expr_id keys (the normal self-join case: same alias values, fresh
# UUIDs) -> no conflict, the entries simply coexist after merge.
existing_disjoint = SimpleNamespace(expr_to_alias={"uuid2": "ALIAS1"})
assert has_conflict(node, existing_disjoint) is False

# Same expr_id mapped to a *different* alias -> conflict, must not share CTE.
existing_conflict = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS2"})
assert has_conflict(node, existing_conflict) is True

# A conflict on any one key is enough, even when other keys agree.
node_multi = SimpleNamespace(expr_to_alias={"uuid1": "ALIAS1", "uuid2": "ALIAS2"})
existing_partial_conflict = SimpleNamespace(
expr_to_alias={"uuid1": "ALIAS1", "uuid2": "DIFFERENT"}
)
assert has_conflict(node_multi, existing_partial_conflict) is True

# Node without any expr_to_alias entries can never conflict.
node_empty = SimpleNamespace(expr_to_alias={})
assert has_conflict(node_empty, existing_conflict) is False


def test_select_statement_contains_data_generation(mock_session, mock_analyzer):
"""SelectStatement.contains_data_generation should detect zero-arg
Expand Down
Loading