Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#### Improvements

- Improved CTE optimization to deduplicate identical subtrees in self-joins, which were previously emitted as repeated subqueries.
- Reduced the size of generated query text for repeated join operations.

#### Deprecations

Expand Down
89 changes: 72 additions & 17 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,33 @@ def lateral_join_statement(
)


_SELECT_STAR_FROM_PREFIX = SELECT + STAR + NEW_LINE + FROM + LEFT_PARENTHESIS + NEW_LINE
_SELECT_STAR_FROM_SUFFIX = NEW_LINE + RIGHT_PARENTHESIS


def _unwrap_select_star_from(sql: str) -> Optional[str]:
"""If sql is a join-produced `SELECT * FROM (\n<join_source>\n)` (the
output of project_statement([], join_source)), return <join_source>.
Only unwraps when the inner content starts with '(' (possibly preceded
by UUID trace comments) which indicates a parenthesized join operand
rather than a wrapped SELECT statement."""
if sql.startswith(_SELECT_STAR_FROM_PREFIX) and sql.endswith(
_SELECT_STAR_FROM_SUFFIX
):
inner = sql[len(_SELECT_STAR_FROM_PREFIX) : -len(_SELECT_STAR_FROM_SUFFIX)]
# In trace-SQL mode, UUID comments (\n-- <uuid>\n) may precede the
# opening parenthesis. Strip them before checking.
check = inner.lstrip("\n")
if check.startswith("--"):
# Skip the comment line and any trailing newline
newline_pos = check.find("\n")
if newline_pos != -1:
check = check[newline_pos + 1 :]
if check.startswith(LEFT_PARENTHESIS) or inner.startswith(LEFT_PARENTHESIS):
return inner
return None
Comment thread
sfc-gh-aling marked this conversation as resolved.


