Skip to content

Commit b4ffbf3

Browse files
sfc-gh-qdingCortex Code
andcommitted
SNOW-3392625: Add artifact_repository support to dbapi udtf_configs
Forward the artifact_repository parameter from dbapi() udtf_configs through to session.udtf.register(), enabling users to specify a custom artifact repository (e.g. PyPI) for packages used by the internal UDTF created during distributed dbapi ingestion. Changes: - dataframe_reader.py: extract artifact_repository from udtf_configs - datasource_partitioner.py: forward through _udtf_ingestion wrapper - base_driver.py: accept and forward to session.udtf.register() - CHANGELOG.md: add new feature entry - test_data_source_api.py: add test_dbapi_udtf_artifact_repository .... Generated with [Cortex Code](https://docs.snowflake.com/en/user-guide/cortex-code/cortex-code) Co-Authored-By: Cortex Code <noreply@snowflake.com>
1 parent 7c853e1 commit b4ffbf3

5 files changed

Lines changed: 73 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
### Snowpark Python API Updates
66

7+
#### New Features
8+
9+
- Added `artifact_repository` support to `udtf_configs` in `session.read.dbapi()`, enabling users to specify a custom artifact repository (e.g. PyPI) for packages used by the internal UDTF during distributed ingestion.
10+
711
#### Bug Fixes
812

913
- Fixed a bug where `TRY_CAST` reader option is ignored when calling `DataFrameReader.schema().csv()`.

src/snowflake/snowpark/_internal/data_source/datasource_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _udtf_ingestion(
184184
fetch_size: int = 1000,
185185
imports: Optional[List[str]] = None,
186186
packages: Optional[List[str]] = None,
187+
artifact_repository: Optional[str] = None,
187188
session_init_statement: Optional[List[str]] = None,
188189
query_timeout: Optional[int] = 0,
189190
statement_params: Optional[Dict[str, str]] = None,
@@ -197,6 +198,7 @@ def _udtf_ingestion(
197198
fetch_size,
198199
imports,
199200
packages,
201+
artifact_repository,
200202
session_init_statement,
201203
query_timeout,
202204
statement_params,

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def udtf_ingestion(
161161
fetch_size: int = 1000,
162162
imports: Optional[List[str]] = None,
163163
packages: Optional[List[str]] = None,
164+
artifact_repository: Optional[str] = None,
164165
session_init_statement: Optional[List[str]] = None,
165166
query_timeout: Optional[int] = 0,
166167
statement_params: Optional[Dict[str, str]] = None,
@@ -187,6 +188,7 @@ def udtf_ingestion(
187188
external_access_integrations=[external_access_integrations],
188189
packages=packages or UDTF_PACKAGE_MAP.get(self.dbms_type),
189190
imports=imports,
191+
artifact_repository=artifact_repository,
190192
statement_params=statement_params,
191193
_emit_ast=_emit_ast, # internal function call, _emit_ast will be set to False by the caller
192194
)

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,7 @@ def create_oracledb_connection():
23052305
fetch_size=fetch_size,
23062306
imports=udtf_configs.get("imports", None),
23072307
packages=udtf_configs.get("packages", None),
2308+
artifact_repository=udtf_configs.get("artifact_repository", None),
23082309
session_init_statement=session_init_statement,
23092310
query_timeout=query_timeout,
23102311
statement_params=statements_params_for_telemetry,

tests/integ/test_data_source_api.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
sql_server_create_connection_unicode_data,
8080
sql_server_create_connection_double_quoted_data,
8181
)
82-
from tests.utils import Utils, IS_WINDOWS, RUNNING_ON_JENKINS
82+
from tests.utils import Utils, IS_WINDOWS, RUNNING_ON_JENKINS, IS_NOT_ON_GITHUB
8383

8484
try:
8585
import pandas # noqa: F401
@@ -1924,3 +1924,66 @@ def create_mock_oracledb_v34_or_higher():
19241924
dbms_type, driver_type = detect_dbms(create_mock_oracledb_v34_or_higher())
19251925
assert dbms_type == DBMS_TYPE.ORACLE_DB
19261926
assert driver_type == DRIVER_TYPE.ORACLEDB
1927+
1928+
1929+
@pytest.mark.skipif(
1930+
"config.getoption('local_testing_mode', default=False)",
1931+
reason="artifact repository not supported in local testing",
1932+
)
1933+
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
1934+
@pytest.mark.skipif(
1935+
RUNNING_ON_JENKINS,
1936+
reason="SNOW-2089683: oracledb real connection test failed on jenkins",
1937+
)
1938+
@pytest.mark.udf
1939+
def test_dbapi_udtf_artifact_repository(session, ast_enabled):
1940+
"""Verify artifact_repository in udtf_configs is forwarded to the UDTF."""
1941+
if ast_enabled:
1942+
pytest.skip("TODO: dbapi has not implemented ast yet, skip the test for now")
1943+
1944+
import sys
1945+
1946+
if sys.version_info < (3, 9):
1947+
pytest.skip("artifact repository requires Python 3.9+")
1948+
1949+
def create_connection():
1950+
class FakeConnection:
1951+
def cursor(self):
1952+
class FakeCursor:
1953+
def execute(self, query):
1954+
pass
1955+
1956+
@property
1957+
def description(self):
1958+
return [("c1", int, None, None, None, None, None)]
1959+
1960+
def fetchmany(self, *args, **kwargs):
1961+
return None
1962+
1963+
return FakeCursor()
1964+
1965+
return FakeConnection()
1966+
1967+
his = session.query_history()
1968+
df = session.read.dbapi(
1969+
create_connection,
1970+
table="Fake",
1971+
custom_schema="c1 INT",
1972+
udtf_configs={
1973+
"external_access_integration": ORACLEDB_TEST_EXTERNAL_ACCESS_INTEGRATION,
1974+
"packages": ["snowflake-snowpark-python", "cloudpickle"],
1975+
"artifact_repository": "SNOWPARK_PYTHON_TEST_REPOSITORY",
1976+
},
1977+
)
1978+
df.select("*").collect()
1979+
1980+
# Verify the CREATE FUNCTION DDL includes the artifact repository
1981+
create_function_queries = [
1982+
q.sql_text
1983+
for q in his.queries
1984+
if "CREATE" in q.sql_text.upper() and "FUNCTION" in q.sql_text.upper()
1985+
]
1986+
assert len(create_function_queries) > 0, "Expected at least one CREATE FUNCTION query"
1987+
assert any(
1988+
"SNOWPARK_PYTHON_TEST_REPOSITORY" in q for q in create_function_queries
1989+
), f"artifact_repository not found in UDTF DDL: {create_function_queries}"

0 commit comments

Comments
 (0)