diff --git a/CHANGELOG.md b/CHANGELOG.md index fdcd6a6acd..1d17cdc3e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ #### Improvements - Improved CTE optimization to deduplicate identical subtrees in self-joins, which were previously emitted as repeated subqueries. +- Reduced the size of generated query text for repeated join operations. #### Deprecations diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 0429f26576..3c8a275ba4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -961,6 +961,33 @@ def lateral_join_statement( ) +_SELECT_STAR_FROM_PREFIX = SELECT + STAR + NEW_LINE + FROM + LEFT_PARENTHESIS + NEW_LINE +_SELECT_STAR_FROM_SUFFIX = NEW_LINE + RIGHT_PARENTHESIS + + +def _unwrap_select_star_from(sql: str) -> Optional[str]: + """If sql is a join-produced `SELECT * FROM (\n\n)` (the + output of project_statement([], join_source)), return . + Only unwraps when the inner content starts with '(' (possibly preceded + by UUID trace comments) which indicates a parenthesized join operand + rather than a wrapped SELECT statement.""" + if sql.startswith(_SELECT_STAR_FROM_PREFIX) and sql.endswith( + _SELECT_STAR_FROM_SUFFIX + ): + inner = sql[len(_SELECT_STAR_FROM_PREFIX) : -len(_SELECT_STAR_FROM_SUFFIX)] + # In trace-SQL mode, UUID comments (\n-- \n) may precede the + # opening parenthesis. Strip them before checking. + check = inner.lstrip("\n") + if check.startswith("--"): + # Skip the comment line and any trailing newline + newline_pos = check.find("\n") + if newline_pos != -1: + check = check[newline_pos + 1 :] + if check.startswith(LEFT_PARENTHESIS) or inner.startswith(LEFT_PARENTHESIS): + return inner + return None + + def snowflake_supported_join_statement( left: str, right: str, @@ -971,17 +998,30 @@ def snowflake_supported_join_statement( left_uuid: Optional[str] = None, right_uuid: Optional[str] = None, directed: bool = False, + left_is_join: bool = False, ) -> str: LEFT_UUID = format_uuid(left_uuid) RIGHT_UUID = format_uuid(right_uuid) - left_alias = ( - "SNOWPARK_LEFT" - if use_constant_subquery_alias - else random_name_for_temp_object(TempObjectType.TABLE) - ) + + # If left is the output of a previous join, flatten into a multi-way join + # by unwrapping the SELECT * FROM (...) envelope and appending the new + # right operand directly to the existing join source. This avoids nested + # SELECT * layers that inflate query text without changing semantics. + # + # Though it is technically less efficient than constructing the join sub-queries + # without the SELECT in the first place, the structure of our SQL processing code + # needs top-level projections to be wrapped by a select to be well-formed, so we + # must strip it here instead. + # + # We only unwrap the left side because it is simpler to deal with than unwrapping + # both left and right, and left-deep chains are more common, as they're produced + # by calls like df1.join(df2).join(df3) etc. + unwrapped_left = _unwrap_select_star_from(left) if left_is_join else None right_alias = ( "SNOWPARK_RIGHT" - if use_constant_subquery_alias + if use_constant_subquery_alias and unwrapped_left is None + # Multi-way join: right alias must be unique to avoid collisions + # with aliases already present in the flattened join source. else random_name_for_temp_object(TempObjectType.TABLE) ) @@ -1017,18 +1057,31 @@ def snowflake_supported_join_statement( maybe_directed_sql = DIRECTED_JOIN if directed else JOIN + if unwrapped_left is not None: + # No need for additional parentheses around the left expression here, since it + # should already be parenthesized + left_expr = LEFT_UUID + unwrapped_left + NEW_LINE + LEFT_UUID + else: + left_alias = ( + "SNOWPARK_LEFT" + if use_constant_subquery_alias + else random_name_for_temp_object(TempObjectType.TABLE) + ) + left_expr = ( + LEFT_PARENTHESIS + + NEW_LINE + + LEFT_UUID + + left + + NEW_LINE + + LEFT_UUID + + RIGHT_PARENTHESIS + + AS + + left_alias + + SPACE + + NEW_LINE + ) source = ( - LEFT_PARENTHESIS - + NEW_LINE - + LEFT_UUID - + left - + NEW_LINE - + LEFT_UUID - + RIGHT_PARENTHESIS - + AS - + left_alias - + SPACE - + NEW_LINE + left_expr + join_sql + maybe_directed_sql + NEW_LINE @@ -1060,6 +1113,7 @@ def join_statement( left_uuid: Optional[str] = None, right_uuid: Optional[str] = None, directed: bool = False, + left_is_join: bool = False, ) -> str: if isinstance(join_type, (LeftSemi, LeftAnti)): return left_semi_or_anti_join_statement( @@ -1087,6 +1141,7 @@ def join_statement( left_uuid=left_uuid, right_uuid=right_uuid, directed=directed, + left_is_join=left_is_join, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 8130a6d24a..15db45a247 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -444,6 +444,19 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan: # Add the last df ast id to the snowflake plan as the most recent # dataframe operation to create this plan. self._snowflake_plan.df_ast_ids = self.df_ast_ids + # Propagate join output flag through passthrough SelectStatements + # so chained joins can flatten into multi-way joins. + # Only SelectStatement has has_clause/has_projection; other + # Selectable subclasses (SelectSQL, SetStatement, etc.) skip this. + if ( + isinstance(self, SelectStatement) + and not self.has_clause + and not self.has_projection + and isinstance(self.from_, SelectSnowflakePlan) + ): + self._snowflake_plan._is_join_output = ( + self.from_._snowflake_plan._is_join_output + ) return self._snowflake_plan @property diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index d42549fcff..45cfbeec8c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -458,6 +458,7 @@ def __init__( self.session = session self.source_plan = source_plan self.is_ddl_on_temp_object = is_ddl_on_temp_object + self._is_join_output = False # We need to copy this list since we don't want to change it for the # previous SnowflakePlan objects self.api_calls = api_calls.copy() if api_calls else [] @@ -769,6 +770,7 @@ def __copy__(self) -> "SnowflakePlan": referenced_ctes=self.referenced_ctes, ) plan.df_ast_ids = self.df_ast_ids + plan._is_join_output = self._is_join_output return plan def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 @@ -808,6 +810,7 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 if copied_source_plan: copied_source_plan._is_valid_for_replacement = True copied_plan.df_ast_ids = self.df_ast_ids + copied_plan._is_join_output = self._is_join_output return copied_plan @@ -1231,7 +1234,8 @@ def join( use_constant_subquery_alias: bool, directed: bool = False, ): - return self.build_binary( + left_is_join = left._is_join_output + result = self.build_binary( lambda x, y: join_statement( x, y, @@ -1248,11 +1252,14 @@ def join( else None ), directed=directed, + left_is_join=left_is_join, ), left, right, source_plan, ) + result._is_join_output = True + return result def save_as_table( self, diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 3a78bd2c93..4c1eb3649f 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -70,6 +70,7 @@ def mock_snowflake_plan() -> SnowflakePlan: fake_snowflake_plan.referenced_ctes = {with_query_block: 1} fake_snowflake_plan._cumulative_node_complexity = {} fake_snowflake_plan._is_valid_for_replacement = True + fake_snowflake_plan._is_join_output = False fake_snowflake_plan._metadata = mock.create_autospec(PlanMetadata) fake_snowflake_plan._metadata.attributes = {} fake_snowflake_plan.query_line_intervals = [] diff --git a/tests/unit/test_analyzer_util_suite.py b/tests/unit/test_analyzer_util_suite.py index e1ad1e28dd..08f293964c 100644 --- a/tests/unit/test_analyzer_util_suite.py +++ b/tests/unit/test_analyzer_util_suite.py @@ -451,6 +451,136 @@ def test_join_statement_negative(): join_statement("", "", join_type, "cond2", "", False) +def test_join_statement_flattens_chained_joins(): + """Chained joins should produce a flat multi-way join, not nested SELECT * wrappers.""" + join_type = UsingJoin(Inner(), ["key"]) + + # First join: produces SELECT * FROM ((left) AS L JOIN (right) AS R USING (key)) + first_join = join_statement( + "SELECT * FROM table_a", + "SELECT * FROM table_b", + join_type, + "", + "", + True, + ) + + # Second join uses first join's output as left operand. + # left_is_join=True signals that the left operand is a join result. + second_join = join_statement( + first_join, + "SELECT * FROM table_c", + join_type, + "", + "", + True, + left_is_join=True, + ) + + # Should NOT have nested SELECT * FROM (SELECT * FROM (...)) + # Count occurrences of "SELECT" — expect exactly one top-level SELECT * + assert second_join.count(" SELECT ") == 1 + + # The SQL should contain all three table references at the same nesting level + assert "table_a" in second_join + assert "table_b" in second_join + assert "table_c" in second_join + + +def test_join_statement_flattens_with_uuid_trace_comments(): + """Flattening must work when UUID trace comments are present (trace-SQL mode).""" + join_type = UsingJoin(Inner(), ["key"]) + + # First join with UUID trace comments + first_join = join_statement( + "SELECT * FROM table_a", + "SELECT * FROM table_b", + join_type, + "", + "", + True, + left_uuid="aaaa-bbbb", + right_uuid="cccc-dddd", + ) + + # Second join: flattening with UUIDs + second_join = join_statement( + first_join, + "SELECT * FROM table_c", + join_type, + "", + "", + True, + left_uuid="eeee-ffff", + right_uuid="1111-2222", + left_is_join=True, + ) + + # Should still flatten — only one top-level SELECT * + assert second_join.count(" SELECT ") == 1 + assert "table_a" in second_join + assert "table_b" in second_join + assert "table_c" in second_join + + # Third join: must also flatten successfully despite accumulated UUID comments + third_join = join_statement( + second_join, + "SELECT * FROM table_d", + join_type, + "", + "", + True, + left_uuid="3333-4444", + right_uuid="5555-6666", + left_is_join=True, + ) + + assert third_join.count(" SELECT ") == 1 + assert "table_a" in third_join + assert "table_b" in third_join + assert "table_c" in third_join + assert "table_d" in third_join + + +def test_join_statement_does_not_flatten_user_generated_select_star(): + """A user-generated SELECT * that coincidentally matches the internal + pattern must NOT be flattened when left_is_join is False (the default).""" + join_type = UsingJoin(Inner(), ["key"]) + + # Craft a SQL string that matches the internal _SELECT_STAR_FROM_PREFIX/SUFFIX + # pattern exactly — this simulates what session.sql("SELECT * FROM (...)") or + # a similar user-provided query might produce. + user_sql = ( + " SELECT * \n FROM (\n" + "(SELECT id, key FROM user_table) AS t1" + " INNER JOIN \n(SELECT id, key FROM other_table) AS t2\n" + " USING (key)" + "\n)" + ) + + # join_statement with left_is_join=False (default) — should NOT unwrap because + # left_is_join is only True when a plan object is constructed from a dataframe + # join operation + result = join_statement( + user_sql, + "SELECT * FROM table_c", + join_type, + "", + "", + True, + left_is_join=False, + ) + + # The user SQL should be preserved as a nested subquery, producing + # two SELECT levels (the outer wrapper + the user's original SELECT *) + assert result.count(" SELECT ") >= 2 + + # The user's original SQL should appear within the output intact + assert "user_table" in result + assert "other_table" in result + assert "table_c" in result + + def test_create_iceberg_table_statement(): assert create_table_statement( table_name="test_table",