Skip to content

Commit c017657

Browse files
committed
fix tests
1 parent b652cae commit c017657

3 files changed

Lines changed: 187 additions & 2 deletions

File tree

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import pytest
66

7+
from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE
78
from snowflake.snowpark.types import StringType
89

910
from tests.parameters import SQL_SERVER_CONNECTION_PARAMETERS
1011
from tests.utils import IS_IN_STORED_PROC, Utils
11-
from tests.resources.test_data_source_dir.test_pyodbc_data import (
12+
from tests.resources.test_data_source_dir.test_sql_server_data import (
1213
SQL_SERVER_TABLE_NAME,
1314
EXPECTED_TEST_DATA,
1415
SQL_SERVER_TEST_EXTERNAL_ACCESS_INTEGRATION,
@@ -17,6 +18,10 @@
1718
SQL_SERVER_UNICODE_SCHEMA,
1819
SQL_SEVER_UNICODE_TABLE_NAME,
1920
)
21+
from snowflake.snowpark.exceptions import (
22+
SnowparkDataframeReaderException,
23+
SnowparkSQLException,
24+
)
2025

2126
DEPENDENCIES_PACKAGE_UNAVAILABLE = True
2227
try:
@@ -188,3 +193,183 @@ def local_create_connection_sql_server():
188193
apply_order,
189194
ignore_string_size=True,
190195
)
196+
197+
198+
@pytest.mark.parametrize(
199+
"input_type, input_value, error_message, udtf_configs",
200+
[
201+
("table", "NONEXISTTABLE", "Invalid object name", None),
202+
("query", "SELEC ** FORM TABLE", "Incorrect syntax near", None),
203+
(
204+
"table",
205+
"NONEXISTTABLE",
206+
"Invalid object name",
207+
SQL_SERVER_TEST_EXTERNAL_ACCESS_INTEGRATION,
208+
),
209+
(
210+
"query",
211+
"SELEC ** FORM TABLE",
212+
"Incorrect syntax near",
213+
SQL_SERVER_TEST_EXTERNAL_ACCESS_INTEGRATION,
214+
),
215+
],
216+
)
217+
def test_error_case(session, input_type, input_value, error_message, udtf_configs):
218+
# Use local connection function when udtf_configs is provided
219+
if udtf_configs:
220+
local_parameters = SQL_SERVER_CONNECTION_PARAMETERS.copy()
221+
222+
def connection_func():
223+
return pyodbc.connect(
224+
"DRIVER=" + local_parameters["DRIVER"] + ";"
225+
"SERVER=" + local_parameters["SERVER"] + ";"
226+
"UID=" + local_parameters["UID"] + ";"
227+
"PWD=" + local_parameters["PWD"] + ";"
228+
"TrustServerCertificate="
229+
+ local_parameters["TrustServerCertificate"]
230+
+ ";"
231+
"Encrypt=" + local_parameters["Encrypt"] + ";"
232+
)
233+
234+
else:
235+
connection_func = create_connection_sql_server
236+
237+
# Prepare kwargs for dbapi call
238+
dbapi_kwargs = construct_input_dict(input_type, input_value)
239+
if udtf_configs:
240+
dbapi_kwargs["udtf_configs"] = udtf_configs
241+
242+
with pytest.raises(SnowparkDataframeReaderException, match=error_message):
243+
session.read.dbapi(connection_func, **dbapi_kwargs)
244+
245+
246+
@pytest.mark.parametrize(
247+
"udtf_configs",
248+
[
249+
None,
250+
SQL_SERVER_TEST_EXTERNAL_ACCESS_INTEGRATION,
251+
],
252+
)
253+
def test_partitions_and_predicates(session, udtf_configs):
254+
# Use local connection function when udtf_configs is provided
255+
if udtf_configs:
256+
local_parameters = SQL_SERVER_CONNECTION_PARAMETERS.copy()
257+
258+
def connection_func():
259+
return pyodbc.connect(
260+
"DRIVER=" + local_parameters["DRIVER"] + ";"
261+
"SERVER=" + local_parameters["SERVER"] + ";"
262+
"UID=" + local_parameters["UID"] + ";"
263+
"PWD=" + local_parameters["PWD"] + ";"
264+
"TrustServerCertificate="
265+
+ local_parameters["TrustServerCertificate"]
266+
+ ";"
267+
"Encrypt=" + local_parameters["Encrypt"] + ";"
268+
)
269+
270+
else:
271+
connection_func = create_connection_sql_server
272+
273+
# Prepare kwargs for dbapi call
274+
dbapi_kwargs = {
275+
"table": SQL_SERVER_TABLE_NAME,
276+
"column": "ID",
277+
"num_partitions": 3,
278+
"upper_bound": 10,
279+
"lower_bound": 0,
280+
}
281+
if udtf_configs:
282+
dbapi_kwargs["udtf_configs"] = udtf_configs
283+
284+
df = session.read.dbapi(connection_func, **dbapi_kwargs)
285+
286+
# Use ignore_string_size=True for UDTF scenarios like in other tests
287+
verify_save_table_result(
288+
session,
289+
df,
290+
EXPECTED_TEST_DATA,
291+
SQL_SERVER_SCHEMA,
292+
True,
293+
ignore_string_size=bool(udtf_configs),
294+
)
295+
296+
dbapi_kwargs = {
297+
"table": SQL_SERVER_TABLE_NAME,
298+
"predicates": ["ID < 6", "ID >= 6"],
299+
}
300+
if udtf_configs:
301+
dbapi_kwargs["udtf_configs"] = udtf_configs
302+
303+
df = session.read.dbapi(connection_func, **dbapi_kwargs)
304+
305+
verify_save_table_result(
306+
session,
307+
df,
308+
EXPECTED_TEST_DATA,
309+
SQL_SERVER_SCHEMA,
310+
True,
311+
ignore_string_size=bool(udtf_configs),
312+
)
313+
314+
315+
@pytest.mark.parametrize(
316+
"udtf_configs",
317+
[
318+
None,
319+
SQL_SERVER_TEST_EXTERNAL_ACCESS_INTEGRATION,
320+
],
321+
)
322+
def test_session_init_statement(session, udtf_configs):
323+
# Use local connection function when udtf_configs is provided
324+
if udtf_configs:
325+
local_parameters = SQL_SERVER_CONNECTION_PARAMETERS.copy()
326+
327+
def connection_func():
328+
return pyodbc.connect(
329+
"DRIVER=" + local_parameters["DRIVER"] + ";"
330+
"SERVER=" + local_parameters["SERVER"] + ";"
331+
"UID=" + local_parameters["UID"] + ";"
332+
"PWD=" + local_parameters["PWD"] + ";"
333+
"TrustServerCertificate="
334+
+ local_parameters["TrustServerCertificate"]
335+
+ ";"
336+
"Encrypt=" + local_parameters["Encrypt"] + ";"
337+
)
338+
339+
else:
340+
connection_func = create_connection_sql_server
341+
342+
# here we use a statement that will fail to verify the session init statement is executed
343+
statements = [
344+
"DECLARE @VAR1 INT;",
345+
"DECLARE @VAR2 INT;",
346+
"SET @VAR_NON_EXIST = 12345;",
347+
]
348+
349+
# Prepare kwargs for dbapi call
350+
dbapi_kwargs = {
351+
"table": SQL_SERVER_TABLE_NAME,
352+
"session_init_statement": statements,
353+
}
354+
if udtf_configs:
355+
dbapi_kwargs["udtf_configs"] = udtf_configs
356+
357+
with pytest.raises(SnowparkSQLException, match="Must declare the scalar variable"):
358+
# TODO: 2362041, UDTF error experience is different from parquet ingestion
359+
# 1. UDTF needs .collect() to trigger the error while parquet ingestion triggers on .dbapi()
360+
# 2. error exception is different
361+
session.read.dbapi(connection_func, **dbapi_kwargs).collect()
362+
363+
364+
def test_pyodbc_driver_class_builder():
365+
from snowflake.snowpark._internal.data_source.drivers.pyodbc_driver import (
366+
PyodbcDriver,
367+
)
368+
369+
driver = PyodbcDriver(create_connection_sql_server, DBMS_TYPE.SQL_SERVER_DB)
370+
udtf_class = driver.udtf_class_builder(
371+
fetch_size=2,
372+
)
373+
ingestion = udtf_class()
374+
results = list(ingestion.process(f"SELECT * FROM {SQL_SERVER_TABLE_NAME}"))
375+
assert len(results) == len(EXPECTED_TEST_DATA)

tests/resources/test_data_source_dir/test_pyodbc_data.py renamed to tests/resources/test_data_source_dir/test_sql_server_data.py

File renamed without changes.

tests/unit/scala/test_utils_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
365365
"resources/test_data_source_dir/test_databricks_data.py",
366366
"resources/test_data_source_dir/test_jdbc_data.py",
367367
"resources/test_data_source_dir/test_mysql_data.py",
368-
"resources/test_data_source_dir/test_pyodbc_data.py",
368+
"resources/test_data_source_dir/test_sql_server_data.py",
369369
"resources/test_debug_utils_dir/",
370370
"resources/test_debug_utils_dir/dataframe_generator1.py",
371371
"resources/test_debug_utils_dir/dataframe_generator2.py",

0 commit comments

Comments
 (0)