Skip to content

Commit e22bc94

Browse files
authored
SNOW-3485482: Eliminate unnecessary SELECT * from joins (#4248)
1 parent 86aa0e9 commit e22bc94

6 files changed

Lines changed: 225 additions & 18 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#### Improvements
1818

1919
- Improved CTE optimization to deduplicate identical subtrees in self-joins, which were previously emitted as repeated subqueries.
20+
- Reduced the size of generated query text for repeated join operations.
2021

2122
#### Deprecations
2223

src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,33 @@ def lateral_join_statement(
961961
)
962962

963963

964+
_SELECT_STAR_FROM_PREFIX = SELECT + STAR + NEW_LINE + FROM + LEFT_PARENTHESIS + NEW_LINE
965+
_SELECT_STAR_FROM_SUFFIX = NEW_LINE + RIGHT_PARENTHESIS
966+
967+
968+
def _unwrap_select_star_from(sql: str) -> Optional[str]:
969+
"""If sql is a join-produced `SELECT * FROM (\n<join_source>\n)` (the
970+
output of project_statement([], join_source)), return <join_source>.
971+
Only unwraps when the inner content starts with '(' (possibly preceded
972+
by UUID trace comments) which indicates a parenthesized join operand
973+
rather than a wrapped SELECT statement."""
974+
if sql.startswith(_SELECT_STAR_FROM_PREFIX) and sql.endswith(
975+
_SELECT_STAR_FROM_SUFFIX
976+
):
977+
inner = sql[len(_SELECT_STAR_FROM_PREFIX) : -len(_SELECT_STAR_FROM_SUFFIX)]
978+
# In trace-SQL mode, UUID comments (\n-- <uuid>\n) may precede the
979+
# opening parenthesis. Strip them before checking.
980+
check = inner.lstrip("\n")
981+
if check.startswith("--"):
982+
# Skip the comment line and any trailing newline
983+
newline_pos = check.find("\n")
984+
if newline_pos != -1:
985+
check = check[newline_pos + 1 :]
986+
if check.startswith(LEFT_PARENTHESIS) or inner.startswith(LEFT_PARENTHESIS):
987+
return inner
988+
return None
989+
990+
964991
def snowflake_supported_join_statement(
965992
left: str,
966993
right: str,
@@ -971,17 +998,30 @@ def snowflake_supported_join_statement(
971998
left_uuid: Optional[str] = None,
972999
right_uuid: Optional[str] = None,
9731000
directed: bool = False,
1001+
left_is_join: bool = False,
9741002
) -> str:
9751003
LEFT_UUID = format_uuid(left_uuid)
9761004
RIGHT_UUID = format_uuid(right_uuid)
977-
left_alias = (
978-
"SNOWPARK_LEFT"
979-
if use_constant_subquery_alias
980-
else random_name_for_temp_object(TempObjectType.TABLE)
981-
)
1005+
1006+
# If left is the output of a previous join, flatten into a multi-way join
1007+
# by unwrapping the SELECT * FROM (...) envelope and appending the new
1008+
# right operand directly to the existing join source. This avoids nested
1009+
# SELECT * layers that inflate query text without changing semantics.
1010+
#
1011+
# Though it is technically less efficient than constructing the join sub-queries
1012+
# without the SELECT in the first place, the structure of our SQL processing code
1013+
# needs top-level projections to be wrapped by a select to be well-formed, so we
1014+
# must strip it here instead.
1015+
#
1016+
# We only unwrap the left side because it is simpler to deal with than unwrapping
1017+
# both left and right, and left-deep chains are more common, as they're produced
1018+
# by calls like df1.join(df2).join(df3) etc.
1019+
unwrapped_left = _unwrap_select_star_from(left) if left_is_join else None
9821020
right_alias = (
9831021
"SNOWPARK_RIGHT"
984-
if use_constant_subquery_alias
1022+
if use_constant_subquery_alias and unwrapped_left is None
1023+
# Multi-way join: right alias must be unique to avoid collisions
1024+
# with aliases already present in the flattened join source.
9851025
else random_name_for_temp_object(TempObjectType.TABLE)
9861026
)
9871027

@@ -1017,18 +1057,31 @@ def snowflake_supported_join_statement(
10171057

10181058
maybe_directed_sql = DIRECTED_JOIN if directed else JOIN
10191059

1060+
if unwrapped_left is not None:
1061+
# No need for additional parentheses around the left expression here, since it
1062+
# should already be parenthesized
1063+
left_expr = LEFT_UUID + unwrapped_left + NEW_LINE + LEFT_UUID
1064+
else:
1065+
left_alias = (
1066+
"SNOWPARK_LEFT"
1067+
if use_constant_subquery_alias
1068+
else random_name_for_temp_object(TempObjectType.TABLE)
1069+
)
1070+
left_expr = (
1071+
LEFT_PARENTHESIS
1072+
+ NEW_LINE
1073+
+ LEFT_UUID
1074+
+ left
1075+
+ NEW_LINE
1076+
+ LEFT_UUID
1077+
+ RIGHT_PARENTHESIS
1078+
+ AS
1079+
+ left_alias
1080+
+ SPACE
1081+
+ NEW_LINE
1082+
)
10201083
source = (
1021-
LEFT_PARENTHESIS
1022-
+ NEW_LINE
1023-
+ LEFT_UUID
1024-
+ left
1025-
+ NEW_LINE
1026-
+ LEFT_UUID
1027-
+ RIGHT_PARENTHESIS
1028-
+ AS
1029-
+ left_alias
1030-
+ SPACE
1031-
+ NEW_LINE
1084+
left_expr
10321085
+ join_sql
10331086
+ maybe_directed_sql
10341087
+ NEW_LINE
@@ -1060,6 +1113,7 @@ def join_statement(
10601113
left_uuid: Optional[str] = None,
10611114
right_uuid: Optional[str] = None,
10621115
directed: bool = False,
1116+
left_is_join: bool = False,
10631117
) -> str:
10641118
if isinstance(join_type, (LeftSemi, LeftAnti)):
10651119
return left_semi_or_anti_join_statement(
@@ -1087,6 +1141,7 @@ def join_statement(
10871141
left_uuid=left_uuid,
10881142
right_uuid=right_uuid,
10891143
directed=directed,
1144+
left_is_join=left_is_join,
10901145
)
10911146

10921147

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,19 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
444444
# Add the last df ast id to the snowflake plan as the most recent
445445
# dataframe operation to create this plan.
446446
self._snowflake_plan.df_ast_ids = self.df_ast_ids
447+
# Propagate join output flag through passthrough SelectStatements
448+
# so chained joins can flatten into multi-way joins.
449+
# Only SelectStatement has has_clause/has_projection; other
450+
# Selectable subclasses (SelectSQL, SetStatement, etc.) skip this.
451+
if (
452+
isinstance(self, SelectStatement)
453+
and not self.has_clause
454+
and not self.has_projection
455+
and isinstance(self.from_, SelectSnowflakePlan)
456+
):
457+
self._snowflake_plan._is_join_output = (
458+
self.from_._snowflake_plan._is_join_output
459+
)
447460
return self._snowflake_plan
448461

449462
@property

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def __init__(
458458
self.session = session
459459
self.source_plan = source_plan
460460
self.is_ddl_on_temp_object = is_ddl_on_temp_object
461+
self._is_join_output = False
461462
# We need to copy this list since we don't want to change it for the
462463
# previous SnowflakePlan objects
463464
self.api_calls = api_calls.copy() if api_calls else []
@@ -769,6 +770,7 @@ def __copy__(self) -> "SnowflakePlan":
769770
referenced_ctes=self.referenced_ctes,
770771
)
771772
plan.df_ast_ids = self.df_ast_ids
773+
plan._is_join_output = self._is_join_output
772774
return plan
773775

774776
def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
@@ -808,6 +810,7 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
808810
if copied_source_plan:
809811
copied_source_plan._is_valid_for_replacement = True
810812
copied_plan.df_ast_ids = self.df_ast_ids
813+
copied_plan._is_join_output = self._is_join_output
811814

812815
return copied_plan
813816

@@ -1231,7 +1234,8 @@ def join(
12311234
use_constant_subquery_alias: bool,
12321235
directed: bool = False,
12331236
):
1234-
return self.build_binary(
1237+
left_is_join = left._is_join_output
1238+
result = self.build_binary(
12351239
lambda x, y: join_statement(
12361240
x,
12371241
y,
@@ -1248,11 +1252,14 @@ def join(
12481252
else None
12491253
),
12501254
directed=directed,
1255+
left_is_join=left_is_join,
12511256
),
12521257
left,
12531258
right,
12541259
source_plan,
12551260
)
1261+
result._is_join_output = True
1262+
return result
12561263

12571264
def save_as_table(
12581265
self,

tests/unit/compiler/test_replace_child_and_update_node.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def mock_snowflake_plan() -> SnowflakePlan:
7070
fake_snowflake_plan.referenced_ctes = {with_query_block: 1}
7171
fake_snowflake_plan._cumulative_node_complexity = {}
7272
fake_snowflake_plan._is_valid_for_replacement = True
73+
fake_snowflake_plan._is_join_output = False
7374
fake_snowflake_plan._metadata = mock.create_autospec(PlanMetadata)
7475
fake_snowflake_plan._metadata.attributes = {}
7576
fake_snowflake_plan.query_line_intervals = []

tests/unit/test_analyzer_util_suite.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,136 @@ def test_join_statement_negative():
451451
join_statement("", "", join_type, "cond2", "", False)
452452

453453

454+
def test_join_statement_flattens_chained_joins():
455+
"""Chained joins should produce a flat multi-way join, not nested SELECT * wrappers."""
456+
join_type = UsingJoin(Inner(), ["key"])
457+
458+
# First join: produces SELECT * FROM ((left) AS L JOIN (right) AS R USING (key))
459+
first_join = join_statement(
460+
"SELECT * FROM table_a",
461+
"SELECT * FROM table_b",
462+
join_type,
463+
"",
464+
"",
465+
True,
466+
)
467+
468+
# Second join uses first join's output as left operand.
469+
# left_is_join=True signals that the left operand is a join result.
470+
second_join = join_statement(
471+
first_join,
472+
"SELECT * FROM table_c",
473+
join_type,
474+
"",
475+
"",
476+
True,
477+
left_is_join=True,
478+
)
479+
480+
# Should NOT have nested SELECT * FROM (SELECT * FROM (...))
481+
# Count occurrences of "SELECT" — expect exactly one top-level SELECT *
482+
assert second_join.count(" SELECT ") == 1
483+
484+
# The SQL should contain all three table references at the same nesting level
485+
assert "table_a" in second_join
486+
assert "table_b" in second_join
487+
assert "table_c" in second_join
488+
489+
490+
def test_join_statement_flattens_with_uuid_trace_comments():
491+
"""Flattening must work when UUID trace comments are present (trace-SQL mode)."""
492+
join_type = UsingJoin(Inner(), ["key"])
493+
494+
# First join with UUID trace comments
495+
first_join = join_statement(
496+
"SELECT * FROM table_a",
497+
"SELECT * FROM table_b",
498+
join_type,
499+
"",
500+
"",
501+
True,
502+
left_uuid="aaaa-bbbb",
503+
right_uuid="cccc-dddd",
504+
)
505+
506+
# Second join: flattening with UUIDs
507+
second_join = join_statement(
508+
first_join,
509+
"SELECT * FROM table_c",
510+
join_type,
511+
"",
512+
"",
513+
True,
514+
left_uuid="eeee-ffff",
515+
right_uuid="1111-2222",
516+
left_is_join=True,
517+
)
518+
519+
# Should still flatten — only one top-level SELECT *
520+
assert second_join.count(" SELECT ") == 1
521+
assert "table_a" in second_join
522+
assert "table_b" in second_join
523+
assert "table_c" in second_join
524+
525+
# Third join: must also flatten successfully despite accumulated UUID comments
526+
third_join = join_statement(
527+
second_join,
528+
"SELECT * FROM table_d",
529+
join_type,
530+
"",
531+
"",
532+
True,
533+
left_uuid="3333-4444",
534+
right_uuid="5555-6666",
535+
left_is_join=True,
536+
)
537+
538+
assert third_join.count(" SELECT ") == 1
539+
assert "table_a" in third_join
540+
assert "table_b" in third_join
541+
assert "table_c" in third_join
542+
assert "table_d" in third_join
543+
544+
545+
def test_join_statement_does_not_flatten_user_generated_select_star():
546+
"""A user-generated SELECT * that coincidentally matches the internal
547+
pattern must NOT be flattened when left_is_join is False (the default)."""
548+
join_type = UsingJoin(Inner(), ["key"])
549+
550+
# Craft a SQL string that matches the internal _SELECT_STAR_FROM_PREFIX/SUFFIX
551+
# pattern exactly — this simulates what session.sql("SELECT * FROM (...)") or
552+
# a similar user-provided query might produce.
553+
user_sql = (
554+
" SELECT * \n FROM (\n"
555+
"(SELECT id, key FROM user_table) AS t1"
556+
" INNER JOIN \n(SELECT id, key FROM other_table) AS t2\n"
557+
" USING (key)"
558+
"\n)"
559+
)
560+
561+
# join_statement with left_is_join=False (default) — should NOT unwrap because
562+
# left_is_join is only True when a plan object is constructed from a dataframe
563+
# join operation
564+
result = join_statement(
565+
user_sql,
566+
"SELECT * FROM table_c",
567+
join_type,
568+
"",
569+
"",
570+
True,
571+
left_is_join=False,
572+
)
573+
574+
# The user SQL should be preserved as a nested subquery, producing
575+
# two SELECT levels (the outer wrapper + the user's original SELECT *)
576+
assert result.count(" SELECT ") >= 2
577+
578+
# The user's original SQL should appear within the output intact
579+
assert "user_table" in result
580+
assert "other_table" in result
581+
assert "table_c" in result
582+
583+
454584
def test_create_iceberg_table_statement():
455585
assert create_table_statement(
456586
table_name="test_table",

0 commit comments

Comments
 (0)