Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

### Snowpark Python API Updates

#### Bug Fixes

- Fixed a bug where `cloudpickle` could not be resolved when registering a Python stored procedure or UDF with `runtime_version='3.13'`.

#### New Features

- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments).
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ def resolve_imports_and_packages(
any(pkg.startswith("cloudpickle") for pkg in packages)
)
resolved_packages = packages + (
[f"cloudpickle=={cloudpickle.__version__}"]
[f"cloudpickle>={cloudpickle.__version__}"]
if not has_cloudpickle
else []
)
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,7 +2134,10 @@ def _get_req_identifiers_list(
if isinstance(m, str) and m not in result_dict:
res.append(m)
elif isinstance(m, ModuleType) and m.__name__ not in result_dict:
res.append(f"{m.__name__}=={m.__version__}")
if m.__name__ == "cloudpickle":
res.append(f"{m.__name__}>={m.__version__}")
else:
res.append(f"{m.__name__}=={m.__version__}")

return res

Expand Down
24 changes: 24 additions & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,3 +2736,27 @@ def normalize(rows):
]

assert normalize(df.collect()) == normalize(oracledb_real_data)


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="not supported in stored proc")
def test_sproc_runtime_313_cloudpickle_ge_spec_compiles_and_executes(session):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking: since this test specifies python==3.13, wondering whether other test matrices with 3.11, 3.12 etc would execute this? Would this be potentially flaky for other python versions, though the merge gate seems to be passing here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be good even if 3.11/3.12 execute this test.
so this change is only making sure that cloudpickle can be correctly picked in stored proc. Say if 3.12 execute this with cloudpickle 2.2.1, it will try to pick cloudpickle >=2.2.1 for py313, as long as this cloudpickle pkg exist on server, the test would not fail. For the situation the test do fail, I think that is more of a packaging issue instead of flakiness

"""Regression test for SNOW-3081273.

Verifies that a sproc targeting Python 3.13 deploys and executes correctly
even when the local cloudpickle version differs from the server-resolved one.
The auto-injected cloudpickle>=X spec must satisfy the 3.13 channel.
"""
multiplier = 7

def multiply(session_: Session, x: int) -> int:
return x * multiplier

sp = session.sproc.register(
multiply,
return_type=IntegerType(),
input_types=[IntegerType()],
packages=["snowflake-snowpark-python"],
runtime_version="3.13",
is_permanent=False,
)
assert sp(6) == 42
14 changes: 14 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import types
from typing import Optional
from unittest import mock
from unittest.mock import MagicMock
Expand Down Expand Up @@ -880,6 +881,19 @@ def run_query_side_effect(query, **kwargs):
ctx._aggregation_function_set = original_agg_set


def test_get_req_identifiers_list_cloudpickle_only_uses_ge(mock_server_connection):
"""Only cloudpickle is injected with >= to allow runtime-compatible resolution."""
import cloudpickle as cp

session = Session(mock_server_connection)
dummy_module = types.ModuleType("dummy_module")
dummy_module.__version__ = "1.2.3"

result = session._get_req_identifiers_list([cp, dummy_module], {})
assert result == [f"cloudpickle>={cp.__version__}", "dummy_module==1.2.3"]
assert f"cloudpickle=={cp.__version__}" not in result


def test_retrieve_aggregation_function_list_uses_single_internal_sync_query():
"""Sync fallback executes exactly one internal metadata query."""
import snowflake.snowpark.context as ctx
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def mock_callback(extension_function_properties):
assert extension_function_properties.all_imports == ""
# for >= 3.14, we use pypi by default, which auto adds cloudpickle
assert extension_function_properties.all_packages == (
f"'cloudpickle=={cloudpickle.__version__}'"
f"'cloudpickle>={cloudpickle.__version__}'"
if sys.version_info >= (3, 14)
else ""
)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_artifact_repository_adds_cloudpickle():
assert all_packages is not None
package_list = all_packages.split(",") if all_packages else []
assert any(
pkg.strip().strip("'").startswith("cloudpickle==") for pkg in package_list
pkg.strip().strip("'").startswith("cloudpickle>=") for pkg in package_list
), f"cloudpickle not found in packages: {all_packages}"

# Test case 2: packages already contains cloudpickle
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,24 @@ def finish(self):
# wrong class type should fallback to default arg names
arg_names = get_func_arg_names(SumUDAF, TempObjectType.TABLE_FUNCTION, 2, True)
assert arg_names == ["arg1", "arg2"]


def test_resolve_imports_and_packages_non_conda_injects_cloudpickle_ge():
"""Auto-injected cloudpickle uses >= not == so the server can resolve a
compatible version for the target runtime (SNOW-3081273)."""
import cloudpickle

_, _, _, all_packages, _, _ = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.PROCEDURE,
func=lambda: None,
arg_names=[],
udf_name="test_sp",
stage_location=None,
imports=None,
packages=["snowflake-snowpark-python"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
assert all_packages is not None
assert f"cloudpickle>={cloudpickle.__version__}" in all_packages
assert f"cloudpickle=={cloudpickle.__version__}" not in all_packages
Loading