Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@

#### Bug Fixes

- Fixed a bug that `DataFrame.limit()` fail if there is parameter binding in the executed SQL.

#### Deprecations

- Deprecated warnings will be triggered when using snowpark-python with Python 3.9. For more details, please refer to https://docs.snowflake.com/en/developer-guide/python-runtime-support-policy.
Expand Down
15 changes: 8 additions & 7 deletions src/snowflake/snowpark/_internal/analyzer/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import traceback
from typing import TYPE_CHECKING, List, Union, Optional
from typing import TYPE_CHECKING, List, Union, Optional, Sequence, Any

import snowflake.snowpark
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
Expand Down Expand Up @@ -70,6 +70,7 @@ def analyze_attributes(
sql: str,
session: "snowflake.snowpark.session.Session",
dataframe_uuid: Optional[str] = None,
query_params: Optional[Sequence[Any]] = None,
) -> List[Attribute]:
lowercase = sql.strip().lower()

Expand Down Expand Up @@ -104,7 +105,7 @@ def analyze_attributes(
stack = traceback.extract_stack(limit=10)[:-1]
stack_trace = [frame.line for frame in stack] if len(stack) > 0 else None
with measure_time() as e2e_time:
attributes = session._get_result_attributes(sql)
attributes = session._get_result_attributes(sql, query_params)
session._conn._telemetry_client.send_describe_query_details(
session._session_id, sql, e2e_time(), stack_trace
)
Expand All @@ -118,9 +119,9 @@ def analyze_attributes(

@ttl_cache(ttl_seconds=15)
def cached_analyze_attributes(
sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None # type: ignore
sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None, query_params: Optional[Sequence[Any]] = None # type: ignore
) -> List[Attribute]:
return analyze_attributes(sql, session, dataframe_uuid)
return analyze_attributes(sql, session, dataframe_uuid, query_params)


def convert_result_meta_to_attribute(
Expand Down Expand Up @@ -162,7 +163,7 @@ def get_new_description(


def run_new_describe(
cursor: SnowflakeCursor, query: str
cursor: SnowflakeCursor, query: str, query_params: Optional[Sequence[Any]] = None
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: # pyright: ignore
"""Execute describe() on a cursor, returning the new metadata format if possible.

Expand All @@ -174,6 +175,6 @@ def run_new_describe(

if hasattr(cursor, "_describe_internal"):
# Pyright does not perform narrowing here
return cursor._describe_internal(query) # pyright: ignore
return cursor._describe_internal(query, params=query_params) # pyright: ignore
else:
return cursor.describe(query)
return cursor.describe(query, params=query_params)
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,15 @@ def _analyze_attributes(self) -> List[Attribute]:
assert (
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
query_params = getattr(self.source_plan, "query_params", None)
if self.session.reduce_describe_query_enabled:
return cached_analyze_attributes(self.schema_query, self.session, self.uuid)
return cached_analyze_attributes(
self.schema_query, self.session, self.uuid, query_params
)
else:
return analyze_attributes(self.schema_query, self.session, self.uuid)
return analyze_attributes(
self.schema_query, self.session, self.uuid, query_params
)

@property
def attributes(self) -> List[Attribute]:
Expand Down
15 changes: 11 additions & 4 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,22 @@ def _get_string_datum(self, query: str) -> Optional[str]:
rows = result_set_to_rows(self.run_query(query)["data"])
return rows[0][0] if len(rows) > 0 else None

def get_result_attributes(self, query: str) -> List[Attribute]:
def get_result_attributes(
self, query: str, query_params: Optional[Sequence[Any]] = None
) -> List[Attribute]:
return convert_result_meta_to_attribute(
self._run_new_describe(self._cursor, query), self.max_string_size
self._run_new_describe(self._cursor, query, query_params=query_params),
self.max_string_size,
)

def _run_new_describe(
self, cursor: SnowflakeCursor, query: str, **kwargs: dict
self,
cursor: SnowflakeCursor,
query: str,
query_params: Optional[Sequence[Any]] = None,
**kwargs: dict,
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]:
result_metadata = run_new_describe(cursor, query)
result_metadata = run_new_describe(cursor, query, query_params)

with self._lock:
for listener in filter(
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2981,8 +2981,10 @@ def _run_query(
_statement_params=statement_params,
)["data"]

def _get_result_attributes(self, query: str) -> List[Attribute]:
return self._conn.get_result_attributes(query)
def _get_result_attributes(
self, query: str, query_params: Optional[Sequence[Any]] = None
) -> List[Attribute]:
return self._conn.get_result_attributes(query, query_params)

def get_session_stage(
self,
Expand Down
21 changes: 21 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4690,6 +4690,27 @@ def test_limit_offset(session):
assert df.limit(1, offset=1).collect() == [Row(A=4, B=5, C=6)]


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Not supported in local testing ",
run=False,
)
Comment thread
sfc-gh-yuwang marked this conversation as resolved.
def test_limit_param_binding(session):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
session.create_dataframe(
[[{"name": "Alice"}]], schema=StructType([StructField("col", VariantType())])
).write.save_as_table(table_name, table_type="temp")
result = session.sql(
f"""
SELECT col:name as Name
FROM {table_name}
WHERE GET_PATH(col, cast(? as VARCHAR)) IS NOT NULL
""",
["name"],
).limit(1)
Utils.check_answer(result, [Row(NAME='"Alice"')])


def test_df_join_how_on_overwrite(session):
df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(
["int", "int2", "str"]
Expand Down
Loading