diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index 18f635c588..a25798c178 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -224,8 +224,8 @@ def get_duplicated_node_complexity_distribution( def encode_query_id(node: "TreeNode") -> Optional[str]: """ - Encode the query and its query parameter into an id using sha256. - + Encode the query, its query parameter, expr_to_alias and df_aliased_col_name_to_real_col_name + into an id using sha256. Returns: If encode succeed, return the first 10 encoded value. @@ -252,7 +252,25 @@ def encode_query_id(node: "TreeNode") -> Optional[str]: # to avoid being detected as a common subquery. return None - string = f"{query}#{query_params}" if query_params else query + def stringify(d): + if isinstance(d, dict): + key_value_pairs = list(d.items()) + key_value_pairs.sort(key=lambda x: x[0]) + return str(key_value_pairs) + else: + return str(d) + + string = query + if query_params: + string = f"{string}#{query_params}" + if hasattr(node, "expr_to_alias") and node.expr_to_alias: + string = f"{string}#{stringify(node.expr_to_alias)}" + if ( + hasattr(node, "df_aliased_col_name_to_real_col_name") + and node.df_aliased_col_name_to_real_col_name + ): + string = f"{string}#{stringify(node.df_aliased_col_name_to_real_col_name)}" + try: return hashlib.sha256(string.encode()).hexdigest()[:10] except Exception as ex: diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 29910a355a..43589dae79 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -566,6 +566,82 @@ def test_same_duplicate_subtree(session): assert count_number_of_ctes(df_result2.queries["queries"][-1]) == 3 +@pytest.mark.parametrize("use_different_df", [True, False]) +def test_cte_preserves_join_suffix_aliases(session, use_different_df): + df_ad_group = session.create_dataframe( + [["1048771", "group_1", "campaign_1"]], + schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"], + ) + + df_ad_group_excv = session.create_dataframe( + [["1048771", "group_1", "device", "8308"]], + schema=["ACCOUNT_ID", "AD_GROUP_ID", "DEVICE", "EXTERNAL_CONVERSION_ID"], + ) + + df_ad_group_excv = df_ad_group_excv.join( + df_ad_group, + df_ad_group.col("AD_GROUP_ID") == df_ad_group_excv.col("AD_GROUP_ID"), + rsuffix="_WITH_AD_GROUP", + ).select( + col("ACCOUNT_ID"), + col("CAMPAIGN_ID"), + col("AD_GROUP_ID"), + lit(None).as_("AD_ID"), + ) + + if use_different_df: + df_ad_group = session.create_dataframe( + [["1048771", "group_1", "campaign_1"]], + schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"], + ) + + df_ad_group_ad = session.create_dataframe( + [["1048771", "ad_1", "group_1"]], + schema=["ACCOUNT_ID", "AD_ID", "AD_GROUP_ID"], + ) + + df_ad_excv = session.create_dataframe( + [["1048771", "group_1", "ad_1", "device", "8308"]], + schema=[ + "ACCOUNT_ID", + "AD_GROUP_ID", + "AD_ID", + "DEVICE", + "EXTERNAL_CONVERSION_ID", + ], + ) + + df_ad_excv = ( + df_ad_excv.join( + df_ad_group_ad, + df_ad_group_ad.col("AD_ID") == df_ad_excv.col("AD_ID"), + rsuffix="_WITH_AD_GROUP_AD", + ) + .join( + df_ad_group, + df_ad_group.col("AD_GROUP_ID") == df_ad_group_ad.col("AD_GROUP_ID"), + rsuffix="_WITH_AD_GROUP", + ) + .select( + col("ACCOUNT_ID"), + col("CAMPAIGN_ID"), + col("AD_GROUP_ID"), + col("AD_ID"), + ) + ) + + df_union = df_ad_group_excv.union_all(df_ad_excv) + union_sql = df_union.queries["queries"][-1] + + # the second one is incorrect join condition as we have rsuffix for join alias + assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql + assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql + # when using different df_ad_group with disambiguation, because rsuffix in join, + # they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE + # However there is still a CTE for create_dataframe call + assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1 + + @pytest.mark.parametrize( "mode", ["append", "truncate", "overwrite", "errorifexists", "ignore"] ) @@ -736,12 +812,12 @@ def test_sql_simplifier(session): describe_count_for_optimized=1 if session._join_alias_fix else None, ) with SqlCounter(query_count=0, describe_count=0): - # When adding a lsuffix, the columns of right dataframe don't need to be renamed, - # so we will get a common CTE with filter + # When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered + # different and have different ids. So only df1 and df will be converted to a CTE assert ( - count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2 + count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1 ) - assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2 + assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3 df7 = df1.with_column("c", lit(1)) df8 = df1.with_column("c", lit(1)).with_column("d", lit(1)) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index a0853bd6ae..2014350f35 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -2,6 +2,8 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import hashlib +from types import SimpleNamespace from unittest import mock import pytest @@ -103,3 +105,28 @@ def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer): encode_node_id_with_query(select_statement_node) == f"{expected_hash}_SelectStatement" ) + + +def test_encode_node_id_with_query_includes_aliases(): + node = SimpleNamespace( + sql_query="select col1 from t", + query_params=(("p1", 1), ("p2", "x")), + expr_to_alias={"uuid1": "ALIAS1"}, + df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"}, + ) + + def stringify_dict(d: dict) -> str: + key_value_pairs = list(d.items()) + key_value_pairs.sort(key=lambda x: x[0]) + return str(key_value_pairs) + + expected_string = node.sql_query + if node.query_params: + expected_string = f"{expected_string}#{node.query_params}" + if node.expr_to_alias: + expected_string = f"{expected_string}#{stringify_dict(node.expr_to_alias)}" + if node.df_aliased_col_name_to_real_col_name: + expected_string = f"{expected_string}#{stringify_dict(node.df_aliased_col_name_to_real_col_name)}" + + expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10] + assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace"