diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 15db45a247..0fb111f538 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -269,6 +269,7 @@ def __init__( self.pre_actions: Optional[List["Query"]] = None self.post_actions: Optional[List["Query"]] = None self.flatten_disabled: bool = False + self.protect_dropped_new_columns: bool = False self._column_states: Optional[ColumnStateDict] = None self._snowflake_plan: Optional[SnowflakePlan] = None self.expr_to_alias = ( @@ -1499,7 +1500,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": can_be_flattened = False else: can_be_flattened = can_select_statement_be_flattened( - self.column_states, new_column_states + self.column_states, new_column_states, self.protect_dropped_new_columns ) if can_be_flattened: @@ -2196,7 +2197,9 @@ def parse_column_name( def can_select_statement_be_flattened( - subquery_column_states: ColumnStateDict, new_column_states: ColumnStateDict + subquery_column_states: ColumnStateDict, + new_column_states: ColumnStateDict, + protect_dropped_new_columns: bool = False, ) -> bool: for col, state in new_column_states.items(): dependent_columns = state.dependent_columns @@ -2219,7 +2222,13 @@ def can_select_statement_be_flattened( state.change_state == ColumnChangeState.DROPPED and (subquery_state := subquery_column_states.get(col)) and subquery_state.change_state == ColumnChangeState.NEW - and subquery_state.is_referenced_by_same_level_column + and ( + subquery_state.is_referenced_by_same_level_column + # If the subquery was explicitly marked (e.g. by the SCOS withColumn + # path), preserve it so that future filter/sort clauses can still + # reference the dropped NEW column. + or protect_dropped_new_columns + ) ): return False return True