Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def get_duplicated_node_complexity_distribution(

def encode_query_id(node: "TreeNode") -> Optional[str]:
"""
Encode the query and its query parameter into an id using sha256.

Encode the query, its query parameter, expr_to_alias and df_aliased_col_name_to_real_col_name
into an id using sha256.

Returns:
If encode succeed, return the first 10 encoded value.
Expand All @@ -252,7 +252,25 @@ def encode_query_id(node: "TreeNode") -> Optional[str]:
# to avoid being detected as a common subquery.
return None

string = f"{query}#{query_params}" if query_params else query
def stringify(d):
if isinstance(d, dict):
key_value_pairs = list(d.items())
key_value_pairs.sort(key=lambda x: x[0])
return str(key_value_pairs)
else:
return str(d)

string = query
if query_params:
string = f"{string}#{query_params}"
if hasattr(node, "expr_to_alias") and node.expr_to_alias:
string = f"{string}#{stringify(node.expr_to_alias)}"
if (
hasattr(node, "df_aliased_col_name_to_real_col_name")
and node.df_aliased_col_name_to_real_col_name
):
string = f"{string}#{stringify(node.df_aliased_col_name_to_real_col_name)}"

try:
return hashlib.sha256(string.encode()).hexdigest()[:10]
except Exception as ex:
Expand Down
84 changes: 80 additions & 4 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,82 @@ def test_same_duplicate_subtree(session):
assert count_number_of_ctes(df_result2.queries["queries"][-1]) == 3


@pytest.mark.parametrize("use_different_df", [True, False])
def test_cte_preserves_join_suffix_aliases(session, use_different_df):
df_ad_group = session.create_dataframe(
[["1048771", "group_1", "campaign_1"]],
schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"],
)

df_ad_group_excv = session.create_dataframe(
[["1048771", "group_1", "device", "8308"]],
schema=["ACCOUNT_ID", "AD_GROUP_ID", "DEVICE", "EXTERNAL_CONVERSION_ID"],
)

df_ad_group_excv = df_ad_group_excv.join(
df_ad_group,
df_ad_group.col("AD_GROUP_ID") == df_ad_group_excv.col("AD_GROUP_ID"),
rsuffix="_WITH_AD_GROUP",
).select(
col("ACCOUNT_ID"),
col("CAMPAIGN_ID"),
col("AD_GROUP_ID"),
lit(None).as_("AD_ID"),
)

if use_different_df:
df_ad_group = session.create_dataframe(
[["1048771", "group_1", "campaign_1"]],
schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"],
)

df_ad_group_ad = session.create_dataframe(
[["1048771", "ad_1", "group_1"]],
schema=["ACCOUNT_ID", "AD_ID", "AD_GROUP_ID"],
)

df_ad_excv = session.create_dataframe(
[["1048771", "group_1", "ad_1", "device", "8308"]],
schema=[
"ACCOUNT_ID",
"AD_GROUP_ID",
"AD_ID",
"DEVICE",
"EXTERNAL_CONVERSION_ID",
],
)

df_ad_excv = (
df_ad_excv.join(
df_ad_group_ad,
df_ad_group_ad.col("AD_ID") == df_ad_excv.col("AD_ID"),
rsuffix="_WITH_AD_GROUP_AD",
)
.join(
df_ad_group,
df_ad_group.col("AD_GROUP_ID") == df_ad_group_ad.col("AD_GROUP_ID"),
rsuffix="_WITH_AD_GROUP",
)
.select(
col("ACCOUNT_ID"),
col("CAMPAIGN_ID"),
col("AD_GROUP_ID"),
col("AD_ID"),
)
)

df_union = df_ad_group_excv.union_all(df_ad_excv)
union_sql = df_union.queries["queries"][-1]

# the second one is incorrect join condition as we have rsuffix for join alias
assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql
assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql
# when using different df_ad_group with disambiguation, because rsuffix in join,
# they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE
# However there is still a CTE for create_dataframe call
assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1


@pytest.mark.parametrize(
"mode", ["append", "truncate", "overwrite", "errorifexists", "ignore"]
)
Expand Down Expand Up @@ -736,12 +812,12 @@ def test_sql_simplifier(session):
describe_count_for_optimized=1 if session._join_alias_fix else None,
)
with SqlCounter(query_count=0, describe_count=0):
# When adding a lsuffix, the columns of right dataframe don't need to be renamed,
# so we will get a common CTE with filter
# When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered
# different and have different ids. So only df1 and df will be converted to a CTE
assert (
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2
count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1
)
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2
assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3

df7 = df1.with_column("c", lit(1))
df8 = df1.with_column("c", lit(1)).with_column("d", lit(1))
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import hashlib
from types import SimpleNamespace
from unittest import mock

import pytest
Expand Down Expand Up @@ -103,3 +105,28 @@ def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer):
encode_node_id_with_query(select_statement_node)
== f"{expected_hash}_SelectStatement"
)


def test_encode_node_id_with_query_includes_aliases():
node = SimpleNamespace(
sql_query="select col1 from t",
query_params=(("p1", 1), ("p2", "x")),
expr_to_alias={"uuid1": "ALIAS1"},
df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"},
)

def stringify_dict(d: dict) -> str:
key_value_pairs = list(d.items())
key_value_pairs.sort(key=lambda x: x[0])
return str(key_value_pairs)

expected_string = node.sql_query
if node.query_params:
expected_string = f"{expected_string}#{node.query_params}"
if node.expr_to_alias:
expected_string = f"{expected_string}#{stringify_dict(node.expr_to_alias)}"
if node.df_aliased_col_name_to_real_col_name:
expected_string = f"{expected_string}#{stringify_dict(node.df_aliased_col_name_to_real_col_name)}"

expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10]
assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace"
Loading