|
9 | 9 | import pytest |
10 | 10 |
|
11 | 11 | from snowflake.snowpark import Row |
12 | | -from snowflake.snowpark._internal.utils import is_in_stored_procedure |
| 12 | +from snowflake.snowpark._internal.utils import TempObjectType, is_in_stored_procedure |
13 | 13 | from snowflake.snowpark.exceptions import SnowparkSQLException |
14 | 14 | from snowflake.snowpark.functions import col, lit, max, table_function |
15 | 15 | from snowflake.snowpark.types import ( |
@@ -457,3 +457,83 @@ def test_explain(session): |
457 | 457 | params=[1, "a", 2, "b"], |
458 | 458 | ) |
459 | 459 | df.explain() |
| 460 | + |
| 461 | + |
| 462 | +@pytest.fixture(scope="module") |
| 463 | +def proc_name(session): |
| 464 | + """Create a trivial stored procedure that echoes its inputs back.""" |
| 465 | + name = f"{session.get_fully_qualified_current_schema()}.{Utils.random_name_for_temp_object(TempObjectType.PROCEDURE)}" |
| 466 | + session.sql( |
| 467 | + f""" |
| 468 | + CREATE OR REPLACE TEMPORARY PROCEDURE {name}(template VARCHAR, args VARCHAR) |
| 469 | + RETURNS VARCHAR |
| 470 | + LANGUAGE SQL |
| 471 | + AS |
| 472 | + $$ |
| 473 | + BEGIN |
| 474 | + RETURN template || ' | ' || args; |
| 475 | + END; |
| 476 | + $$ |
| 477 | + """ |
| 478 | + ).collect() |
| 479 | + return name |
| 480 | + |
| 481 | + |
| 482 | +class TestCallIdentifierBinding: |
| 483 | + """ |
| 484 | + SNOW-3061745: Bindings in CALL previously were not properly transferred through the expression tree. |
| 485 | + These previously errored out when a chained operation after `session.sql` triggered a call to |
| 486 | + `to_subqueryable`, which did not properly populate binding parameters. |
| 487 | + """ |
| 488 | + |
| 489 | + def test_call_collect(self, session, proc_name): |
| 490 | + result = session.sql( |
| 491 | + "CALL identifier(?)(?, to_varchar(parse_json(?)))", |
| 492 | + params=[proc_name, "tmpl", '{"a": 1}'], |
| 493 | + ).collect() |
| 494 | + assert result == [Row('tmpl | {"a":1}')] |
| 495 | + |
| 496 | + def test_call_select(self, session, proc_name): |
| 497 | + result = ( |
| 498 | + session.sql( |
| 499 | + "CALL identifier(?)(?, ?)", |
| 500 | + params=[proc_name, "tmpl", "args"], |
| 501 | + ) |
| 502 | + .select("*") |
| 503 | + .collect() |
| 504 | + ) |
| 505 | + assert result == [Row("tmpl | args")] |
| 506 | + |
| 507 | + def test_call_filter(self, session, proc_name): |
| 508 | + result = ( |
| 509 | + session.sql( |
| 510 | + "CALL identifier(?)(?, ?)", |
| 511 | + params=[proc_name, "tmpl", "args"], |
| 512 | + ) |
| 513 | + .filter("1=1") |
| 514 | + .collect() |
| 515 | + ) |
| 516 | + assert result == [Row("tmpl | args")] |
| 517 | + |
| 518 | + def test_call_sort(self, session, proc_name): |
| 519 | + result = ( |
| 520 | + session.sql( |
| 521 | + "CALL identifier(?)(?, ?)", |
| 522 | + params=[proc_name, "tmpl", "args"], |
| 523 | + ) |
| 524 | + .sort("$1") |
| 525 | + .collect() |
| 526 | + ) |
| 527 | + assert result == [Row("tmpl | args")] |
| 528 | + |
| 529 | + def test_call_union(self, session, proc_name): |
| 530 | + df1 = session.sql( |
| 531 | + "CALL identifier(?)(?, ?)", |
| 532 | + params=[proc_name, "tmpl1", "args1"], |
| 533 | + ) |
| 534 | + df2 = session.sql( |
| 535 | + "CALL identifier(?)(?, ?)", |
| 536 | + params=[proc_name, "tmpl2", "args2"], |
| 537 | + ) |
| 538 | + result = df1.union_all(df2).collect() |
| 539 | + assert result == [Row("tmpl1 | args1"), Row("tmpl2 | args2")] |
0 commit comments