Skip to content

Commit 003d251

Browse files
committed
propagate parameters through describe queries
1 parent 0aeccd9 commit 003d251

2 files changed

Lines changed: 87 additions & 4 deletions

File tree

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,12 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]:
672672

673673
@SnowflakePlan.Decorator.wrap_exception
674674
def _analyze_attributes(
675-
sql: str, session: "snowflake.snowpark.session.Session", dataframe_uuid: Optional[str] = None # type: ignore
675+
sql: str,
676+
session: "snowflake.snowpark.session.Session",
677+
dataframe_uuid: Optional[str] = None, # type: ignore
678+
query_params: Optional[Sequence[Any]] = None,
676679
) -> List[Attribute]:
677-
return analyze_attributes(sql, session, dataframe_uuid)
680+
return analyze_attributes(sql, session, dataframe_uuid, query_params)
678681

679682

680683
class SelectSQL(Selectable):
@@ -707,7 +710,7 @@ def __init__(
707710
self.pre_actions[0].query_id_place_holder
708711
)
709712
self._schema_query = analyzer_utils.schema_value_statement(
710-
_analyze_attributes(sql, self._session, self._uuid)
713+
_analyze_attributes(sql, self._session, self._uuid, query_params=params)
711714
) # Change to subqueryable schema query so downstream query plan can describe the SQL
712715
self._query_param = None
713716
else:

tests/integ/test_bind_variable.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010

1111
from snowflake.snowpark import Row
12-
from snowflake.snowpark._internal.utils import is_in_stored_procedure
12+
from snowflake.snowpark._internal.utils import TempObjectType, is_in_stored_procedure
1313
from snowflake.snowpark.exceptions import SnowparkSQLException
1414
from snowflake.snowpark.functions import col, lit, max, table_function
1515
from snowflake.snowpark.types import (
@@ -457,3 +457,83 @@ def test_explain(session):
457457
params=[1, "a", 2, "b"],
458458
)
459459
df.explain()
460+
461+
462+
@pytest.fixture(scope="module")
463+
def proc_name(session):
464+
"""Create a trivial stored procedure that echoes its inputs back."""
465+
name = f"{session.get_fully_qualified_current_schema()}.{Utils.random_name_for_temp_object(TempObjectType.PROCEDURE)}"
466+
session.sql(
467+
f"""
468+
CREATE OR REPLACE TEMPORARY PROCEDURE {name}(template VARCHAR, args VARCHAR)
469+
RETURNS VARCHAR
470+
LANGUAGE SQL
471+
AS
472+
$$
473+
BEGIN
474+
RETURN template || ' | ' || args;
475+
END;
476+
$$
477+
"""
478+
).collect()
479+
return name
480+
481+
482+
class TestCallIdentifierBinding:
483+
"""
484+
SNOW-3061745: Bindings in CALL previously were not properly transferred through the expression tree.
485+
These previously errored out when a chained operation after `session.sql` triggered a call to
486+
`to_subqueryable`, which did not properly populate binding parameters.
487+
"""
488+
489+
def test_call_collect(self, session, proc_name):
490+
result = session.sql(
491+
"CALL identifier(?)(?, to_varchar(parse_json(?)))",
492+
params=[proc_name, "tmpl", '{"a": 1}'],
493+
).collect()
494+
assert result == [Row('tmpl | {"a":1}')]
495+
496+
def test_call_select(self, session, proc_name):
497+
result = (
498+
session.sql(
499+
"CALL identifier(?)(?, ?)",
500+
params=[proc_name, "tmpl", "args"],
501+
)
502+
.select("*")
503+
.collect()
504+
)
505+
assert result == [Row("tmpl | args")]
506+
507+
def test_call_filter(self, session, proc_name):
508+
result = (
509+
session.sql(
510+
"CALL identifier(?)(?, ?)",
511+
params=[proc_name, "tmpl", "args"],
512+
)
513+
.filter("1=1")
514+
.collect()
515+
)
516+
assert result == [Row("tmpl | args")]
517+
518+
def test_call_sort(self, session, proc_name):
519+
result = (
520+
session.sql(
521+
"CALL identifier(?)(?, ?)",
522+
params=[proc_name, "tmpl", "args"],
523+
)
524+
.sort("$1")
525+
.collect()
526+
)
527+
assert result == [Row("tmpl | args")]
528+
529+
def test_call_union(self, session, proc_name):
530+
df1 = session.sql(
531+
"CALL identifier(?)(?, ?)",
532+
params=[proc_name, "tmpl1", "args1"],
533+
)
534+
df2 = session.sql(
535+
"CALL identifier(?)(?, ?)",
536+
params=[proc_name, "tmpl2", "args2"],
537+
)
538+
result = df1.union_all(df2).collect()
539+
assert result == [Row("tmpl1 | args1"), Row("tmpl2 | args2")]

0 commit comments

Comments
 (0)