From 6f56d6e0c420e5310002d09701658ce844cf6a05 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Thu, 25 Sep 2025 14:44:57 -0700 Subject: [PATCH 1/3] d --- .../snowpark/_internal/compiler/cte_utils.py | 16 +++- tests/integ/test_cte.py | 87 ++++++++++++++++++- tests/unit/test_cte.py | 24 +++++ 3 files changed, 120 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index 18f635c588..f9254bb476 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -224,8 +224,8 @@ def get_duplicated_node_complexity_distribution( def encode_query_id(node: "TreeNode") -> Optional[str]: """ - Encode the query and its query parameter into an id using sha256. - + Encode the query, its query parameter, expr_to_alias and df_aliased_col_name_to_real_col_name + into an id using sha256. Returns: If encode succeed, return the first 10 encoded value. @@ -252,7 +252,17 @@ def encode_query_id(node: "TreeNode") -> Optional[str]: # to avoid being detected as a common subquery. return None - string = f"{query}#{query_params}" if query_params else query + string = query + if query_params: + string = f"{string}#{query_params}" + if hasattr(node, "expr_to_alias") and node.expr_to_alias: + string = f"{string}#{node.expr_to_alias}" + if ( + hasattr(node, "df_aliased_col_name_to_real_col_name") + and node.df_aliased_col_name_to_real_col_name + ): + string = f"{string}#{node.df_aliased_col_name_to_real_col_name}" + try: return hashlib.sha256(string.encode()).hexdigest()[:10] except Exception as ex: diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 29910a355a..96421340a6 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -566,6 +566,85 @@ def test_same_duplicate_subtree(session): assert count_number_of_ctes(df_result2.queries["queries"][-1]) == 3 +@pytest.mark.parametrize("use_different_df", [True, False]) +def test_cte_preserves_join_suffix_aliases(session, use_different_df): + df_ad_group = session.create_dataframe( + [["1048771", "group_1", "campaign_1"]], + schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"], + ) + + df_ad_group_excv = session.create_dataframe( + [["1048771", "group_1", "device", "8308"]], + schema=["ACCOUNT_ID", "AD_GROUP_ID", "DEVICE", "EXTERNAL_CONVERSION_ID"], + ) + + df_ad_group_excv = df_ad_group_excv.join( + df_ad_group, + df_ad_group.col("AD_GROUP_ID") == df_ad_group_excv.col("AD_GROUP_ID"), + rsuffix="_WITH_AD_GROUP", + ).select( + col("ACCOUNT_ID"), + col("CAMPAIGN_ID"), + col("AD_GROUP_ID"), + lit(None).as_("AD_ID"), + ) + + if use_different_df: + df_ad_group = session.create_dataframe( + [["1048771", "group_1", "campaign_1"]], + schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"], + ) + + df_ad_group_ad = session.create_dataframe( + [["1048771", "ad_1", "group_1"]], + schema=["ACCOUNT_ID", "AD_ID", "AD_GROUP_ID"], + ) + + df_ad_excv = session.create_dataframe( + [["1048771", "group_1", "ad_1", "device", "8308"]], + schema=[ + "ACCOUNT_ID", + "AD_GROUP_ID", + "AD_ID", + "DEVICE", + "EXTERNAL_CONVERSION_ID", + ], + ) + + df_ad_excv = ( + df_ad_excv.join( + df_ad_group_ad, + df_ad_group_ad.col("AD_ID") == df_ad_excv.col("AD_ID"), + rsuffix="_WITH_AD_GROUP_AD", + ) + .join( + df_ad_group, + df_ad_group.col("AD_GROUP_ID") == df_ad_group_ad.col("AD_GROUP_ID"), + rsuffix="_WITH_AD_GROUP", + ) + .select( + col("ACCOUNT_ID"), + col("CAMPAIGN_ID"), + col("AD_GROUP_ID"), + col("AD_ID"), + ) + ) + + df_union = df_ad_group_excv.union_all(df_ad_excv) + union_sql = df_union.queries["queries"][-1] + + # the second one is incorrect join condition as we have rsuffix for join alias + assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql + assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql + # when using different df_ad_group, because rsuffix in join, they have different alias map (expr_to_alias), + # so they are considered different and we can't convert them to a CTE + assert ( + count_number_of_ctes(Utils.normalize_sql(union_sql)) == 0 + if use_different_df + else 1 + ) + + @pytest.mark.parametrize( "mode", ["append", "truncate", "overwrite", "errorifexists", "ignore"] ) @@ -736,12 +815,12 @@ def test_sql_simplifier(session): describe_count_for_optimized=1 if session._join_alias_fix else None, ) with SqlCounter(query_count=0, describe_count=0): - # When adding a lsuffix, the columns of right dataframe don't need to be renamed, - # so we will get a common CTE with filter + # When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered + # different and have different ids. So only df1 and df will be converted to a CTE assert ( - count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 2 + count_number_of_ctes(Utils.normalize_sql(df6.queries["queries"][-1])) == 1 ) - assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 2 + assert Utils.normalize_sql(df6.queries["queries"][-1]).count(filter_clause) == 3 df7 = df1.with_column("c", lit(1)) df8 = df1.with_column("c", lit(1)).with_column("d", lit(1)) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index a0853bd6ae..120e580feb 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -2,6 +2,8 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import hashlib +from types import SimpleNamespace from unittest import mock import pytest @@ -103,3 +105,25 @@ def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer): encode_node_id_with_query(select_statement_node) == f"{expected_hash}_SelectStatement" ) + + +def test_encode_node_id_with_query_includes_aliases(): + node = SimpleNamespace( + sql_query="select col1 from t", + query_params=(("p1", 1), ("p2", "x")), + expr_to_alias=(("expr1", "alias1"),), + df_aliased_col_name_to_real_col_name=(("ALIAS1", "COL1"),), + ) + + expected_string = node.sql_query + if node.query_params: + expected_string = f"{expected_string}#{node.query_params}" + if node.expr_to_alias: + expected_string = f"{expected_string}#{node.expr_to_alias}" + if node.df_aliased_col_name_to_real_col_name: + expected_string = ( + f"{expected_string}#{node.df_aliased_col_name_to_real_col_name}" + ) + + expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10] + assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace" From 11e6d56e49ff5ee037bf1efb12fb2d596c4fcb80 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Thu, 25 Sep 2025 15:03:32 -0700 Subject: [PATCH 2/3] d --- .../snowpark/_internal/compiler/cte_utils.py | 12 ++++++++++-- tests/unit/test_cte.py | 15 +++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index f9254bb476..a25798c178 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -252,16 +252,24 @@ def encode_query_id(node: "TreeNode") -> Optional[str]: # to avoid being detected as a common subquery. return None + def stringify(d): + if isinstance(d, dict): + key_value_pairs = list(d.items()) + key_value_pairs.sort(key=lambda x: x[0]) + return str(key_value_pairs) + else: + return str(d) + string = query if query_params: string = f"{string}#{query_params}" if hasattr(node, "expr_to_alias") and node.expr_to_alias: - string = f"{string}#{node.expr_to_alias}" + string = f"{string}#{stringify(node.expr_to_alias)}" if ( hasattr(node, "df_aliased_col_name_to_real_col_name") and node.df_aliased_col_name_to_real_col_name ): - string = f"{string}#{node.df_aliased_col_name_to_real_col_name}" + string = f"{string}#{stringify(node.df_aliased_col_name_to_real_col_name)}" try: return hashlib.sha256(string.encode()).hexdigest()[:10] diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 120e580feb..2014350f35 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -111,19 +111,22 @@ def test_encode_node_id_with_query_includes_aliases(): node = SimpleNamespace( sql_query="select col1 from t", query_params=(("p1", 1), ("p2", "x")), - expr_to_alias=(("expr1", "alias1"),), - df_aliased_col_name_to_real_col_name=(("ALIAS1", "COL1"),), + expr_to_alias={"uuid1": "ALIAS1"}, + df_aliased_col_name_to_real_col_name={"ALIAS1": "col1"}, ) + def stringify_dict(d: dict) -> str: + key_value_pairs = list(d.items()) + key_value_pairs.sort(key=lambda x: x[0]) + return str(key_value_pairs) + expected_string = node.sql_query if node.query_params: expected_string = f"{expected_string}#{node.query_params}" if node.expr_to_alias: - expected_string = f"{expected_string}#{node.expr_to_alias}" + expected_string = f"{expected_string}#{stringify_dict(node.expr_to_alias)}" if node.df_aliased_col_name_to_real_col_name: - expected_string = ( - f"{expected_string}#{node.df_aliased_col_name_to_real_col_name}" - ) + expected_string = f"{expected_string}#{stringify_dict(node.df_aliased_col_name_to_real_col_name)}" expected_hash = hashlib.sha256(expected_string.encode()).hexdigest()[:10] assert encode_node_id_with_query(node) == f"{expected_hash}_SimpleNamespace" From c5dcced5c759de7d0c0d2a775fe75ffb2a923506 Mon Sep 17 00:00:00 2001 From: Jianzhun Du Date: Thu, 25 Sep 2025 16:14:23 -0700 Subject: [PATCH 3/3] fix --- tests/integ/test_cte.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 96421340a6..43589dae79 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -636,13 +636,10 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df): # the second one is incorrect join condition as we have rsuffix for join alias assert 'ON ("AD_GROUP_ID_WITH_AD_GROUP" = "AD_GROUP_ID")' in union_sql assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql - # when using different df_ad_group, because rsuffix in join, they have different alias map (expr_to_alias), - # so they are considered different and we can't convert them to a CTE - assert ( - count_number_of_ctes(Utils.normalize_sql(union_sql)) == 0 - if use_different_df - else 1 - ) + # when using different df_ad_group with disambiguation, because rsuffix in join, + # they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE + # However there is still a CTE for create_dataframe call + assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1 @pytest.mark.parametrize(