Skip to content

Commit c9cccc0

Browse files
committed
propagate through plan
1 parent 12a15d5 commit c9cccc0

4 files changed

Lines changed: 69 additions & 8 deletions

File tree

src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -996,24 +996,30 @@ def snowflake_supported_join_statement(
996996
left_uuid: Optional[str] = None,
997997
right_uuid: Optional[str] = None,
998998
directed: bool = False,
999+
left_is_join: bool = False,
9991000
) -> str:
10001001
LEFT_UUID = format_uuid(left_uuid)
10011002
RIGHT_UUID = format_uuid(right_uuid)
10021003

1003-
# If left is a simple SELECT * FROM (\n<join_source>\n) wrapper from a
1004-
# previous join, flatten into a multi-way join by appending the new right
1005-
# operand directly to the existing join source. This avoids nested
1004+
# If left is the output of a previous join, flatten into a multi-way join
1005+
# by unwrapping the SELECT * FROM (...) envelope and appending the new
1006+
# right operand directly to the existing join source. This avoids nested
10061007
# SELECT * layers that inflate query text without changing semantics.
10071008
#
10081009
# Though it is technically less efficient than constructing the join sub-queries
10091010
# without the SELECT in the first place, the structure of our SQL processing code
1010-
# top-level projections to be wrapped by a select.
1011-
unwrapped_left = _unwrap_select_star_from(left)
1011+
# needs top-level projections to be wrapped by a select to be well-formed, so we
1012+
# must strip it here instead.
1013+
#
1014+
# We only unwrap the left side because it is simpler to deal with than unwrapping
1015+
# both left and right, and left-deep chains are more common, as they're produced
1016+
# by calls like df1.join(df2).join(df3) etc.
1017+
unwrapped_left = _unwrap_select_star_from(left) if left_is_join else None
10121018
right_alias = (
10131019
"SNOWPARK_RIGHT"
1020+
if use_constant_subquery_alias and unwrapped_left is None
10141021
# Multi-way join: right alias must be unique to avoid collisions
10151022
# with aliases already present in the flattened join source.
1016-
if use_constant_subquery_alias and unwrapped_left is None
10171023
else random_name_for_temp_object(TempObjectType.TABLE)
10181024
)
10191025

@@ -1105,6 +1111,7 @@ def join_statement(
11051111
left_uuid: Optional[str] = None,
11061112
right_uuid: Optional[str] = None,
11071113
directed: bool = False,
1114+
left_is_join: bool = False,
11081115
) -> str:
11091116
if isinstance(join_type, (LeftSemi, LeftAnti)):
11101117
return left_semi_or_anti_join_statement(
@@ -1132,6 +1139,7 @@ def join_statement(
11321139
left_uuid=left_uuid,
11331140
right_uuid=right_uuid,
11341141
directed=directed,
1142+
left_is_join=left_is_join,
11351143
)
11361144

11371145

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,13 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
451451
# Add the last df ast id to the snowflake plan as the most recent
452452
# dataframe operation to create this plan.
453453
self._snowflake_plan.df_ast_ids = self.df_ast_ids
454+
# Propagate join output flag through passthrough SelectStatements
455+
# so chained joins can flatten into multi-way joins.
456+
if not self.has_clause and not self.has_projection:
457+
if isinstance(self.from_, SelectSnowflakePlan):
458+
self._snowflake_plan._is_join_output = (
459+
self.from_._snowflake_plan._is_join_output
460+
)
454461
return self._snowflake_plan
455462

456463
@property

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def __init__(
464464
self.session = session
465465
self.source_plan = source_plan
466466
self.is_ddl_on_temp_object = is_ddl_on_temp_object
467+
self._is_join_output = False
467468
# We need to copy this list since we don't want to change it for the
468469
# previous SnowflakePlan objects
469470
self.api_calls = api_calls.copy() if api_calls else []
@@ -1237,7 +1238,8 @@ def join(
12371238
use_constant_subquery_alias: bool,
12381239
directed: bool = False,
12391240
):
1240-
return self.build_binary(
1241+
left_is_join = left._is_join_output
1242+
result = self.build_binary(
12411243
lambda x, y: join_statement(
12421244
x,
12431245
y,
@@ -1254,11 +1256,14 @@ def join(
12541256
else None
12551257
),
12561258
directed=directed,
1259+
left_is_join=left_is_join,
12571260
),
12581261
left,
12591262
right,
12601263
source_plan,
12611264
)
1265+
result._is_join_output = True
1266+
return result
12621267

12631268
def save_as_table(
12641269
self,

tests/unit/test_analyzer_util_suite.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,14 +465,16 @@ def test_join_statement_flattens_chained_joins():
465465
True,
466466
)
467467

468-
# Second join uses first join's output as left operand
468+
# Second join uses first join's output as left operand.
469+
# left_is_join=True signals that the left operand is a join result.
469470
second_join = join_statement(
470471
first_join,
471472
"SELECT * FROM table_c",
472473
join_type,
473474
"",
474475
"",
475476
True,
477+
left_is_join=True,
476478
)
477479

478480
# Should NOT have nested SELECT * FROM (SELECT * FROM (...))
@@ -485,6 +487,45 @@ def test_join_statement_flattens_chained_joins():
485487
assert "table_c" in second_join
486488

487489

490+
def test_join_statement_does_not_flatten_user_generated_select_star():
491+
"""A user-generated SELECT * that coincidentally matches the internal
492+
pattern must NOT be flattened when left_is_join is False (the default)."""
493+
join_type = UsingJoin(Inner(), ["key"])
494+
495+
# Craft a SQL string that matches the internal _SELECT_STAR_FROM_PREFIX/SUFFIX
496+
# pattern exactly — this simulates what session.sql("SELECT * FROM (...)") or
497+
# a similar user-provided query might produce.
498+
user_sql = (
499+
" SELECT * \n FROM (\n"
500+
"(SELECT id, key FROM user_table) AS t1"
501+
" INNER JOIN \n(SELECT id, key FROM other_table) AS t2\n"
502+
" USING (key)"
503+
"\n)"
504+
)
505+
506+
# join_statement with left_is_join=False (default) — should NOT unwrap because
507+
# left_is_join is only True when a plan object is constructed from a dataframe
508+
# join operation
509+
result = join_statement(
510+
user_sql,
511+
"SELECT * FROM table_c",
512+
join_type,
513+
"",
514+
"",
515+
True,
516+
left_is_join=False,
517+
)
518+
519+
# The user SQL should be preserved as a nested subquery, producing
520+
# two SELECT levels (the outer wrapper + the user's original SELECT *)
521+
assert result.count(" SELECT ") >= 2
522+
523+
# The user's original SQL should appear within the output intact
524+
assert "user_table" in result
525+
assert "other_table" in result
526+
assert "table_c" in result
527+
528+
488529
def test_create_iceberg_table_statement():
489530
assert create_table_statement(
490531
table_name="test_table",

0 commit comments

Comments
 (0)