def snowflake_supported_join_statement(
left: str,
right: str,
Expand All @@ -971,17 +998,30 @@ def snowflake_supported_join_statement(
left_uuid: Optional[str] = None,
right_uuid: Optional[str] = None,
directed: bool = False,
left_is_join: bool = False,
) -> str:
LEFT_UUID = format_uuid(left_uuid)
RIGHT_UUID = format_uuid(right_uuid)
left_alias = (
"SNOWPARK_LEFT"
Comment thread
sfc-gh-aling marked this conversation as resolved.
if use_constant_subquery_alias
else random_name_for_temp_object(TempObjectType.TABLE)
)

# If left is the output of a previous join, flatten into a multi-way join
# by unwrapping the SELECT * FROM (...) envelope and appending the new
# right operand directly to the existing join source. This avoids nested
# SELECT * layers that inflate query text without changing semantics.
#
# Though it is technically less efficient than constructing the join sub-queries
# without the SELECT in the first place, the structure of our SQL processing code
# needs top-level projections to be wrapped by a select to be well-formed, so we
# must strip it here instead.
#
# We only unwrap the left side because it is simpler to deal with than unwrapping
# both left and right, and left-deep chains are more common, as they're produced
# by calls like df1.join(df2).join(df3) etc.
unwrapped_left = _unwrap_select_star_from(left) if left_is_join else None
right_alias = (
"SNOWPARK_RIGHT"
if use_constant_subquery_alias
if use_constant_subquery_alias and unwrapped_left is None
# Multi-way join: right alias must be unique to avoid collisions
# with aliases already present in the flattened join source.
else random_name_for_temp_object(TempObjectType.TABLE)
)

Expand Down Expand Up @@ -1017,18 +1057,31 @@ def snowflake_supported_join_statement(

maybe_directed_sql = DIRECTED_JOIN if directed else JOIN

if unwrapped_left is not None:
# No need for additional parentheses around the left expression here, since it
# should already be parenthesized
left_expr = LEFT_UUID + unwrapped_left + NEW_LINE + LEFT_UUID
else:
left_alias = (
"SNOWPARK_LEFT"
if use_constant_subquery_alias
else random_name_for_temp_object(TempObjectType.TABLE)
)
left_expr = (
LEFT_PARENTHESIS
+ NEW_LINE
+ LEFT_UUID
+ left
+ NEW_LINE
+ LEFT_UUID
+ RIGHT_PARENTHESIS
+ AS
+ left_alias
+ SPACE
+ NEW_LINE
)
source = (
LEFT_PARENTHESIS
+ NEW_LINE
+ LEFT_UUID
+ left
+ NEW_LINE
+ LEFT_UUID
+ RIGHT_PARENTHESIS
+ AS
+ left_alias
+ SPACE
+ NEW_LINE
left_expr
+ join_sql
+ maybe_directed_sql
+ NEW_LINE
Expand Down Expand Up @@ -1060,6 +1113,7 @@ def join_statement(
left_uuid: Optional[str] = None,
right_uuid: Optional[str] = None,
directed: bool = False,
left_is_join: bool = False,
) -> str:
if isinstance(join_type, (LeftSemi, LeftAnti)):
return left_semi_or_anti_join_statement(
Expand Down Expand Up @@ -1087,6 +1141,7 @@ def join_statement(
left_uuid=left_uuid,
right_uuid=right_uuid,
directed=directed,
left_is_join=left_is_join,
)


Expand Down
13 changes: 13 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
# Add the last df ast id to the snowflake plan as the most recent
# dataframe operation to create this plan.
self._snowflake_plan.df_ast_ids = self.df_ast_ids
# Propagate join output flag through passthrough SelectStatements
# so chained joins can flatten into multi-way joins.
# Only SelectStatement has has_clause/has_projection; other
# Selectable subclasses (SelectSQL, SetStatement, etc.) skip this.
if (
isinstance(self, SelectStatement)
and not self.has_clause
and not self.has_projection
and isinstance(self.from_, SelectSnowflakePlan)
):
self._snowflake_plan._is_join_output = (
self.from_._snowflake_plan._is_join_output
)
return self._snowflake_plan

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def __init__(
self.session = session
self.source_plan = source_plan
self.is_ddl_on_temp_object = is_ddl_on_temp_object
self._is_join_output = False
# We need to copy this list since we don't want to change it for the
# previous SnowflakePlan objects
self.api_calls = api_calls.copy() if api_calls else []
Expand Down Expand Up @@ -769,6 +770,7 @@ def __copy__(self) -> "SnowflakePlan":
referenced_ctes=self.referenced_ctes,
)
plan.df_ast_ids = self.df_ast_ids
plan._is_join_output = self._is_join_output
return plan

def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
Expand Down Expand Up @@ -808,6 +810,7 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
if copied_source_plan:
copied_source_plan._is_valid_for_replacement = True
copied_plan.df_ast_ids = self.df_ast_ids
copied_plan._is_join_output = self._is_join_output

return copied_plan

Expand Down Expand Up @@ -1231,7 +1234,8 @@ def join(
use_constant_subquery_alias: bool,
directed: bool = False,
):
return self.build_binary(
left_is_join = left._is_join_output
result = self.build_binary(
lambda x, y: join_statement(
x,
y,
Expand All @@ -1248,11 +1252,14 @@ def join(
else None
),
directed=directed,
left_is_join=left_is_join,
),
left,
right,
source_plan,
)
result._is_join_output = True
return result

def save_as_table(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def mock_snowflake_plan() -> SnowflakePlan:
fake_snowflake_plan.referenced_ctes = {with_query_block: 1}
fake_snowflake_plan._cumulative_node_complexity = {}
fake_snowflake_plan._is_valid_for_replacement = True
fake_snowflake_plan._is_join_output = False
fake_snowflake_plan._metadata = mock.create_autospec(PlanMetadata)
fake_snowflake_plan._metadata.attributes = {}
fake_snowflake_plan.query_line_intervals = []
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_analyzer_util_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,136 @@ def test_join_statement_negative():
join_statement("", "", join_type, "cond2", "", False)


def test_join_statement_flattens_chained_joins():
"""Chained joins should produce a flat multi-way join, not nested SELECT * wrappers."""
join_type = UsingJoin(Inner(), ["key"])

# First join: produces SELECT * FROM ((left) AS L JOIN (right) AS R USING (key))
first_join = join_statement(
"SELECT * FROM table_a",
"SELECT * FROM table_b",
join_type,
"",
"",
True,
)

# Second join uses first join's output as left operand.
# left_is_join=True signals that the left operand is a join result.
second_join = join_statement(
first_join,
"SELECT * FROM table_c",
join_type,
"",
"",
True,
left_is_join=True,
)

# Should NOT have nested SELECT * FROM (SELECT * FROM (...))
# Count occurrences of "SELECT" — expect exactly one top-level SELECT *
assert second_join.count(" SELECT ") == 1

# The SQL should contain all three table references at the same nesting level
assert "table_a" in second_join
assert "table_b" in second_join
assert "table_c" in second_join


def test_join_statement_flattens_with_uuid_trace_comments():
"""Flattening must work when UUID trace comments are present (trace-SQL mode)."""
join_type = UsingJoin(Inner(), ["key"])

# First join with UUID trace comments
first_join = join_statement(
"SELECT * FROM table_a",
"SELECT * FROM table_b",
join_type,
"",
"",
True,
left_uuid="aaaa-bbbb",
right_uuid="cccc-dddd",
)

# Second join: flattening with UUIDs
second_join = join_statement(
first_join,
"SELECT * FROM table_c",
join_type,
"",
"",
True,
left_uuid="eeee-ffff",
right_uuid="1111-2222",
left_is_join=True,
)

# Should still flatten — only one top-level SELECT *
assert second_join.count(" SELECT ") == 1
assert "table_a" in second_join
assert "table_b" in second_join
assert "table_c" in second_join

# Third join: must also flatten successfully despite accumulated UUID comments
third_join = join_statement(
second_join,
"SELECT * FROM table_d",
join_type,
"",
"",
True,
left_uuid="3333-4444",
right_uuid="5555-6666",
left_is_join=True,
)

assert third_join.count(" SELECT ") == 1
assert "table_a" in third_join
assert "table_b" in third_join
assert "table_c" in third_join
assert "table_d" in third_join


def test_join_statement_does_not_flatten_user_generated_select_star():
"""A user-generated SELECT * that coincidentally matches the internal
pattern must NOT be flattened when left_is_join is False (the default)."""
join_type = UsingJoin(Inner(), ["key"])

# Craft a SQL string that matches the internal _SELECT_STAR_FROM_PREFIX/SUFFIX
# pattern exactly — this simulates what session.sql("SELECT * FROM (...)") or
# a similar user-provided query might produce.
user_sql = (
" SELECT * \n FROM (\n"
"(SELECT id, key FROM user_table) AS t1"
" INNER JOIN \n(SELECT id, key FROM other_table) AS t2\n"
" USING (key)"
"\n)"
)

# join_statement with left_is_join=False (default) — should NOT unwrap because
# left_is_join is only True when a plan object is constructed from a dataframe
# join operation
result = join_statement(
user_sql,
"SELECT * FROM table_c",
join_type,
"",
"",
True,
left_is_join=False,
)

# The user SQL should be preserved as a nested subquery, producing
# two SELECT levels (the outer wrapper + the user's original SELECT *)
assert result.count(" SELECT ") >= 2

# The user's original SQL should appear within the output intact
assert "user_table" in result
assert "other_table" in result
assert "table_c" in result


def test_create_iceberg_table_statement():
assert create_table_statement(
table_name="test_table",
Expand Down
Loading