Skip to content

Commit 07689e0

Browse files
sfc-gh-yixiesfc-gh-alingsfc-gh-mayliu
authored
SNOW-2203826: Loosen flattening rules for sort and filter (#4026)
Co-authored-by: Adam Ling <adam.ling@snowflake.com> Co-authored-by: May Liu <may.liu@snowflake.com>
1 parent c5362e4 commit 07689e0

8 files changed

Lines changed: 1003 additions & 57 deletions

File tree

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 214 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Sequence,
2222
Set,
2323
Union,
24+
Literal,
2425
)
2526

2627
import snowflake.snowpark._internal.utils
@@ -87,6 +88,7 @@
8788
is_sql_select_statement,
8889
ExprAliasUpdateDict,
8990
)
91+
import snowflake.snowpark.context as context
9092

9193
# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
9294
# Python 3.9 can use both
@@ -1418,9 +1420,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14181420
):
14191421
# TODO: Clean up, this entire if case is parameter protection
14201422
can_be_flattened = False
1421-
elif (self.where or self.order_by or self.limit_) and has_data_generator_exp(
1422-
cols
1423-
):
1423+
elif (
1424+
self.where or self.order_by or self.limit_
1425+
) and has_data_generator_or_window_function_exp(cols):
14241426
can_be_flattened = False
14251427
elif self.where and (
14261428
(subquery_dependent_columns := derive_dependent_columns(self.where))
@@ -1431,6 +1433,23 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14311433
subquery_dependent_columns & new_column_states.active_columns
14321434
)
14331435
)
1436+
or (
1437+
# unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery
1438+
# reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error
1439+
# sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1'
1440+
(
1441+
context._is_snowpark_connect_compatible_mode
1442+
and context._snowpark_connect_flatten_select_after_sort
1443+
)
1444+
and new_column_states.dropped_columns
1445+
and any(
1446+
self.column_states[_col].change_state
1447+
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
1448+
for _col in (
1449+
subquery_dependent_columns & new_column_states.dropped_columns
1450+
)
1451+
)
1452+
)
14341453
):
14351454
can_be_flattened = False
14361455
elif self.order_by and (
@@ -1443,6 +1462,23 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14431462
subquery_dependent_columns & new_column_states.active_columns
14441463
)
14451464
)
1465+
or (
1466+
# unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery
1467+
# reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error
1468+
# sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"'
1469+
(
1470+
context._is_snowpark_connect_compatible_mode
1471+
and context._snowpark_connect_flatten_select_after_sort
1472+
)
1473+
and new_column_states.dropped_columns
1474+
and any(
1475+
self.column_states[_col].change_state
1476+
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
1477+
for _col in (
1478+
subquery_dependent_columns & new_column_states.dropped_columns
1479+
)
1480+
)
1481+
)
14461482
):
14471483
can_be_flattened = False
14481484
elif self.distinct_:
@@ -1488,8 +1524,62 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14881524
self.df_ast_ids.copy() if self.df_ast_ids is not None else None
14891525
)
14901526
else:
1527+
new_order_by = None
1528+
new_from = self
1529+
if (
1530+
context._is_snowpark_connect_compatible_mode
1531+
and context._snowpark_connect_flatten_select_after_sort
1532+
) and self.order_by:
1533+
order_by_dependent_columns = derive_dependent_columns(*self.order_by)
1534+
if order_by_dependent_columns in (
1535+
COLUMN_DEPENDENCY_DOLLAR,
1536+
COLUMN_DEPENDENCY_ALL,
1537+
):
1538+
new_order_by = None
1539+
elif any(
1540+
col not in self.from_.column_states
1541+
and col not in self.column_states
1542+
for col in order_by_dependent_columns
1543+
):
1544+
new_order_by = None
1545+
elif any(
1546+
_col not in self.column_states
1547+
or self.column_states[_col].change_state
1548+
in (ColumnChangeState.CHANGED_EXP, ColumnChangeState.DROPPED)
1549+
for _col in order_by_dependent_columns
1550+
):
1551+
new_from = copy(self)
1552+
missing_columns = (
1553+
order_by_dependent_columns
1554+
- new_from.column_states.active_columns
1555+
)
1556+
base_projection: List[Expression] = (
1557+
new_from.projection
1558+
if new_from.projection is not None
1559+
else list(new_from.column_states.projection)
1560+
)
1561+
new_from.projection = base_projection + [
1562+
Attribute(col, DataType()) for col in missing_columns
1563+
]
1564+
new_col_states = derive_column_states_from_subquery(
1565+
new_from.projection, new_from.from_
1566+
)
1567+
if new_col_states is not None:
1568+
new_from.column_states = new_col_states
1569+
new_from._projection_in_str = None
1570+
new_from._commented_sql = None
1571+
new_from._sql_query = None
1572+
new_order_by = self.order_by
1573+
else:
1574+
new_from = self
1575+
new_order_by = None
1576+
else:
1577+
new_order_by = self.order_by
14911578
new = SelectStatement(
1492-
projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer
1579+
projection=cols,
1580+
from_=new_from.to_subqueryable(),
1581+
order_by=new_order_by,
1582+
analyzer=self.analyzer,
14931583
)
14941584
new._merge_projection_complexity_with_subquery = (
14951585
can_select_projection_complexity_be_merged(
@@ -1510,12 +1600,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
15101600
return new
15111601

15121602
def filter(self, col: Expression) -> "SelectStatement":
1603+
self._session._retrieve_aggregation_function_list()
15131604
can_be_flattened = (
15141605
(not self.flatten_disabled)
15151606
and can_clause_dependent_columns_flatten(
1516-
derive_dependent_columns(col), self.column_states
1607+
derive_dependent_columns(col), self.column_states, "filter"
15171608
)
1518-
and not has_data_generator_exp(self.projection)
1609+
and not has_data_generator_or_window_function_exp(self.projection)
1610+
and not (
1611+
(
1612+
context._is_snowpark_connect_compatible_mode
1613+
and context._snowpark_connect_flatten_select_after_sort
1614+
)
1615+
and has_aggregation_function_exp(self.projection)
1616+
) # sum(col) as new_col, new_col can not be flattened in where clause
15191617
and not (self.order_by and self.limit_ is not None)
15201618
)
15211619
if can_be_flattened:
@@ -1541,16 +1639,12 @@ def filter(self, col: Expression) -> "SelectStatement":
15411639
def sort(self, cols: List[Expression]) -> "SelectStatement":
15421640
can_be_flattened = (
15431641
(not self.flatten_disabled)
1544-
# limit order by and order by limit can cause big performance
1545-
# difference, because limit can stop table scanning whenever the
1546-
# number of record is satisfied.
1547-
# Therefore, disallow sql simplification when the
1548-
# current SelectStatement has a limit clause to avoid moving
1549-
# order by in front of limit.
1642+
# Disallow flattening when the current SelectStatement has a
1643+
# limit clause to avoid moving order by in front of limit.
15501644
and (not self.limit_)
15511645
and (not self.offset)
15521646
and can_clause_dependent_columns_flatten(
1553-
derive_dependent_columns(*cols), self.column_states
1647+
derive_dependent_columns(*cols), self.column_states, "sort"
15541648
)
15551649
and not has_data_generator_exp(self.projection)
15561650
)
@@ -1589,7 +1683,7 @@ def distinct(self) -> "SelectStatement":
15891683
# .order_by(col1).select(col2).distinct() cannot be flattened because
15901684
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
15911685
and (not (self.order_by and self.has_projection))
1592-
and not has_data_generator_exp(self.projection)
1686+
and not has_data_generator_or_window_function_exp(self.projection)
15931687
)
15941688
if can_be_flattened:
15951689
new = copy(self)
@@ -2080,7 +2174,12 @@ def can_projection_dependent_columns_be_flattened(
20802174
def can_clause_dependent_columns_flatten(
20812175
dependent_columns: Optional[AbstractSet[str]],
20822176
subquery_column_states: ColumnStateDict,
2177+
clause: Literal["filter", "sort"],
20832178
) -> bool:
2179+
assert clause in (
2180+
"filter",
2181+
"sort",
2182+
), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}"
20842183
if dependent_columns == COLUMN_DEPENDENCY_DOLLAR:
20852184
return False
20862185
elif (
@@ -2095,15 +2194,31 @@ def can_clause_dependent_columns_flatten(
20952194
dc_state = subquery_column_states.get(dc)
20962195
if dc_state:
20972196
if dc_state.change_state == ColumnChangeState.CHANGED_EXP:
2098-
return False
2197+
if clause == "filter":
2198+
return False
2199+
# sort + CHANGED_EXP: safe in SCOS mode since ORDER BY
2200+
# is evaluated after projection. Keep checking remaining
2201+
# columns though — another column may be unsafe.
2202+
elif not (
2203+
context._is_snowpark_connect_compatible_mode
2204+
and context._snowpark_connect_flatten_select_after_sort
2205+
):
2206+
return False
20992207
elif dc_state.change_state == ColumnChangeState.NEW:
2100-
# Most of the time this can be flattened. But if a new column uses window function and this column
2101-
# is used in a clause, the sql doesn't work in Snowflake.
2102-
# For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work.
2103-
# But `select a, b as d from test_table where d = 1` works
2104-
# We can inspect whether the referenced new column uses window function. Here we are being
2105-
# conservative for now to not flatten the SQL.
2106-
return False
2208+
if clause == "sort" and dc_state.dependent_columns in (
2209+
COLUMN_DEPENDENCY_DOLLAR,
2210+
COLUMN_DEPENDENCY_ALL,
2211+
):
2212+
# Scalar subqueries in sort can trigger Snowflake
2213+
# internal errors when ORDER BY references them
2214+
# at the same SELECT level.
2215+
return False
2216+
if not (
2217+
context._is_snowpark_connect_compatible_mode
2218+
and context._snowpark_connect_flatten_select_after_sort
2219+
):
2220+
return False
2221+
21072222
return True
21082223

21092224

@@ -2327,23 +2442,92 @@ def derive_column_states_from_subquery(
23272442
return column_states
23282443

23292444

2330-
def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
2445+
def _check_expressions_for_types(
2446+
expressions: Optional[List["Expression"]],
2447+
check_data_gen: bool = False,
2448+
check_window: bool = False,
2449+
check_aggregation: bool = False,
2450+
) -> bool:
2451+
"""Efficiently check if expressions contain specific types in a single pass.
2452+
2453+
Args:
2454+
expressions: List of expressions to check
2455+
check_data_gen: Check for data generator functions
2456+
check_window: Check for window functions
2457+
check_aggregation: Check for aggregation functions
2458+
2459+
Returns:
2460+
True if any requested type is found
2461+
"""
23312462
if expressions is None:
23322463
return False
2464+
23332465
for exp in expressions:
2334-
if isinstance(exp, WindowExpression):
2466+
if exp is None:
2467+
continue
2468+
2469+
# Check window functions
2470+
if check_window and isinstance(exp, WindowExpression):
23352471
return True
2336-
if isinstance(exp, FunctionExpression) and (
2337-
exp.is_data_generator
2338-
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION
2472+
2473+
if check_data_gen:
2474+
if isinstance(exp, FunctionExpression) and (
2475+
exp.is_data_generator
2476+
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION
2477+
):
2478+
# https://docs.snowflake.com/en/sql-reference/functions-data-generation
2479+
return True
2480+
2481+
# Check aggregation functions
2482+
if check_aggregation and isinstance(exp, FunctionExpression):
2483+
if exp.name.lower() in context._aggregation_function_set:
2484+
return True
2485+
2486+
# Recursively check children.
2487+
# Some expression types (e.g. CaseWhen) store sub-expressions in
2488+
# _child_expressions rather than children; fall back to that.
2489+
sub_exps = exp.children
2490+
if not sub_exps:
2491+
sub_exps = getattr(exp, "_child_expressions", None)
2492+
if _check_expressions_for_types(
2493+
sub_exps, check_data_gen, check_window, check_aggregation
23392494
):
2340-
# https://docs.snowflake.com/en/sql-reference/functions-data-generation
2341-
return True
2342-
if exp is not None and has_data_generator_exp(exp.children):
23432495
return True
2496+
23442497
return False
23452498

23462499

2500+
def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
2501+
"""Check if expressions contain data generator functions.
2502+
2503+
Note:
2504+
In non-connect mode, window expressions are also treated as data generators
2505+
for backward compatibility.
2506+
"""
2507+
if not (
2508+
context._is_snowpark_connect_compatible_mode
2509+
and context._snowpark_connect_flatten_select_after_sort
2510+
):
2511+
return _check_expressions_for_types(
2512+
expressions, check_data_gen=True, check_window=True
2513+
)
2514+
return _check_expressions_for_types(expressions, check_data_gen=True)
2515+
2516+
2517+
def has_data_generator_or_window_function_exp(
2518+
expressions: Optional[List["Expression"]],
2519+
) -> bool:
2520+
"""Check if expressions contain data generators or window functions."""
2521+
return _check_expressions_for_types(
2522+
expressions, check_data_gen=True, check_window=True
2523+
)
2524+
2525+
2526+
def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool:
2527+
"""Check if expressions contain aggregation functions."""
2528+
return _check_expressions_for_types(expressions, check_aggregation=True)
2529+
2530+
23472531
def has_nondeterministic_data_generation_exp(
23482532
expressions: Optional[List["Expression"]],
23492533
) -> bool:

0 commit comments

Comments
 (0)