Skip to content

Commit ed4199a

Browse files
SNOW-2313015: add support for session init statement in udtf ingestion of dbapi (#3748)
1 parent 8601802 commit ed4199a

15 files changed

Lines changed: 319 additions & 29 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
- `create_or_replace_dynamic_table`
104104
- Added a new function `snowflake.snowpark.functions.vectorized` that allows users to mark a function as vectorized UDF.
105105
- Added support for parameter `use_vectorized_scanner` in function `Session.write_pandas()`.
106+
- Added support for parameter `session_init_statement` in udtf ingestion of `DataFrameReader.jdbc`(PrPr).
106107
- Added support for the following scalar functions in `functions.py`:
107108
- `getdate`
108109
- `getvariable`
@@ -117,6 +118,8 @@
117118

118119
#### Bug Fixes
119120

121+
- Fixed a bug that `query_timeout` does not work in udtf ingestion of `DataFrameReader.jdbc`(PrPr).
122+
120123
#### Deprecations
121124

122125
- Deprecated warnings will be triggered when using snowpark-python with Python 3.9. For more details, please refer to https://docs.snowflake.com/en/developer-guide/python-runtime-support-policy.

src/snowflake/snowpark/_internal/data_source/datasource_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def read(self, partition: str) -> Iterator[List[Any]]:
6060
cursor.execute(statement)
6161
except BaseException as exc:
6262
raise SnowparkDataframeReaderException(
63-
f"Failed to execute session init statement: '{statement}' due to exception '{exc!r}'"
63+
f"Failed to execute session init statement: '{statement}' due to exception '{exc}'"
6464
)
6565
# use server side cursor to fetch data if supported by the driver
6666
# some drivers do not support execute twice on server side cursor (e.g. psycopg2)

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,21 @@ def udtf_ingestion(
139139
fetch_size: int = 1000,
140140
imports: Optional[List[str]] = None,
141141
packages: Optional[List[str]] = None,
142+
session_init_statement: Optional[List[str]] = None,
143+
query_timeout: Optional[int] = 0,
142144
_emit_ast: bool = True,
143145
) -> "snowflake.snowpark.DataFrame":
144146
from snowflake.snowpark._internal.data_source.utils import UDTF_PACKAGE_MAP
145147

146148
udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
147149
with measure_time() as udtf_register_time:
148150
session.udtf.register(
149-
self.udtf_class_builder(fetch_size=fetch_size, schema=schema),
151+
self.udtf_class_builder(
152+
fetch_size=fetch_size,
153+
schema=schema,
154+
session_init_statement=session_init_statement,
155+
query_timeout=query_timeout,
156+
),
150157
name=udtf_name,
151158
output_schema=StructType(
152159
[
@@ -166,14 +173,22 @@ def udtf_ingestion(
166173
return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast)
167174

168175
def udtf_class_builder(
169-
self, fetch_size: int = 1000, schema: StructType = None
176+
self,
177+
fetch_size: int = 1000,
178+
schema: StructType = None,
179+
session_init_statement: List[str] = None,
180+
query_timeout: int = 0,
170181
) -> type:
171182
create_connection = self.create_connection
183+
prepare_connection = self.prepare_connection
172184

173185
class UDTFIngestion:
174186
def process(self, query: str):
175-
conn = create_connection()
187+
conn = prepare_connection(create_connection(), query_timeout)
176188
cursor = conn.cursor()
189+
if session_init_statement is not None:
190+
for statement in session_init_statement:
191+
cursor.execute(statement)
177192
cursor.execute(query)
178193
while True:
179194
rows = cursor.fetchmany(fetch_size)

src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,21 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
6969
return StructType(all_columns)
7070

7171
def udtf_class_builder(
72-
self, fetch_size: int = 1000, schema: StructType = None
72+
self,
73+
fetch_size: int = 1000,
74+
schema: StructType = None,
75+
session_init_statement: List[str] = None,
76+
query_timeout: int = 0,
7377
) -> type:
7478
create_connection = self.create_connection
7579

7680
class UDTFIngestion:
7781
def process(self, query: str):
7882
conn = create_connection()
7983
cursor = conn.cursor()
84+
if session_init_statement is not None:
85+
for statement in session_init_statement:
86+
cursor.execute(statement)
8087

8188
# First get schema information
8289
describe_query = f"DESCRIBE QUERY SELECT * FROM ({query})"

src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,18 @@ def prepare_connection(
105105
conn: "Connection",
106106
query_timeout: int = 0,
107107
) -> "Connection":
108-
conn.call_timeout = query_timeout * 1000
108+
if query_timeout > 0:
109+
conn.call_timeout = query_timeout * 1000
109110
if conn.outputtypehandler is None:
110111
conn.outputtypehandler = output_type_handler
111112
return conn
112113

113114
def udtf_class_builder(
114-
self, fetch_size: int = 1000, schema: StructType = None
115+
self,
116+
fetch_size: int = 1000,
117+
schema: StructType = None,
118+
session_init_statement: List[str] = None,
119+
query_timeout: int = 0,
115120
) -> type:
116121
create_connection = self.create_connection
117122

@@ -138,9 +143,14 @@ def convert_to_hex(value):
138143
class UDTFIngestion:
139144
def process(self, query: str):
140145
conn = create_connection()
146+
if query_timeout > 0:
147+
conn.call_timeout = query_timeout * 1000
141148
if conn.outputtypehandler is None:
142149
conn.outputtypehandler = oracledb_output_type_handler
143150
cursor = conn.cursor()
151+
if session_init_statement is not None:
152+
for statement in session_init_statement:
153+
cursor.execute(statement)
144154
cursor.execute(query)
145155
while True:
146156
rows = cursor.fetchmany(fetch_size)

src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ def prepare_connection(
253253
return conn
254254

255255
def udtf_class_builder(
256-
self, fetch_size: int = 1000, schema: StructType = None
256+
self,
257+
fetch_size: int = 1000,
258+
schema: StructType = None,
259+
session_init_statement: List[str] = None,
260+
query_timeout: int = 0,
257261
) -> type:
258262
create_connection = self.create_connection
259263

@@ -275,10 +279,15 @@ def prepare_connection_in_udtf(
275279

276280
class UDTFIngestion:
277281
def process(self, query: str):
278-
conn = prepare_connection_in_udtf(create_connection())
282+
conn = prepare_connection_in_udtf(create_connection(), query_timeout)
279283
cursor = conn.cursor(
280284
f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}"
281285
)
286+
if session_init_statement is not None:
287+
session_init_cur = conn.cursor()
288+
for statement in session_init_statement:
289+
session_init_cur.execute(statement)
290+
session_init_cur.fetchall()
282291
cursor.execute(query)
283292
while True:
284293
rows = cursor.fetchmany(fetch_size)

src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
184184
return StructType(fields)
185185

186186
def udtf_class_builder(
187-
self, fetch_size: int = 1000, schema: StructType = None
187+
self,
188+
fetch_size: int = 1000,
189+
schema: StructType = None,
190+
session_init_statement: List[str] = None,
191+
query_timeout: int = 0,
188192
) -> type:
189193
create_connection = self.create_connection
190194

@@ -194,6 +198,9 @@ def process(self, query: str):
194198

195199
conn = create_connection()
196200
cursor = pymysql.cursors.SSCursor(conn)
201+
if session_init_statement is not None:
202+
for statement in session_init_statement:
203+
cursor.execute(statement)
197204
cursor.execute(query)
198205
while True:
199206
rows = cursor.fetchmany(fetch_size)
@@ -203,14 +210,6 @@ def process(self, query: str):
203210

204211
return UDTFIngestion
205212

206-
def prepare_connection(
207-
self,
208-
conn: "Connection",
209-
query_timeout: int = 0,
210-
) -> "Connection":
211-
conn.read_timeout = query_timeout if query_timeout != 0 else None
212-
return conn
213-
214213
@staticmethod
215214
def infer_type_from_data(data: List[tuple], number_of_columns: int) -> List[Type]:
216215
# TODO: SNOW-2112938 investigate whether different types can be fit into one column

src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,14 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
7878
return StructType(fields)
7979

8080
def udtf_class_builder(
81-
self, fetch_size: int = 1000, schema: StructType = None
81+
self,
82+
fetch_size: int = 1000,
83+
schema: StructType = None,
84+
session_init_statement: List[str] = None,
85+
query_timeout: int = 0,
8286
) -> type:
8387
create_connection = self.create_connection
88+
prepare_connection = self.prepare_connection
8489

8590
def binary_converter(value):
8691
return value.hex() if value is not None else None
@@ -89,7 +94,7 @@ class UDTFIngestion:
8994
def process(self, query: str):
9095
import pyodbc
9196

92-
conn = create_connection()
97+
conn = prepare_connection(create_connection(), query_timeout)
9398
if (
9499
conn.get_output_converter(pyodbc.SQL_BINARY) is None
95100
and conn.get_output_converter(pyodbc.SQL_VARBINARY) is None
@@ -101,6 +106,9 @@ def process(self, query: str):
101106
pyodbc.SQL_LONGVARBINARY, binary_converter
102107
)
103108
cursor = conn.cursor()
109+
if session_init_statement is not None:
110+
for statement in session_init_statement:
111+
cursor.execute(statement)
104112
cursor.execute(query)
105113
while True:
106114
rows = cursor.fetchmany(fetch_size)

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,8 @@ def create_oracledb_connection():
18591859
fetch_size=fetch_size,
18601860
imports=udtf_configs.get("imports", None),
18611861
packages=udtf_configs.get("packages", None),
1862+
session_init_statement=session_init_statement,
1863+
query_timeout=query_timeout,
18621864
_emit_ast=_emit_ast,
18631865
)
18641866
end_time = time.perf_counter()

tests/integ/datasource/test_databricks.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
random_name_for_temp_object,
1818
TempObjectType,
1919
)
20-
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
20+
from snowflake.snowpark.exceptions import (
21+
SnowparkDataframeReaderException,
22+
SnowparkSQLException,
23+
)
2124
from snowflake.snowpark.types import (
2225
StructType,
2326
StructField,
@@ -205,7 +208,9 @@ def local_create_databricks_connection():
205208

206209
def test_unit_udtf_ingestion():
207210
dbx_driver = DatabricksDriver(create_databricks_connection, DBMS_TYPE.DATABRICKS_DB)
208-
udtf_ingestion_class = dbx_driver.udtf_class_builder()
211+
udtf_ingestion_class = dbx_driver.udtf_class_builder(
212+
session_init_statement=["select 1"]
213+
)
209214
udtf_ingestion_instance = udtf_ingestion_class()
210215

211216
dsp = DataSourcePartitioner(
@@ -258,3 +263,37 @@ def test_unsupported_type():
258263
create_databricks_connection, DBMS_TYPE.DATABRICKS_DB
259264
).to_snow_type([("test_col", "unsupported_type", True)])
260265
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])
266+
267+
268+
def test_session_init(session):
269+
with pytest.raises(
270+
SnowparkDataframeReaderException,
271+
match="syntax error command",
272+
):
273+
session.read.dbapi(
274+
create_databricks_connection,
275+
table=TEST_TABLE_NAME,
276+
session_init_statement=["syntax error command"],
277+
)
278+
279+
280+
def test_session_init_udtf(session):
281+
udtf_configs = {
282+
"external_access_integration": DATABRICKS_TEST_EXTERNAL_ACCESS_INTEGRATION
283+
}
284+
285+
def create_databricks_udtf_connection():
286+
import databricks.sql
287+
288+
return databricks.sql.connect(**DATABRICKS_CONNECTION_PARAMETERS)
289+
290+
with pytest.raises(
291+
SnowparkSQLException,
292+
match="syntax error command",
293+
):
294+
session.read.dbapi(
295+
create_databricks_udtf_connection,
296+
table=TEST_TABLE_NAME,
297+
session_init_statement=["syntax error command"],
298+
udtf_configs=udtf_configs,
299+
).collect()

0 commit comments

Comments
 (0)