Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

### Snowpark Python API Updates

#### Bug Fixes

- Fixed a bug that `DataFrame.limit()` fail if there is parameter binding in the executed SQL.

#### New Features

### Snowpark pandas API Updates
Expand Down
18 changes: 8 additions & 10 deletions src/snowflake/snowpark/_internal/analyzer/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import traceback
from typing import TYPE_CHECKING, List, Union, Optional
from typing import TYPE_CHECKING, List, Union, Optional, Sequence, Any

import snowflake.snowpark
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
Expand Down Expand Up @@ -70,6 +70,7 @@ def analyze_attributes(
sql: str,
session: "snowflake.snowpark.session.Session",
dataframe_uuid: Optional[str] = None,
query_params: Optional[Sequence[Any]] = None,
) -> List[Attribute]:
lowercase = sql.strip().lower()

Expand Down Expand Up @@ -104,7 +105,7 @@ def analyze_attributes(
stack = traceback.extract_stack(limit=10)[:-1]
stack_trace = [frame.line for frame in stack] if len(stack) > 0 else None
with measure_time() as e2e_time:
attributes = session._get_result_attributes(sql)
attributes = session._get_result_attributes(sql, query_params)
session._conn._telemetry_client.send_describe_query_details(
session._session_id, sql, e2e_time(), stack_trace
)
Expand All @@ -118,9 +119,9 @@ def analyze_attributes(

@ttl_cache(ttl_seconds=15)
def cached_analyze_attributes(
sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None # type: ignore
sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None, query_params: Optional[Sequence[Any]] = None # type: ignore
) -> List[Attribute]:
return analyze_attributes(sql, session, dataframe_uuid)
return analyze_attributes(sql, session, dataframe_uuid, query_params)


def convert_result_meta_to_attribute(
Expand Down Expand Up @@ -162,7 +163,7 @@ def get_new_description(


def run_new_describe(
cursor: SnowflakeCursor, query: str
cursor: SnowflakeCursor, query: str, query_params: Optional[Sequence[Any]] = None
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: # pyright: ignore
"""Execute describe() on a cursor, returning the new metadata format if possible.

Expand All @@ -172,8 +173,5 @@ def run_new_describe(
# ResultMetadataV2 may not currently be a type, depending on the connector
# version, so the argument types are pyright ignored

if hasattr(cursor, "_describe_internal"):
# Pyright does not perform narrowing here
return cursor._describe_internal(query) # pyright: ignore
else:
return cursor.describe(query)
# Pyright does not perform narrowing here
return cursor._describe_internal(query, params=query_params) # pyright: ignore
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,15 @@ def _analyze_attributes(self) -> List[Attribute]:
assert (
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
query_params = getattr(self.source_plan, "query_params", None)
if self.session.reduce_describe_query_enabled:
return cached_analyze_attributes(self.schema_query, self.session, self.uuid)
return cached_analyze_attributes(
self.schema_query, self.session, self.uuid, query_params
)
else:
return analyze_attributes(self.schema_query, self.session, self.uuid)
return analyze_attributes(
self.schema_query, self.session, self.uuid, query_params
)

@property
def attributes(self) -> List[Attribute]:
Expand Down
15 changes: 11 additions & 4 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,22 @@ def _get_string_datum(self, query: str) -> Optional[str]:
rows = result_set_to_rows(self.run_query(query)["data"])
return rows[0][0] if len(rows) > 0 else None

def get_result_attributes(self, query: str) -> List[Attribute]:
def get_result_attributes(
self, query: str, query_params: Optional[Sequence[Any]] = None
) -> List[Attribute]:
return convert_result_meta_to_attribute(
self._run_new_describe(self._cursor, query), self.max_string_size
self._run_new_describe(self._cursor, query, query_params=query_params),
self.max_string_size,
)

def _run_new_describe(
self, cursor: SnowflakeCursor, query: str, **kwargs: dict
self,
cursor: SnowflakeCursor,
query: str,
query_params: Optional[Sequence[Any]] = None,
**kwargs: dict,
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]:
result_metadata = run_new_describe(cursor, query)
result_metadata = run_new_describe(cursor, query, query_params)

with self._lock:
for listener in filter(
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2981,8 +2981,10 @@ def _run_query(
_statement_params=statement_params,
)["data"]

def _get_result_attributes(self, query: str) -> List[Attribute]:
return self._conn.get_result_attributes(query)
def _get_result_attributes(
self, query: str, query_params: Optional[Sequence[Any]] = None
) -> List[Attribute]:
return self._conn.get_result_attributes(query, query_params)

def get_session_stage(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4690,6 +4690,26 @@ def test_limit_offset(session):
assert df.limit(1, offset=1).collect() == [Row(A=4, B=5, C=6)]


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Not supported in local testing ",
)
def test_limit_param_binding(session):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
session.create_dataframe(
[[{"name": "Alice"}]], schema=StructType([StructField("col", VariantType())])
).write.save_as_table(table_name, table_type="temp")
result = session.sql(
f"""
SELECT col:name as Name
FROM {table_name}
WHERE GET_PATH(col, cast(? as VARCHAR)) IS NOT NULL
""",
["name"],
).limit(1)
Utils.check_answer(result, [Row(NAME='"Alice"')])


def test_df_join_how_on_overwrite(session):
df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(
["int", "int2", "str"]
Expand Down
Loading