2121 Sequence ,
2222 Set ,
2323 Union ,
24+ Literal ,
2425)
2526
2627import snowflake .snowpark ._internal .utils
8788 is_sql_select_statement ,
8889 ExprAliasUpdateDict ,
8990)
91+ import snowflake .snowpark .context as context
9092
9193# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
9294# Python 3.9 can use both
@@ -1418,9 +1420,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14181420 ):
14191421 # TODO: Clean up, this entire if case is parameter protection
14201422 can_be_flattened = False
1421- elif (self . where or self . order_by or self . limit_ ) and has_data_generator_exp (
1422- cols
1423- ):
1423+ elif (
1424+ self . where or self . order_by or self . limit_
1425+ ) and has_data_generator_or_window_function_exp ( cols ) :
14241426 can_be_flattened = False
14251427 elif self .where and (
14261428 (subquery_dependent_columns := derive_dependent_columns (self .where ))
@@ -1431,6 +1433,23 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14311433 subquery_dependent_columns & new_column_states .active_columns
14321434 )
14331435 )
1436+ or (
1437+ # unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery
1438+ # reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error
1439+ # sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1'
1440+ (
1441+ context ._is_snowpark_connect_compatible_mode
1442+ and context ._snowpark_connect_flatten_select_after_sort
1443+ )
1444+ and new_column_states .dropped_columns
1445+ and any (
1446+ self .column_states [_col ].change_state
1447+ in (ColumnChangeState .NEW , ColumnChangeState .CHANGED_EXP )
1448+ for _col in (
1449+ subquery_dependent_columns & new_column_states .dropped_columns
1450+ )
1451+ )
1452+ )
14341453 ):
14351454 can_be_flattened = False
14361455 elif self .order_by and (
@@ -1443,6 +1462,23 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14431462 subquery_dependent_columns & new_column_states .active_columns
14441463 )
14451464 )
1465+ or (
1466+ # unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery
1467+ # reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error
1468+ # sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"'
1469+ (
1470+ context ._is_snowpark_connect_compatible_mode
1471+ and context ._snowpark_connect_flatten_select_after_sort
1472+ )
1473+ and new_column_states .dropped_columns
1474+ and any (
1475+ self .column_states [_col ].change_state
1476+ in (ColumnChangeState .NEW , ColumnChangeState .CHANGED_EXP )
1477+ for _col in (
1478+ subquery_dependent_columns & new_column_states .dropped_columns
1479+ )
1480+ )
1481+ )
14461482 ):
14471483 can_be_flattened = False
14481484 elif self .distinct_ :
@@ -1488,8 +1524,62 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14881524 self .df_ast_ids .copy () if self .df_ast_ids is not None else None
14891525 )
14901526 else :
1527+ new_order_by = None
1528+ new_from = self
1529+ if (
1530+ context ._is_snowpark_connect_compatible_mode
1531+ and context ._snowpark_connect_flatten_select_after_sort
1532+ ) and self .order_by :
1533+ order_by_dependent_columns = derive_dependent_columns (* self .order_by )
1534+ if order_by_dependent_columns in (
1535+ COLUMN_DEPENDENCY_DOLLAR ,
1536+ COLUMN_DEPENDENCY_ALL ,
1537+ ):
1538+ new_order_by = None
1539+ elif any (
1540+ col not in self .from_ .column_states
1541+ and col not in self .column_states
1542+ for col in order_by_dependent_columns
1543+ ):
1544+ new_order_by = None
1545+ elif any (
1546+ _col not in self .column_states
1547+ or self .column_states [_col ].change_state
1548+ in (ColumnChangeState .CHANGED_EXP , ColumnChangeState .DROPPED )
1549+ for _col in order_by_dependent_columns
1550+ ):
1551+ new_from = copy (self )
1552+ missing_columns = (
1553+ order_by_dependent_columns
1554+ - new_from .column_states .active_columns
1555+ )
1556+ base_projection : List [Expression ] = (
1557+ new_from .projection
1558+ if new_from .projection is not None
1559+ else list (new_from .column_states .projection )
1560+ )
1561+ new_from .projection = base_projection + [
1562+ Attribute (col , DataType ()) for col in missing_columns
1563+ ]
1564+ new_col_states = derive_column_states_from_subquery (
1565+ new_from .projection , new_from .from_
1566+ )
1567+ if new_col_states is not None :
1568+ new_from .column_states = new_col_states
1569+ new_from ._projection_in_str = None
1570+ new_from ._commented_sql = None
1571+ new_from ._sql_query = None
1572+ new_order_by = self .order_by
1573+ else :
1574+ new_from = self
1575+ new_order_by = None
1576+ else :
1577+ new_order_by = self .order_by
14911578 new = SelectStatement (
1492- projection = cols , from_ = self .to_subqueryable (), analyzer = self .analyzer
1579+ projection = cols ,
1580+ from_ = new_from .to_subqueryable (),
1581+ order_by = new_order_by ,
1582+ analyzer = self .analyzer ,
14931583 )
14941584 new ._merge_projection_complexity_with_subquery = (
14951585 can_select_projection_complexity_be_merged (
@@ -1510,12 +1600,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
15101600 return new
15111601
15121602 def filter (self , col : Expression ) -> "SelectStatement" :
1603+ self ._session ._retrieve_aggregation_function_list ()
15131604 can_be_flattened = (
15141605 (not self .flatten_disabled )
15151606 and can_clause_dependent_columns_flatten (
1516- derive_dependent_columns (col ), self .column_states
1607+ derive_dependent_columns (col ), self .column_states , "filter"
15171608 )
1518- and not has_data_generator_exp (self .projection )
1609+ and not has_data_generator_or_window_function_exp (self .projection )
1610+ and not (
1611+ (
1612+ context ._is_snowpark_connect_compatible_mode
1613+ and context ._snowpark_connect_flatten_select_after_sort
1614+ )
1615+ and has_aggregation_function_exp (self .projection )
1616+ ) # sum(col) as new_col, new_col can not be flattened in where clause
15191617 and not (self .order_by and self .limit_ is not None )
15201618 )
15211619 if can_be_flattened :
@@ -1541,16 +1639,12 @@ def filter(self, col: Expression) -> "SelectStatement":
15411639 def sort (self , cols : List [Expression ]) -> "SelectStatement" :
15421640 can_be_flattened = (
15431641 (not self .flatten_disabled )
1544- # limit order by and order by limit can cause big performance
1545- # difference, because limit can stop table scanning whenever the
1546- # number of record is satisfied.
1547- # Therefore, disallow sql simplification when the
1548- # current SelectStatement has a limit clause to avoid moving
1549- # order by in front of limit.
1642+ # Disallow flattening when the current SelectStatement has a
1643+ # limit clause to avoid moving order by in front of limit.
15501644 and (not self .limit_ )
15511645 and (not self .offset )
15521646 and can_clause_dependent_columns_flatten (
1553- derive_dependent_columns (* cols ), self .column_states
1647+ derive_dependent_columns (* cols ), self .column_states , "sort"
15541648 )
15551649 and not has_data_generator_exp (self .projection )
15561650 )
@@ -1589,7 +1683,7 @@ def distinct(self) -> "SelectStatement":
15891683 # .order_by(col1).select(col2).distinct() cannot be flattened because
15901684 # SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
15911685 and (not (self .order_by and self .has_projection ))
1592- and not has_data_generator_exp (self .projection )
1686+ and not has_data_generator_or_window_function_exp (self .projection )
15931687 )
15941688 if can_be_flattened :
15951689 new = copy (self )
@@ -2080,7 +2174,12 @@ def can_projection_dependent_columns_be_flattened(
20802174def can_clause_dependent_columns_flatten (
20812175 dependent_columns : Optional [AbstractSet [str ]],
20822176 subquery_column_states : ColumnStateDict ,
2177+ clause : Literal ["filter" , "sort" ],
20832178) -> bool :
2179+ assert clause in (
2180+ "filter" ,
2181+ "sort" ,
2182+ ), f"Invalid clause called in can_clause_dependent_columns_flatten: { clause } "
20842183 if dependent_columns == COLUMN_DEPENDENCY_DOLLAR :
20852184 return False
20862185 elif (
@@ -2095,15 +2194,31 @@ def can_clause_dependent_columns_flatten(
20952194 dc_state = subquery_column_states .get (dc )
20962195 if dc_state :
20972196 if dc_state .change_state == ColumnChangeState .CHANGED_EXP :
2098- return False
2197+ if clause == "filter" :
2198+ return False
2199+ # sort + CHANGED_EXP: safe in SCOS mode since ORDER BY
2200+ # is evaluated after projection. Keep checking remaining
2201+ # columns though — another column may be unsafe.
2202+ elif not (
2203+ context ._is_snowpark_connect_compatible_mode
2204+ and context ._snowpark_connect_flatten_select_after_sort
2205+ ):
2206+ return False
20992207 elif dc_state .change_state == ColumnChangeState .NEW :
2100- # Most of the time this can be flattened. But if a new column uses window function and this column
2101- # is used in a clause, the sql doesn't work in Snowflake.
2102- # For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work.
2103- # But `select a, b as d from test_table where d = 1` works
2104- # We can inspect whether the referenced new column uses window function. Here we are being
2105- # conservative for now to not flatten the SQL.
2106- return False
2208+ if clause == "sort" and dc_state .dependent_columns in (
2209+ COLUMN_DEPENDENCY_DOLLAR ,
2210+ COLUMN_DEPENDENCY_ALL ,
2211+ ):
2212+ # Scalar subqueries in sort can trigger Snowflake
2213+ # internal errors when ORDER BY references them
2214+ # at the same SELECT level.
2215+ return False
2216+ if not (
2217+ context ._is_snowpark_connect_compatible_mode
2218+ and context ._snowpark_connect_flatten_select_after_sort
2219+ ):
2220+ return False
2221+
21072222 return True
21082223
21092224
@@ -2327,23 +2442,92 @@ def derive_column_states_from_subquery(
23272442 return column_states
23282443
23292444
2330- def has_data_generator_exp (expressions : Optional [List ["Expression" ]]) -> bool :
2445+ def _check_expressions_for_types (
2446+ expressions : Optional [List ["Expression" ]],
2447+ check_data_gen : bool = False ,
2448+ check_window : bool = False ,
2449+ check_aggregation : bool = False ,
2450+ ) -> bool :
2451+ """Efficiently check if expressions contain specific types in a single pass.
2452+
2453+ Args:
2454+ expressions: List of expressions to check
2455+ check_data_gen: Check for data generator functions
2456+ check_window: Check for window functions
2457+ check_aggregation: Check for aggregation functions
2458+
2459+ Returns:
2460+ True if any requested type is found
2461+ """
23312462 if expressions is None :
23322463 return False
2464+
23332465 for exp in expressions :
2334- if isinstance (exp , WindowExpression ):
2466+ if exp is None :
2467+ continue
2468+
2469+ # Check window functions
2470+ if check_window and isinstance (exp , WindowExpression ):
23352471 return True
2336- if isinstance (exp , FunctionExpression ) and (
2337- exp .is_data_generator
2338- or exp .name .lower () in SEQUENCE_DEPENDENT_DATA_GENERATION
2472+
2473+ if check_data_gen :
2474+ if isinstance (exp , FunctionExpression ) and (
2475+ exp .is_data_generator
2476+ or exp .name .lower () in SEQUENCE_DEPENDENT_DATA_GENERATION
2477+ ):
2478+ # https://docs.snowflake.com/en/sql-reference/functions-data-generation
2479+ return True
2480+
2481+ # Check aggregation functions
2482+ if check_aggregation and isinstance (exp , FunctionExpression ):
2483+ if exp .name .lower () in context ._aggregation_function_set :
2484+ return True
2485+
2486+ # Recursively check children.
2487+ # Some expression types (e.g. CaseWhen) store sub-expressions in
2488+ # _child_expressions rather than children; fall back to that.
2489+ sub_exps = exp .children
2490+ if not sub_exps :
2491+ sub_exps = getattr (exp , "_child_expressions" , None )
2492+ if _check_expressions_for_types (
2493+ sub_exps , check_data_gen , check_window , check_aggregation
23392494 ):
2340- # https://docs.snowflake.com/en/sql-reference/functions-data-generation
2341- return True
2342- if exp is not None and has_data_generator_exp (exp .children ):
23432495 return True
2496+
23442497 return False
23452498
23462499
2500+ def has_data_generator_exp (expressions : Optional [List ["Expression" ]]) -> bool :
2501+ """Check if expressions contain data generator functions.
2502+
2503+ Note:
2504+ In non-connect mode, window expressions are also treated as data generators
2505+ for backward compatibility.
2506+ """
2507+ if not (
2508+ context ._is_snowpark_connect_compatible_mode
2509+ and context ._snowpark_connect_flatten_select_after_sort
2510+ ):
2511+ return _check_expressions_for_types (
2512+ expressions , check_data_gen = True , check_window = True
2513+ )
2514+ return _check_expressions_for_types (expressions , check_data_gen = True )
2515+
2516+
2517+ def has_data_generator_or_window_function_exp (
2518+ expressions : Optional [List ["Expression" ]],
2519+ ) -> bool :
2520+ """Check if expressions contain data generators or window functions."""
2521+ return _check_expressions_for_types (
2522+ expressions , check_data_gen = True , check_window = True
2523+ )
2524+
2525+
2526+ def has_aggregation_function_exp (expressions : Optional [List ["Expression" ]]) -> bool :
2527+ """Check if expressions contain aggregation functions."""
2528+ return _check_expressions_for_types (expressions , check_aggregation = True )
2529+
2530+
23472531def has_nondeterministic_data_generation_exp (
23482532 expressions : Optional [List ["Expression" ]],
23492533) -> bool :
0 commit comments