Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
#### Bug Fixes

- Fixed a bug where `cloudpickle` could not be resolved when registering a Python stored procedure or UDF with `runtime_version='3.13'`.
- Fixed a bug where calling `DataFrame.alias()` twice on the same DataFrame (e.g. for a self-join) caused both aliases to share the same internal column-mapping dictionary. This made `col("R", "col")` resolve to the same column as `col("L", "col")`, producing incorrect join conditions and filter expressions.

#### New Features

- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments).

#### Bug Fixes
#### Improvements

- Fixed a bug where calling `DataFrame.alias()` twice on the same DataFrame (e.g. for a self-join) caused both aliases to share the same internal column-mapping dictionary. This made `col("R", "col")` resolve to the same column as `col("L", "col")`, producing incorrect join conditions and filter expressions.
- Improved CTE optimization to deduplicate identical subtrees in self-joins, which were previously emitted as repeated subqueries.

## 1.51.1 (2026-05-28)

Expand Down
2 changes: 0 additions & 2 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ def stringify(d):
string = query
if query_params:
string = f"{string}#{query_params}"
if hasattr(node, "expr_to_alias") and node.expr_to_alias:
string = f"{string}#{stringify(node.expr_to_alias)}"
if (
hasattr(node, "df_aliased_col_name_to_real_col_name")
and node.df_aliased_col_name_to_real_col_name
Expand Down
131 changes: 124 additions & 7 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import copy
import re
import tracemalloc
from contextlib import contextmanager
Expand Down Expand Up @@ -30,6 +31,8 @@
uuid_string,
when_matched,
to_timestamp,
stddev_samp,
when,
)
from snowflake.snowpark.types import (
StructType,
Expand Down Expand Up @@ -699,9 +702,10 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df):
# the second one is incorrect join condition as we have rsuffix for join alias
assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql
assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
# when using different df_ad_group with disambiguation, because rsuffix in join,
# they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
# However there is still a CTE for create_dataframe call
# Both cases produce 1 CTE: the disambiguated rhs_remapped wrapper nodes hash
# identically (same SQL + same alias values, different UUID keys), so they're
# merged into a single CTE via the expr_to_alias merge fix. The raw VALUES
# table is absorbed inline into that CTE body rather than becoming its own CTE.
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1


Expand Down Expand Up @@ -874,12 +878,15 @@ def test_sql_simplifier(session):
join_count=2,
)
with SqlCounter(query_count=0, describe_count=0):
# When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered
# different and have different ids. So only df1 and df will be converted to a CTE
# With value-sort hashing, df1/df2/df3 now hash identically (same SQL +
# same alias values, different UUID keys). df2 and df3 are replaced with a
# shared CTE, but df1's left-join position remains inline. That gives 2
# CTEs (base VALUES + filtered df1) and the filter appears twice (once in
# the CTE body, once inline for the left-join position).
assert (
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2
)
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2

df7 = df1.with_column("c", lit(1))
df8 = df1.with_column("c", lit(1)).with_column("d", lit(1))
Expand Down Expand Up @@ -1949,3 +1956,113 @@ def test_uniform_cte_optimization_depends_on_gen(session, use_bare_random, expec

vals = [row["VAL"] for row in result_df.collect()]
assert (vals[:5] == vals[5:]) == expect_cte


def test_cte_tpcds_q39_style_self_join_deduplication(session):
"""TPCDS_Q39-style self-join: filtered aggregation df aliased twice and self-joined.

Verifies that the shared `inv` computation (group-by + agg + cov filter) is
pushed into a single CTE rather than being inlined once per alias branch.
The CTE body should contain stddev_samp/avg exactly once; the outer query
references it twice (once for inv1, once for inv2).
"""
if not session._sql_simplifier_enabled:
pytest.skip("SQL simplifier is not enabled")

# Synthetic data shaped like the Q39 inventory result after the inner join:
# (item_sk, warehouse_sk, month, quantity). High-variance values so that
# BOTH months pass cov > 1 for each item/warehouse pair, making an incorrect
# cross-join (4 rows) detectable vs. the correct equi-join (2 rows).
raw = session.create_dataframe(
[
(10, 1, 1, 10),
(
10,
1,
1,
390,
), # item 10, wh 1, month 1: mean=200, stdev≈268.7, cov≈1.34 > 1
(10, 1, 2, 20),
(
10,
1,
2,
380,
), # item 10, wh 1, month 2: mean=200, stdev≈254.6, cov≈1.27 > 1
(20, 2, 1, 5),
(
20,
2,
1,
395,
), # item 20, wh 2, month 1: mean=200, stdev≈275.8, cov≈1.38 > 1
(20, 2, 2, 30),
(
20,
2,
2,
370,
), # item 20, wh 2, month 2: mean=200, stdev≈240.4, cov≈1.20 > 1
],
schema=["i_item_sk", "w_warehouse_sk", "d_moy", "qty"],
)

# Mirrors Q39's inner "foo" aggregation subquery.
agg = raw.group_by("i_item_sk", "w_warehouse_sk", "d_moy").agg(
stddev_samp("qty").alias("stdev"),
avg("qty").cast("double").alias("mean"),
)

# Mirrors Q39's outer "inv" CTE: compute cov and filter on cov > 1.
# All four (item, warehouse, month) combinations pass cov > 1.
inv = agg.with_column(
"cov",
when(col("mean") == 0, lit(None)).otherwise(col("stdev") / col("mean")),
).filter(when(col("mean") == 0, lit(0)).otherwise(col("stdev") / col("mean")) > 1)

inv_r = copy.copy(inv)
result = (
inv.join(inv_r, on=["i_item_sk", "w_warehouse_sk"], rsuffix="_r")
.filter(col("d_moy") == 1)
.filter(col("d_moy_r") == 2)
)

sql = result.queries["queries"][-1]
normalized = Utils.normalize_sql(sql)

with SqlCounter(query_count=0, describe_count=0):
# The shared `inv` computation should be deduplicated into exactly one CTE.
assert count_number_of_ctes(normalized) == 1

# The CTE should appear at least 3 times: once in the WITH definition
# and at least twice in the body (one per alias branch).
cte_name_match = re.search(r"WITH\s+(\w+)\s+AS", normalized)
assert cte_name_match is not None, "expected a WITH CTE in the generated SQL"
cte_name = cte_name_match.group(1)
assert (
normalized.count(cte_name) >= 3
), f"CTE '{cte_name}' should appear in the definition and both join branches"

# The aggregation (stddev_samp / GROUP BY) must appear exactly once —
# inside the CTE body. Two occurrences would mean `inv` is inlined
# separately for each alias branch instead of being shared.
assert (
normalized.lower().count("stddev_samp") == 1
), "stddev_samp should appear once (in the CTE), not once per alias branch"
assert (
normalized.upper().count("GROUP BY") == 1
), "GROUP BY should appear once (in the CTE), not once per alias branch"

# Correctness: the CTE-optimized result must match the non-optimized result.
# Correct equi-join on (i_item_sk, w_warehouse_sk) produces 2 rows (item 10
# and item 20, each pairing their month-1 and month-2 stats). A wrong
# cross-join would produce 4 rows, so this check_result is meaningful.
check_result(
session,
result,
expect_cte_optimized=True,
query_count=1,
describe_count=0,
union_count=0,
join_count=1,
)
Loading