diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cb7b23a80..fd66456c17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ ### Snowpark Python API Updates +#### Bug Fixes + +- Fixed a bug that `DataFrame.limit()` fail if there is parameter binding in the executed SQL. + #### New Features ### Snowpark pandas API Updates diff --git a/src/snowflake/snowpark/_internal/analyzer/schema_utils.py b/src/snowflake/snowpark/_internal/analyzer/schema_utils.py index 14e84806da..f0aa29bbad 100644 --- a/src/snowflake/snowpark/_internal/analyzer/schema_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/schema_utils.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # import traceback -from typing import TYPE_CHECKING, List, Union, Optional +from typing import TYPE_CHECKING, List, Union, Optional, Sequence, Any import snowflake.snowpark from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor @@ -70,6 +70,7 @@ def analyze_attributes( sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None, + query_params: Optional[Sequence[Any]] = None, ) -> List[Attribute]: lowercase = sql.strip().lower() @@ -104,7 +105,7 @@ def analyze_attributes( stack = traceback.extract_stack(limit=10)[:-1] stack_trace = [frame.line for frame in stack] if len(stack) > 0 else None with measure_time() as e2e_time: - attributes = session._get_result_attributes(sql) + attributes = session._get_result_attributes(sql, query_params) session._conn._telemetry_client.send_describe_query_details( session._session_id, sql, e2e_time(), stack_trace ) @@ -118,9 +119,9 @@ def analyze_attributes( @ttl_cache(ttl_seconds=15) def cached_analyze_attributes( - sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None # type: ignore + sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None, query_params: Optional[Sequence[Any]] = None # type: ignore ) -> List[Attribute]: - return analyze_attributes(sql, session, dataframe_uuid) + return analyze_attributes(sql, session, dataframe_uuid, query_params) def convert_result_meta_to_attribute( @@ -162,7 +163,7 @@ def get_new_description( def run_new_describe( - cursor: SnowflakeCursor, query: str + cursor: SnowflakeCursor, query: str, query_params: Optional[Sequence[Any]] = None ) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: # pyright: ignore """Execute describe() on a cursor, returning the new metadata format if possible. @@ -172,8 +173,5 @@ def run_new_describe( # ResultMetadataV2 may not currently be a type, depending on the connector # version, so the argument types are pyright ignored - if hasattr(cursor, "_describe_internal"): - # Pyright does not perform narrowing here - return cursor._describe_internal(query) # pyright: ignore - else: - return cursor.describe(query) + # Pyright does not perform narrowing here + return cursor._describe_internal(query, params=query_params) # pyright: ignore diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index d302700890..d502c8a29e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -588,10 +588,15 @@ def _analyze_attributes(self) -> List[Attribute]: assert ( self.schema_query is not None ), "No schema query is available for the SnowflakePlan" + query_params = getattr(self.source_plan, "query_params", None) if self.session.reduce_describe_query_enabled: - return cached_analyze_attributes(self.schema_query, self.session, self.uuid) + return cached_analyze_attributes( + self.schema_query, self.session, self.uuid, query_params + ) else: - return analyze_attributes(self.schema_query, self.session, self.uuid) + return analyze_attributes( + self.schema_query, self.session, self.uuid, query_params + ) @property def attributes(self) -> List[Attribute]: diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 6b2232f32d..7026a6189b 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -269,15 +269,22 @@ def _get_string_datum(self, query: str) -> Optional[str]: rows = result_set_to_rows(self.run_query(query)["data"]) return rows[0][0] if len(rows) > 0 else None - def get_result_attributes(self, query: str) -> List[Attribute]: + def get_result_attributes( + self, query: str, query_params: Optional[Sequence[Any]] = None + ) -> List[Attribute]: return convert_result_meta_to_attribute( - self._run_new_describe(self._cursor, query), self.max_string_size + self._run_new_describe(self._cursor, query, query_params=query_params), + self.max_string_size, ) def _run_new_describe( - self, cursor: SnowflakeCursor, query: str, **kwargs: dict + self, + cursor: SnowflakeCursor, + query: str, + query_params: Optional[Sequence[Any]] = None, + **kwargs: dict, ) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: - result_metadata = run_new_describe(cursor, query) + result_metadata = run_new_describe(cursor, query, query_params) with self._lock: for listener in filter( diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 10f3874dd5..4b0618a572 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2981,8 +2981,10 @@ def _run_query( _statement_params=statement_params, )["data"] - def _get_result_attributes(self, query: str) -> List[Attribute]: - return self._conn.get_result_attributes(query) + def _get_result_attributes( + self, query: str, query_params: Optional[Sequence[Any]] = None + ) -> List[Attribute]: + return self._conn.get_result_attributes(query, query_params) def get_session_stage( self, diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 22f963f121..ee418e525c 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -4690,6 +4690,26 @@ def test_limit_offset(session): assert df.limit(1, offset=1).collect() == [Row(A=4, B=5, C=6)] +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Not supported in local testing ", +) +def test_limit_param_binding(session): + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + session.create_dataframe( + [[{"name": "Alice"}]], schema=StructType([StructField("col", VariantType())]) + ).write.save_as_table(table_name, table_type="temp") + result = session.sql( + f""" + SELECT col:name as Name + FROM {table_name} + WHERE GET_PATH(col, cast(? as VARCHAR)) IS NOT NULL + """, + ["name"], + ).limit(1) + Utils.check_answer(result, [Row(NAME='"Alice"')]) + + def test_df_join_how_on_overwrite(session): df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df( ["int", "int2", "str"]