From 7542a851cf664e0f33d639cf9a5f1581d67e6f86 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 4 Sep 2025 15:31:59 -0700 Subject: [PATCH 01/16] add support for session init statement in udtf ingetion of dbapi --- CHANGELOG.md | 2 + .../data_source/datasource_reader.py | 2 +- .../data_source/drivers/base_driver.py | 20 ++++++- .../data_source/drivers/databricks_driver.py | 8 ++- .../data_source/drivers/oracledb_driver.py | 11 +++- .../data_source/drivers/psycopg2_driver.py | 11 +++- .../data_source/drivers/pymsql_driver.py | 15 ++++- .../data_source/drivers/pyodbc_driver.py | 11 +++- src/snowflake/snowpark/dataframe_reader.py | 2 + tests/integ/datasource/test_databricks.py | 39 +++++++++++- tests/integ/datasource/test_mysql.py | 60 +++++++++++++++++++ tests/integ/datasource/test_oracledb.py | 57 ++++++++++++++++++ tests/integ/datasource/test_postgres.py | 28 ++++++++- tests/integ/test_data_source_api.py | 2 +- 14 files changed, 250 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cef5f904c2..fefde396c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ - Added support for `FileOperation.remove` to remove files in a stage. - Added support for parameter `use_vectorized_scanner` in function `Session.write_pandas()`. +- Added support for parameter `session_init_statement` in udtf ingestion of `DataFrameReader.jdbc`(PrPr). + #### Bug Fixes #### Deprecations diff --git a/src/snowflake/snowpark/_internal/data_source/datasource_reader.py b/src/snowflake/snowpark/_internal/data_source/datasource_reader.py index 507141b9fa..cb2ea1e01d 100644 --- a/src/snowflake/snowpark/_internal/data_source/datasource_reader.py +++ b/src/snowflake/snowpark/_internal/data_source/datasource_reader.py @@ -60,7 +60,7 @@ def read(self, partition: str) -> Iterator[List[Any]]: cursor.execute(statement) except BaseException as exc: raise SnowparkDataframeReaderException( - f"Failed to execute session init statement: '{statement}' due to exception '{exc!r}'" + f"Failed to execute session init statement: '{statement}' due to exception '{exc}'" ) # use server side cursor to fetch data if supported by the driver # some drivers do not support execute twice on server side cursor (e.g. psycopg2) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py index a9de3af7b7..05acc485c0 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py @@ -139,6 +139,8 @@ def udtf_ingestion( fetch_size: int = 1000, imports: Optional[List[str]] = None, packages: Optional[List[str]] = None, + session_init_statement: Optional[List[str]] = None, + query_timeout: Optional[int] = 0, _emit_ast: bool = True, ) -> "snowflake.snowpark.DataFrame": from snowflake.snowpark._internal.data_source.utils import UDTF_PACKAGE_MAP @@ -146,7 +148,12 @@ def udtf_ingestion( udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION) with measure_time() as udtf_register_time: session.udtf.register( - self.udtf_class_builder(fetch_size=fetch_size, schema=schema), + self.udtf_class_builder( + fetch_size=fetch_size, + schema=schema, + session_init_statement=session_init_statement, + query_timeout=query_timeout, + ), name=udtf_name, output_schema=StructType( [ @@ -166,14 +173,21 @@ def udtf_ingestion( return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast) def udtf_class_builder( - self, fetch_size: int = 1000, schema: StructType = None + self, + fetch_size: int = 1000, + schema: StructType = None, + session_init_statement: List[str] = None, + query_timeout: int = 0, ) -> type: create_connection = self.create_connection + prepare_connection = self.prepare_connection class UDTFIngestion: def process(self, query: str): - conn = create_connection() + conn = prepare_connection(create_connection(), query_timeout) cursor = conn.cursor() + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py index 226b45236d..30532a5405 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py @@ -69,7 +69,11 @@ def to_snow_type(self, schema: List[Any]) -> StructType: return StructType(all_columns) def udtf_class_builder( - self, fetch_size: int = 1000, schema: StructType = None + self, + fetch_size: int = 1000, + schema: StructType = None, + session_init_statement: List[str] = None, + query_timeout: int = 0, ) -> type: create_connection = self.create_connection @@ -77,6 +81,8 @@ class UDTFIngestion: def process(self, query: str): conn = create_connection() cursor = conn.cursor() + for statement in session_init_statement: + cursor.execute(statement) # First get schema information describe_query = f"DESCRIBE QUERY SELECT * FROM ({query})" diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py index 11d7b9ec07..0ca526f0ec 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py @@ -111,9 +111,14 @@ def prepare_connection( return conn def udtf_class_builder( - self, fetch_size: int = 1000, schema: StructType = None + self, + fetch_size: int = 1000, + schema: StructType = None, + session_init_statement: List[str] = None, + query_timeout: int = 0, ) -> type: create_connection = self.create_connection + prepare_connection = self.prepare_connection def oracledb_output_type_handler(cursor, metadata): from oracledb import ( @@ -137,10 +142,12 @@ def convert_to_hex(value): class UDTFIngestion: def process(self, query: str): - conn = create_connection() + conn = prepare_connection(create_connection(), query_timeout) if conn.outputtypehandler is None: conn.outputtypehandler = oracledb_output_type_handler cursor = conn.cursor() + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py index 8bfa734f92..ecdfaff32c 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py @@ -253,7 +253,11 @@ def prepare_connection( return conn def udtf_class_builder( - self, fetch_size: int = 1000, schema: StructType = None + self, + fetch_size: int = 1000, + schema: StructType = None, + session_init_statement: List[str] = None, + query_timeout: int = 0, ) -> type: create_connection = self.create_connection @@ -275,10 +279,13 @@ def prepare_connection_in_udtf( class UDTFIngestion: def process(self, query: str): - conn = prepare_connection_in_udtf(create_connection()) + conn = prepare_connection_in_udtf(create_connection(), query_timeout) cursor = conn.cursor( f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}" ) + for statement in session_init_statement: + cursor.execute(statement) + cursor.fetchall() cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py index 268a9145ae..b81e51dd4c 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py @@ -184,16 +184,23 @@ def to_snow_type(self, schema: List[Any]) -> StructType: return StructType(fields) def udtf_class_builder( - self, fetch_size: int = 1000, schema: StructType = None + self, + fetch_size: int = 1000, + schema: StructType = None, + session_init_statement: List[str] = None, + query_timeout: int = 0, ) -> type: create_connection = self.create_connection + prepare_connection = self.prepare_connection class UDTFIngestion: def process(self, query: str): import pymysql - conn = create_connection() + conn = prepare_connection(create_connection()) cursor = pymysql.cursors.SSCursor(conn) + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) @@ -208,7 +215,9 @@ def prepare_connection( conn: "Connection", query_timeout: int = 0, ) -> "Connection": - conn.read_timeout = query_timeout if query_timeout != 0 else None + if query_timeout > 0: + cursor = conn.cursor() + cursor.execute(f"SET SESSION MAX_EXECUTION_TIME={1000 * query_timeout}") return conn @staticmethod diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py index 9ecbac9ea9..4191231cbd 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py @@ -78,9 +78,14 @@ def to_snow_type(self, schema: List[Any]) -> StructType: return StructType(fields) def udtf_class_builder( - self, fetch_size: int = 1000, schema: StructType = None + self, + fetch_size: int = 1000, + schema: StructType = None, + session_init_statement: List[str] = None, + query_timeout: int = 0, ) -> type: create_connection = self.create_connection + prepare_connection = self.prepare_connection def binary_converter(value): return value.hex() if value is not None else None @@ -89,7 +94,7 @@ class UDTFIngestion: def process(self, query: str): import pyodbc - conn = create_connection() + conn = prepare_connection(create_connection(), query_timeout) if ( conn.get_output_converter(pyodbc.SQL_BINARY) is None and conn.get_output_converter(pyodbc.SQL_VARBINARY) is None @@ -101,6 +106,8 @@ def process(self, query: str): pyodbc.SQL_LONGVARBINARY, binary_converter ) cursor = conn.cursor() + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 5533ebb246..d376751d8b 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1859,6 +1859,8 @@ def create_oracledb_connection(): fetch_size=fetch_size, imports=udtf_configs.get("imports", None), packages=udtf_configs.get("packages", None), + session_init_statement=session_init_statement, + query_timeout=query_timeout, _emit_ast=_emit_ast, ) end_time = time.perf_counter() diff --git a/tests/integ/datasource/test_databricks.py b/tests/integ/datasource/test_databricks.py index 954dec122d..8fe7d220ce 100644 --- a/tests/integ/datasource/test_databricks.py +++ b/tests/integ/datasource/test_databricks.py @@ -17,7 +17,10 @@ random_name_for_temp_object, TempObjectType, ) -from snowflake.snowpark.exceptions import SnowparkDataframeReaderException +from snowflake.snowpark.exceptions import ( + SnowparkDataframeReaderException, + SnowparkSQLException, +) from snowflake.snowpark.types import ( StructType, StructField, @@ -258,3 +261,37 @@ def test_unsupported_type(): create_databricks_connection, DBMS_TYPE.DATABRICKS_DB ).to_snow_type([("test_col", "unsupported_type", True)]) assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) + + +def test_session_init(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="syntax error command", + ): + session.read.dbapi( + create_databricks_connection, + table=TEST_TABLE_NAME, + session_init_statement=["syntax error command"], + ) + + +def test_session_init_udtf(session): + udtf_configs = { + "external_access_integration": DATABRICKS_TEST_EXTERNAL_ACCESS_INTEGRATION + } + + def create_databricks_udtf_connection(): + import databricks.sql + + return databricks.sql.connect(**DATABRICKS_CONNECTION_PARAMETERS) + + with pytest.raises( + SnowparkSQLException, + match="syntax error command", + ): + session.read.dbapi( + create_databricks_udtf_connection, + table=TEST_TABLE_NAME, + session_init_statement=["syntax error command"], + udtf_configs=udtf_configs, + ).collect() diff --git a/tests/integ/datasource/test_mysql.py b/tests/integ/datasource/test_mysql.py index 6878437a68..6565d8bbb1 100644 --- a/tests/integ/datasource/test_mysql.py +++ b/tests/integ/datasource/test_mysql.py @@ -15,6 +15,10 @@ ) from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE from snowflake.snowpark.types import StructType, StructField, StringType +from snowflake.snowpark.exceptions import ( + SnowparkDataframeReaderException, + SnowparkSQLException, +) from tests.resources.test_data_source_dir.test_mysql_data import ( mysql_real_data, MysqlType, @@ -297,3 +301,59 @@ def test_unsupported_type(): [("test_col", "unsupported_type", None, None, 0, 0, True)] ) assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) + + +def test_query_timeout(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="Query execution was interrupted, maximum statement execution time exceeded", + ): + session.read.dbapi( + create_connection_mysql, + query="SELECT COUNT(*) AS a FROM (SELECT SLEEP(2) AS x) AS t", + query_timeout=1, + ) + + +def test_session_init(session): + with pytest.raises( + SnowparkDataframeReaderException, + match="Mock error to test init_statement", + ): + session.read.dbapi( + create_connection_mysql, + table=TEST_TABLE_NAME, + session_init_statement=[ + "SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Mock error to test init_statement'" + ], + ) + + +def test_session_init_udtf(session): + udtf_configs = { + "external_access_integration": MYSQL_TEST_EXTERNAL_ACCESS_INTEGRATION + } + + def create_connection_udtf_oracledb(): + import pymysql # noqa: F811 + + conn = pymysql.connect( + user=MYSQL_CONNECTION_PARAMETERS["username"], + password=MYSQL_CONNECTION_PARAMETERS["password"], + host=MYSQL_CONNECTION_PARAMETERS["host"], + database=MYSQL_CONNECTION_PARAMETERS["database"], + ) + return conn + + with pytest.raises( + SnowparkSQLException, + match="Mock error to test init_statement", + ): + session.read.dbapi( + create_connection_udtf_oracledb, + table=TEST_TABLE_NAME, + session_init_statement=[ + "SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Mock error to test init_statement'" + ], + udtf_configs=udtf_configs, + ).collect() diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index 7360f8a74c..cd4b64208e 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -20,6 +20,10 @@ DBMS_TYPE, ) from snowflake.snowpark.types import StructType, StructField, StringType +from snowflake.snowpark.exceptions import ( + SnowparkDataframeReaderException, + SnowparkSQLException, +) from tests.parameters import ORACLEDB_CONNECTION_PARAMETERS from tests.resources.test_data_source_dir.test_data_source_data import ( OracleDBType, @@ -248,3 +252,56 @@ def test_unsupported_type(): create_connection_oracledb, DBMS_TYPE.ORACLE_DB ).to_snow_type([MockDescription("test_col", invalid_type, 0, 0, True)]) assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) + + +def test_query_timeout_and_session_init(session): + statement = """ + BEGIN + DBMS_LOCK.SLEEP(5); + END; +""" + with pytest.raises( + SnowparkDataframeReaderException, + match="socket timed out while recovering from previous socket timeout", + ): + session.read.dbapi( + create_connection_oracledb, + table=ORACLEDB_TABLE_NAME, + query_timeout=1, + session_init_statement=[statement], + ) + + +def test_query_timeout_and_session_init_udtf(session): + udtf_configs = { + "external_access_integration": ORACLEDB_TEST_EXTERNAL_ACCESS_INTEGRATION + } + statement = """ + BEGIN + DBMS_LOCK.SLEEP(5); + END; + """ + + def create_connection_udtf_oracledb(): + import oracledb + + host = ORACLEDB_CONNECTION_PARAMETERS["host"] + port = ORACLEDB_CONNECTION_PARAMETERS["port"] + service_name = ORACLEDB_CONNECTION_PARAMETERS["service_name"] + username = ORACLEDB_CONNECTION_PARAMETERS["username"] + password = ORACLEDB_CONNECTION_PARAMETERS["password"] + dsn = f"{host}:{port}/{service_name}" + connection = oracledb.connect(user=username, password=password, dsn=dsn) + return connection + + with pytest.raises( + SnowparkSQLException, + match="call timeout of 1000 ms exceeded", + ): + session.read.dbapi( + create_connection_udtf_oracledb, + table=ORACLEDB_TABLE_NAME, + query_timeout=1, + session_init_statement=[statement], + udtf_configs=udtf_configs, + ).collect() diff --git a/tests/integ/datasource/test_postgres.py b/tests/integ/datasource/test_postgres.py index 82e1b13a8f..a2650b634d 100644 --- a/tests/integ/datasource/test_postgres.py +++ b/tests/integ/datasource/test_postgres.py @@ -11,7 +11,10 @@ Psycopg2TypeCode, ) from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE -from snowflake.snowpark.exceptions import SnowparkDataframeReaderException +from snowflake.snowpark.exceptions import ( + SnowparkDataframeReaderException, + SnowparkSQLException, +) from snowflake.snowpark.types import ( DecimalType, BinaryType, @@ -102,7 +105,7 @@ def test_error_case(session, input_type, input_value, error_message): session.read.dbapi(create_postgres_connection, **input_dict) -def test_query_timeout(session): +def test_query_timeout_and_session_init(session): with pytest.raises( SnowparkDataframeReaderException, match=r"due to exception 'QueryCanceled\('canceling statement due to statement timeout", @@ -115,6 +118,27 @@ def test_query_timeout(session): ) +def test_query_timeout_and_session_init_udtf(session): + udtf_configs = { + "external_access_integration": POSTGRES_TEST_EXTERNAL_ACCESS_INTEGRATION + } + + def create_postgres_udtf_connection(): + return psycopg2.connect(**POSTGRES_CONNECTION_PARAMETERS) + + with pytest.raises( + SnowparkSQLException, + match="canceling statement due to statement timeout", + ): + session.read.dbapi( + create_postgres_udtf_connection, + table=POSTGRES_TABLE_NAME, + query_timeout=1, + session_init_statement=["SELECT pg_sleep(5)"], + udtf_configs=udtf_configs, + ).collect() + + def test_external_access_integration_not_set(session): with pytest.raises( ValueError, diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 7de529432e..ac28ee7bf0 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -558,7 +558,7 @@ def test_session_init_statement(session, fetch_with_process): with pytest.raises( SnowparkDataframeReaderException, - match=r'Failed to execute session init statement: \'SELECT FROM NOTHING;\' due to exception \'OperationalError\(\'near "FROM": syntax error\'\)\'', + match="Failed to execute session init statement:", ): session.read.dbapi( functools.partial(create_connection_to_sqlite3_db, dbpath), From 0577f8ec6f1a4406d4ffc374d5b37a859e6a3b6d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 9 Sep 2025 11:42:42 -0700 Subject: [PATCH 02/16] fix test --- .../_internal/data_source/drivers/base_driver.py | 5 +++-- .../_internal/data_source/drivers/databricks_driver.py | 5 +++-- .../_internal/data_source/drivers/oracledb_driver.py | 9 +++++---- .../_internal/data_source/drivers/psycopg2_driver.py | 6 +++--- .../_internal/data_source/drivers/pymsql_driver.py | 5 +++-- .../_internal/data_source/drivers/pyodbc_driver.py | 5 +++-- 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py index 05acc485c0..28b45b5770 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py @@ -186,8 +186,9 @@ class UDTFIngestion: def process(self, query: str): conn = prepare_connection(create_connection(), query_timeout) cursor = conn.cursor() - for statement in session_init_statement: - cursor.execute(statement) + if session_init_statement is not None: + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py index 30532a5405..62d1e2bcce 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py @@ -81,8 +81,9 @@ class UDTFIngestion: def process(self, query: str): conn = create_connection() cursor = conn.cursor() - for statement in session_init_statement: - cursor.execute(statement) + if session_init_statement is not None: + for statement in session_init_statement: + cursor.execute(statement) # First get schema information describe_query = f"DESCRIBE QUERY SELECT * FROM ({query})" diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py index 0ca526f0ec..e0eafa0356 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py @@ -118,7 +118,6 @@ def udtf_class_builder( query_timeout: int = 0, ) -> type: create_connection = self.create_connection - prepare_connection = self.prepare_connection def oracledb_output_type_handler(cursor, metadata): from oracledb import ( @@ -142,12 +141,14 @@ def convert_to_hex(value): class UDTFIngestion: def process(self, query: str): - conn = prepare_connection(create_connection(), query_timeout) + conn = create_connection() + conn.call_timeout = query_timeout * 1000 if conn.outputtypehandler is None: conn.outputtypehandler = oracledb_output_type_handler cursor = conn.cursor() - for statement in session_init_statement: - cursor.execute(statement) + if session_init_statement is not None: + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py index ecdfaff32c..18f262e92e 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py @@ -283,9 +283,9 @@ def process(self, query: str): cursor = conn.cursor( f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}" ) - for statement in session_init_statement: - cursor.execute(statement) - cursor.fetchall() + if session_init_statement is not None: + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py index b81e51dd4c..7fea1937e0 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py @@ -199,8 +199,9 @@ def process(self, query: str): conn = prepare_connection(create_connection()) cursor = pymysql.cursors.SSCursor(conn) - for statement in session_init_statement: - cursor.execute(statement) + if session_init_statement is not None: + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py index 4191231cbd..cbdedc3ccb 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py @@ -106,8 +106,9 @@ def process(self, query: str): pyodbc.SQL_LONGVARBINARY, binary_converter ) cursor = conn.cursor() - for statement in session_init_statement: - cursor.execute(statement) + if session_init_statement is not None: + for statement in session_init_statement: + cursor.execute(statement) cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) From 28f08b4240eedb2ccbbda6dff9e6545721484483 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 9 Sep 2025 14:48:26 -0700 Subject: [PATCH 03/16] fix test --- .../_internal/data_source/drivers/psycopg2_driver.py | 1 + tests/integ/datasource/test_oracledb.py | 9 +++++---- tests/integ/datasource/test_postgres.py | 2 +- tests/integ/test_data_source_api.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py index 18f262e92e..dad13eca8a 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py @@ -286,6 +286,7 @@ def process(self, query: str): if session_init_statement is not None: for statement in session_init_statement: cursor.execute(statement) + cursor.fetchall() cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index cd4b64208e..6a837c0e42 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -260,16 +260,17 @@ def test_query_timeout_and_session_init(session): DBMS_LOCK.SLEEP(5); END; """ - with pytest.raises( - SnowparkDataframeReaderException, - match="socket timed out while recovering from previous socket timeout", - ): + with pytest.raises(SnowparkDataframeReaderException) as error: session.read.dbapi( create_connection_oracledb, table=ORACLEDB_TABLE_NAME, query_timeout=1, session_init_statement=[statement], ) + assert ( + "socket timed out while recovering from previous socket timeout" in error + or "call timeout of 1000 ms exceeded" in error + ) def test_query_timeout_and_session_init_udtf(session): diff --git a/tests/integ/datasource/test_postgres.py b/tests/integ/datasource/test_postgres.py index a2650b634d..9d6587d28a 100644 --- a/tests/integ/datasource/test_postgres.py +++ b/tests/integ/datasource/test_postgres.py @@ -108,7 +108,7 @@ def test_error_case(session, input_type, input_value, error_message): def test_query_timeout_and_session_init(session): with pytest.raises( SnowparkDataframeReaderException, - match=r"due to exception 'QueryCanceled\('canceling statement due to statement timeout", + match="canceling statement due to statement timeout", ): session.read.dbapi( create_postgres_connection, diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index ac28ee7bf0..2798acae0c 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -1018,7 +1018,7 @@ def fetchmany(self, row_count: int): driver.to_snow_type(raw_schema), partitions_table, "", - packages=["pyodbc"], + packages=["pyodbc", "snowflake-snowpark-python"], ) Utils.check_answer(df, sql_server_udtf_ingestion_data) From 4233f1a88648e6deeab24a6b2e07b543af2d0d05 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 10 Sep 2025 18:25:16 -0700 Subject: [PATCH 04/16] fix changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fefde396c1..dd0fbd2399 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,12 @@ - Added support for `FileOperation.list` to list files in a stage with metadata. - Added support for `FileOperation.remove` to remove files in a stage. - Added support for parameter `use_vectorized_scanner` in function `Session.write_pandas()`. - - Added support for parameter `session_init_statement` in udtf ingestion of `DataFrameReader.jdbc`(PrPr). #### Bug Fixes +- Fixed a bug that `query_timeout` does not work in udtf ingestion of `DataFrameReader.jdbc`(PrPr). + #### Deprecations #### Dependency Updates From c6e42be00b3be84fbe25cc87a1c4e4e6fb9ab62f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 11 Sep 2025 10:14:15 -0700 Subject: [PATCH 05/16] address comments --- .../data_source/drivers/pymsql_driver.py | 13 +------------ tests/integ/datasource/test_mysql.py | 16 ++-------------- tests/integ/datasource/test_oracledb.py | 8 ++++---- 3 files changed, 7 insertions(+), 30 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py index 7fea1937e0..7428bce4af 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py @@ -191,13 +191,12 @@ def udtf_class_builder( query_timeout: int = 0, ) -> type: create_connection = self.create_connection - prepare_connection = self.prepare_connection class UDTFIngestion: def process(self, query: str): import pymysql - conn = prepare_connection(create_connection()) + conn = create_connection() cursor = pymysql.cursors.SSCursor(conn) if session_init_statement is not None: for statement in session_init_statement: @@ -211,16 +210,6 @@ def process(self, query: str): return UDTFIngestion - def prepare_connection( - self, - conn: "Connection", - query_timeout: int = 0, - ) -> "Connection": - if query_timeout > 0: - cursor = conn.cursor() - cursor.execute(f"SET SESSION MAX_EXECUTION_TIME={1000 * query_timeout}") - return conn - @staticmethod def infer_type_from_data(data: List[tuple], number_of_columns: int) -> List[Type]: # TODO: SNOW-2112938 investigate whether different types can be fit into one column diff --git a/tests/integ/datasource/test_mysql.py b/tests/integ/datasource/test_mysql.py index 6565d8bbb1..9bb2e2acb3 100644 --- a/tests/integ/datasource/test_mysql.py +++ b/tests/integ/datasource/test_mysql.py @@ -303,18 +303,6 @@ def test_unsupported_type(): assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)]) -def test_query_timeout(session): - with pytest.raises( - SnowparkDataframeReaderException, - match="Query execution was interrupted, maximum statement execution time exceeded", - ): - session.read.dbapi( - create_connection_mysql, - query="SELECT COUNT(*) AS a FROM (SELECT SLEEP(2) AS x) AS t", - query_timeout=1, - ) - - def test_session_init(session): with pytest.raises( SnowparkDataframeReaderException, @@ -334,7 +322,7 @@ def test_session_init_udtf(session): "external_access_integration": MYSQL_TEST_EXTERNAL_ACCESS_INTEGRATION } - def create_connection_udtf_oracledb(): + def create_connection_udtf_mysql(): import pymysql # noqa: F811 conn = pymysql.connect( @@ -350,7 +338,7 @@ def create_connection_udtf_oracledb(): match="Mock error to test init_statement", ): session.read.dbapi( - create_connection_udtf_oracledb, + create_connection_udtf_mysql, table=TEST_TABLE_NAME, session_init_statement=[ "SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Mock error to test init_statement'" diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index 6a837c0e42..c5810344a5 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -267,10 +267,10 @@ def test_query_timeout_and_session_init(session): query_timeout=1, session_init_statement=[statement], ) - assert ( - "socket timed out while recovering from previous socket timeout" in error - or "call timeout of 1000 ms exceeded" in error - ) + assert ( + "socket timed out while recovering from previous socket timeout" in error + or "call timeout of 1000 ms exceeded" in error + ) def test_query_timeout_and_session_init_udtf(session): From ef6b41ffe0d6e3a1047279e84fee0ff4ee8190e7 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 11 Sep 2025 14:17:56 -0700 Subject: [PATCH 06/16] fix test --- tests/integ/datasource/test_oracledb.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index c5810344a5..5a71822588 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -267,10 +267,9 @@ def test_query_timeout_and_session_init(session): query_timeout=1, session_init_statement=[statement], ) - assert ( - "socket timed out while recovering from previous socket timeout" in error - or "call timeout of 1000 ms exceeded" in error - ) + assert "socket timed out while recovering from previous socket timeout" in str( + error.value + ) or "call timeout of 1000 ms exceeded" in str(error.value) def test_query_timeout_and_session_init_udtf(session): From 7940a9e30fbe5eedbcf5c495dc1938c32463545c Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 11 Sep 2025 16:27:16 -0700 Subject: [PATCH 07/16] address comments --- .../_internal/data_source/drivers/oracledb_driver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py index e0eafa0356..745adec64f 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py @@ -105,7 +105,8 @@ def prepare_connection( conn: "Connection", query_timeout: int = 0, ) -> "Connection": - conn.call_timeout = query_timeout * 1000 + if query_timeout > 0: + conn.call_timeout = query_timeout * 1000 if conn.outputtypehandler is None: conn.outputtypehandler = output_type_handler return conn @@ -142,7 +143,8 @@ def convert_to_hex(value): class UDTFIngestion: def process(self, query: str): conn = create_connection() - conn.call_timeout = query_timeout * 1000 + if query_timeout > 0: + conn.call_timeout = query_timeout * 1000 if conn.outputtypehandler is None: conn.outputtypehandler = oracledb_output_type_handler cursor = conn.cursor() From 1e006cac4a5511283d5002ec40a58adc8e804edb Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 15 Sep 2025 16:14:44 -0700 Subject: [PATCH 08/16] coverage placeholder --- tests/integ/datasource/test_postgres.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integ/datasource/test_postgres.py b/tests/integ/datasource/test_postgres.py index 9d6587d28a..99088268e2 100644 --- a/tests/integ/datasource/test_postgres.py +++ b/tests/integ/datasource/test_postgres.py @@ -210,7 +210,9 @@ def test_psycopg2_driver_udtf_class_builder(): driver = Psycopg2Driver(create_postgres_connection, DBMS_TYPE.POSTGRES_DB) # Get the UDTF class with a small fetch size to test batching - UDTFClass = driver.udtf_class_builder(fetch_size=2) + UDTFClass = driver.udtf_class_builder( + fetch_size=2, session_init_statement=["SELECT pg_sleep(1)"] + ) # Instantiate the UDTF class udtf_instance = UDTFClass() From 3236293a366935c97e5307742a8a60351551e3fb Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 15 Sep 2025 16:46:32 -0700 Subject: [PATCH 09/16] increase coverage --- tests/integ/datasource/test_databricks.py | 4 +++- tests/integ/datasource/test_mysql.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/integ/datasource/test_databricks.py b/tests/integ/datasource/test_databricks.py index 8fe7d220ce..63c496f1c5 100644 --- a/tests/integ/datasource/test_databricks.py +++ b/tests/integ/datasource/test_databricks.py @@ -208,7 +208,9 @@ def local_create_databricks_connection(): def test_unit_udtf_ingestion(): dbx_driver = DatabricksDriver(create_databricks_connection, DBMS_TYPE.DATABRICKS_DB) - udtf_ingestion_class = dbx_driver.udtf_class_builder() + udtf_ingestion_class = dbx_driver.udtf_class_builder( + session_init_statement=["select 1"] + ) udtf_ingestion_instance = udtf_ingestion_class() dsp = DataSourcePartitioner( diff --git a/tests/integ/datasource/test_mysql.py b/tests/integ/datasource/test_mysql.py index 9bb2e2acb3..842f386dcf 100644 --- a/tests/integ/datasource/test_mysql.py +++ b/tests/integ/datasource/test_mysql.py @@ -265,7 +265,9 @@ def test_pymysql_driver_udtf_class_builder(): driver = PymysqlDriver(create_connection_mysql, DBMS_TYPE.MYSQL_DB) # Get the UDTF class with a small fetch size to test batching - UDTFClass = driver.udtf_class_builder(fetch_size=2) + UDTFClass = driver.udtf_class_builder( + fetch_size=2, session_init_statement=["select 1"] + ) # Instantiate the UDTF class udtf_instance = UDTFClass() From d375e5401bffc7e994b795c51d3c108d8173d1be Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 16 Sep 2025 11:52:28 -0700 Subject: [PATCH 10/16] use client cursor on session init --- .../_internal/data_source/drivers/psycopg2_driver.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py index dad13eca8a..ee9575f090 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py @@ -284,9 +284,10 @@ def process(self, query: str): f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}" ) if session_init_statement is not None: + session_init_cur = conn.cursor() for statement in session_init_statement: - cursor.execute(statement) - cursor.fetchall() + session_init_cur.execute(statement) + session_init_cur.fetchall() cursor.execute(query) while True: rows = cursor.fetchmany(fetch_size) From a6593a43cc91af182b4914439620c18a4d47fa85 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 16 Sep 2025 14:03:16 -0700 Subject: [PATCH 11/16] add coverage --- tests/integ/datasource/test_oracledb.py | 29 +++++++++++++++++++ tests/integ/test_data_source_api.py | 21 ++++++++++++++ .../test_data_source_data.py | 6 ++++ 3 files changed, 56 insertions(+) diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index 5a71822588..696ac03767 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -305,3 +305,32 @@ def create_connection_udtf_oracledb(): session_init_statement=[statement], udtf_configs=udtf_configs, ).collect() + + +def test_oracledb_driver_udtf_class_builder(): + """Test the UDTF class builder in Psycopg2Driver using a real PostgreSQL connection""" + # Create the driver with the real connection function + driver = OracledbDriver(create_connection_oracledb, DBMS_TYPE.ORACLE_DB) + + # Get the UDTF class with a small fetch size to test batching + UDTFClass = driver.udtf_class_builder( + fetch_size=2, session_init_statement=["select 1 from dual"], query_timeout=1 + ) + + # Instantiate the UDTF class + udtf_instance = UDTFClass() + + # Test with a simple query that should return a few rows + test_query = f"SELECT * FROM {ORACLEDB_TABLE_NAME}" + result_rows = list(udtf_instance.process(test_query)) + + # Verify we got some data back (we know the test table has data from other tests) + assert len(result_rows) > 0 + + # Test with a query that returns specific columns + test_columns_query = f"SELECT ID, NUMBER_COL FROM {ORACLEDB_TABLE_NAME}" + column_result_rows = list(udtf_instance.process(test_columns_query)) + + # Verify we got data with the right structure (2 columns) + assert len(column_result_rows) > 0 + assert len(column_result_rows[0]) == 2 # Two columns diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 2798acae0c..a53856de55 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -1697,3 +1697,24 @@ def test_error_in_upload_is_raised(session): create_connection=sql_server_create_connection, table=SQL_SERVER_TABLE_NAME, ) + + +def test_pyodbc_driver_udtf_class_builder(): + """Test the UDTF class builder in Psycopg2Driver using a real PostgreSQL connection""" + # Create the driver with the real connection function + driver = PyodbcDriver(sql_server_create_connection, DBMS_TYPE.ORACLE_DB) + + # Get the UDTF class with a small fetch size to test batching + UDTFClass = driver.udtf_class_builder( + fetch_size=2, session_init_statement=["select 1"], query_timeout=1 + ) + + # Instantiate the UDTF class + udtf_instance = UDTFClass() + + # Test with a simple query that should return a few rows + test_query = f"SELECT * FROM {SQL_SERVER_TABLE_NAME}" + result_rows = list(udtf_instance.process(test_query)) + + # Verify we got some data back (we know the test table has data from other tests) + assert len(result_rows) > 0 diff --git a/tests/resources/test_data_source_dir/test_data_source_data.py b/tests/resources/test_data_source_dir/test_data_source_data.py index 2d7e7a3949..6d42ad08ee 100644 --- a/tests/resources/test_data_source_dir/test_data_source_data.py +++ b/tests/resources/test_data_source_dir/test_data_source_data.py @@ -49,6 +49,12 @@ def cursor(self): def close(self): pass + def get_output_converter(self, type): + pass + + def add_output_converter(self, type1, type2): + pass + @property def description(self): return self.schema From eb5ec661c3358e86283cce36e33402f70a867d6d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 18 Sep 2025 14:30:23 -0700 Subject: [PATCH 12/16] add test --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 5514898a7e..a6a9e7dab9 100644 --- a/tox.ini +++ b/tox.ini @@ -248,6 +248,7 @@ deps = oracledb psycopg2-binary pymysql + pyodbc commands = {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE}" {posargs:} tests/integ/datasource -n 8 From 027a79c7e71123219746af66641b030731b0229b Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 18 Sep 2025 15:38:30 -0700 Subject: [PATCH 13/16] fix test --- tests/integ/datasource/test_oracledb.py | 2 +- tests/integ/test_data_source_api.py | 21 --------------------- tox.ini | 1 - 3 files changed, 1 insertion(+), 23 deletions(-) diff --git a/tests/integ/datasource/test_oracledb.py b/tests/integ/datasource/test_oracledb.py index 696ac03767..6b5178590c 100644 --- a/tests/integ/datasource/test_oracledb.py +++ b/tests/integ/datasource/test_oracledb.py @@ -308,7 +308,7 @@ def create_connection_udtf_oracledb(): def test_oracledb_driver_udtf_class_builder(): - """Test the UDTF class builder in Psycopg2Driver using a real PostgreSQL connection""" + """Test the UDTF class builder in OracledbDriver using a real Oracledb connection""" # Create the driver with the real connection function driver = OracledbDriver(create_connection_oracledb, DBMS_TYPE.ORACLE_DB) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index a53856de55..2798acae0c 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -1697,24 +1697,3 @@ def test_error_in_upload_is_raised(session): create_connection=sql_server_create_connection, table=SQL_SERVER_TABLE_NAME, ) - - -def test_pyodbc_driver_udtf_class_builder(): - """Test the UDTF class builder in Psycopg2Driver using a real PostgreSQL connection""" - # Create the driver with the real connection function - driver = PyodbcDriver(sql_server_create_connection, DBMS_TYPE.ORACLE_DB) - - # Get the UDTF class with a small fetch size to test batching - UDTFClass = driver.udtf_class_builder( - fetch_size=2, session_init_statement=["select 1"], query_timeout=1 - ) - - # Instantiate the UDTF class - udtf_instance = UDTFClass() - - # Test with a simple query that should return a few rows - test_query = f"SELECT * FROM {SQL_SERVER_TABLE_NAME}" - result_rows = list(udtf_instance.process(test_query)) - - # Verify we got some data back (we know the test table has data from other tests) - assert len(result_rows) > 0 diff --git a/tox.ini b/tox.ini index a6a9e7dab9..5514898a7e 100644 --- a/tox.ini +++ b/tox.ini @@ -248,7 +248,6 @@ deps = oracledb psycopg2-binary pymysql - pyodbc commands = {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE}" {posargs:} tests/integ/datasource -n 8 From 917c0a893ae5e4c7b180d4f3170ade3b08d9a109 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 22 Sep 2025 10:36:29 -0700 Subject: [PATCH 14/16] add coverage --- tests/integ/test_data_source_api.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 2798acae0c..08f121a58e 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -1697,3 +1697,29 @@ def test_error_in_upload_is_raised(session): create_connection=sql_server_create_connection, table=SQL_SERVER_TABLE_NAME, ) + + +def test_base_driver_udtf_class_builder(): + with tempfile.TemporaryDirectory() as temp_dir: + dbpath = os.path.join(temp_dir, "testsqlite3.db") + table_name, columns, example_data, _ = sqlite3_db(dbpath) + # Create the driver with the real connection function + driver = BaseDriver( + functools.partial(create_connection_to_sqlite3_db, dbpath), + DBMS_TYPE.UNKNOWN, + ) + + # Get the UDTF class with a small fetch size to test batching + UDTFClass = driver.udtf_class_builder( + fetch_size=2, session_init_statement=["select 1"] + ) + + # Instantiate the UDTF class + udtf_instance = UDTFClass() + + # Test with a simple query that should return a few rows + test_query = f"SELECT * FROM {table_name}" + result_rows = list(udtf_instance.process(test_query)) + + # Verify we got some data back (we know the test table has data from other tests) + assert len(result_rows) > 0 From 03d19d3ccdf56a977ddb2e85243d803b768074cf Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 22 Sep 2025 13:25:33 -0700 Subject: [PATCH 15/16] fix test --- tests/integ/test_data_source_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 08f121a58e..5a4fd4dcb0 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -1701,7 +1701,7 @@ def test_error_in_upload_is_raised(session): def test_base_driver_udtf_class_builder(): with tempfile.TemporaryDirectory() as temp_dir: - dbpath = os.path.join(temp_dir, "testsqlite3.db") + dbpath = os.path.join(temp_dir, "sqlite3udtf.db") table_name, columns, example_data, _ = sqlite3_db(dbpath) # Create the driver with the real connection function driver = BaseDriver( From ed3131b9b125045ef87b2406a0cbfb31c0d64f14 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 22 Sep 2025 14:20:07 -0700 Subject: [PATCH 16/16] fix test --- tests/integ/test_data_source_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py index 5a4fd4dcb0..5d7d10de84 100644 --- a/tests/integ/test_data_source_api.py +++ b/tests/integ/test_data_source_api.py @@ -1699,6 +1699,10 @@ def test_error_in_upload_is_raised(session): ) +@pytest.mark.skipif( + IS_WINDOWS, + reason="sqlite3 file can not be shared across processes on windows", +) def test_base_driver_udtf_class_builder(): with tempfile.TemporaryDirectory() as temp_dir: dbpath = os.path.join(temp_dir, "sqlite3udtf.db")