Skip to content

Commit 344ffc0

Browse files
SNOW-2241912: add test coverage of dbapi on stored procedure (#3739)
1 parent 5573dc2 commit 344ffc0

1 file changed

Lines changed: 399 additions & 0 deletions

File tree

Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import datetime
6+
import functools
7+
import os
8+
import tempfile
9+
10+
import pytest
11+
import sqlite3
12+
13+
from snowflake.snowpark import Row
14+
from tests.utils import RUNNING_ON_GH
15+
16+
SQLITE3_DB_CUSTOM_SCHEMA_STRING = """
17+
id INTEGER,
18+
int_col INTEGER,
19+
real_col FLOAT,
20+
text_col STRING,
21+
blob_col BINARY,
22+
null_col STRING,
23+
ts_col TIMESTAMP_NTZ,
24+
date_col DATE,
25+
time_col TIME,
26+
short_col SHORT,
27+
long_col LONG,
28+
double_col DOUBLE,
29+
decimal_col DECIMAL,
30+
map_col MAP,
31+
array_col ARRAY,
32+
var_col VARIANT
33+
"""
34+
35+
36+
def create_connection_to_sqlite3_db(db_path):
37+
return sqlite3.connect(db_path)
38+
39+
40+
def sqlite3_db(db_path):
41+
conn = create_connection_to_sqlite3_db(db_path)
42+
cursor = conn.cursor()
43+
table_name = "PrimitiveTypes"
44+
columns = [
45+
"id",
46+
"int_col",
47+
"real_col",
48+
"text_col",
49+
"blob_col",
50+
"null_col",
51+
"ts_col",
52+
"date_col",
53+
"time_col",
54+
"short_col",
55+
"long_col",
56+
"double_col",
57+
"decimal_col",
58+
"map_col",
59+
"array_col",
60+
"var_col",
61+
]
62+
# Create a table with different primitive types
63+
# sqlite3 only supports 5 types: NULL, INTEGER, REAL, TEXT, BLOB
64+
cursor.execute(
65+
f"""
66+
CREATE TABLE IF NOT EXISTS {table_name} (
67+
id INTEGER PRIMARY KEY, -- Auto-incrementing primary key
68+
int_col INTEGER, -- Integer column
69+
real_col REAL, -- Floating point column
70+
text_col TEXT, -- String column
71+
blob_col BLOB, -- Binary data column
72+
null_col NULL, -- Explicit NULL type (for testing purposes)
73+
ts_col TEXT, -- Timestamp column in TEXT format
74+
date_col TEXT, -- Date column in TEXT format
75+
time_col TEXT, -- Time column in TEXT format
76+
short_col INTEGER, -- Short integer column
77+
long_col INTEGER, -- Long integer column
78+
double_col REAL, -- Double column
79+
decimal_col REAL, -- Decimal column
80+
map_col TEXT, -- Map column in TEXT format
81+
array_col TEXT, -- Array column in TEXT format
82+
var_col TEXT -- Variant column in TEXT format
83+
)
84+
"""
85+
)
86+
test_datetime = datetime.datetime(2021, 1, 2, 12, 34, 56)
87+
test_date = test_datetime.date()
88+
test_time = test_datetime.time()
89+
example_data = [
90+
(
91+
1,
92+
42,
93+
3.14,
94+
"Hello, world!",
95+
b"\x00\x01\x02\x03",
96+
None,
97+
test_datetime.isoformat(),
98+
test_date.isoformat(),
99+
test_time.isoformat(),
100+
1,
101+
2,
102+
3.0,
103+
4.0,
104+
'{"a": 1, "b": 2}',
105+
"[1, 2, 3]",
106+
"1",
107+
),
108+
(
109+
2,
110+
-10,
111+
2.718,
112+
"SQLite",
113+
b"\x04\x05\x06\x07",
114+
None,
115+
test_datetime.isoformat(),
116+
test_date.isoformat(),
117+
test_time.isoformat(),
118+
1,
119+
2,
120+
3.0,
121+
4.0,
122+
'{"a": 1, "b": 2}',
123+
"[1, 2, 3]",
124+
"2",
125+
),
126+
(
127+
3,
128+
9999,
129+
-0.99,
130+
"Python",
131+
b"\x08\x09\x0A\x0B",
132+
None,
133+
test_datetime.isoformat(),
134+
test_date.isoformat(),
135+
test_time.isoformat(),
136+
1,
137+
2,
138+
3.0,
139+
4.0,
140+
'{"a": 1, "b": 2}',
141+
"[1, 2, 3]",
142+
"3",
143+
),
144+
]
145+
assert_data = [
146+
Row(
147+
ID=1,
148+
INT_COL=42,
149+
REAL_COL=3.14,
150+
TEXT_COL="Hello, world!",
151+
BLOB_COL=bytearray(b"\x00\x01\x02\x03"),
152+
NULL_COL=None,
153+
TS_COL=datetime.datetime(2021, 1, 2, 12, 34, 56),
154+
DATE_COL=datetime.date(2021, 1, 2),
155+
TIME_COL=datetime.time(12, 34, 56),
156+
SHORT_COL=1,
157+
LONG_COL=2,
158+
DOUBLE_COL=3.0,
159+
DECIMAL_COL=4,
160+
MAP_COL='{\n "a": 1,\n "b": 2\n}',
161+
ARRAY_COL='[\n "[1, 2, 3]"\n]',
162+
VAR_COL='"1"',
163+
),
164+
Row(
165+
ID=2,
166+
INT_COL=-10,
167+
REAL_COL=2.718,
168+
TEXT_COL="SQLite",
169+
BLOB_COL=bytearray(b"\x04\x05\x06\x07"),
170+
NULL_COL=None,
171+
TS_COL=datetime.datetime(2021, 1, 2, 12, 34, 56),
172+
DATE_COL=datetime.date(2021, 1, 2),
173+
TIME_COL=datetime.time(12, 34, 56),
174+
SHORT_COL=1,
175+
LONG_COL=2,
176+
DOUBLE_COL=3.0,
177+
DECIMAL_COL=4,
178+
MAP_COL='{\n "a": 1,\n "b": 2\n}',
179+
ARRAY_COL='[\n "[1, 2, 3]"\n]',
180+
VAR_COL='"2"',
181+
),
182+
Row(
183+
ID=3,
184+
INT_COL=9999,
185+
REAL_COL=-0.99,
186+
TEXT_COL="Python",
187+
BLOB_COL=bytearray(b"\x08\t\n\x0b"),
188+
NULL_COL=None,
189+
TS_COL=datetime.datetime(2021, 1, 2, 12, 34, 56),
190+
DATE_COL=datetime.date(2021, 1, 2),
191+
TIME_COL=datetime.time(12, 34, 56),
192+
SHORT_COL=1,
193+
LONG_COL=2,
194+
DOUBLE_COL=3.0,
195+
DECIMAL_COL=4,
196+
MAP_COL='{\n "a": 1,\n "b": 2\n}',
197+
ARRAY_COL='[\n "[1, 2, 3]"\n]',
198+
VAR_COL='"3"',
199+
),
200+
]
201+
cursor.executemany(
202+
f"INSERT INTO {table_name} VALUES ({','.join('?' * 16)})", example_data
203+
)
204+
conn.commit()
205+
conn.close()
206+
return table_name, columns, example_data, assert_data
207+
208+
209+
pytestmark = [
210+
pytest.mark.skipif(
211+
"config.getoption('local_testing_mode', default=False)",
212+
reason="feature not available in local testing",
213+
),
214+
pytest.mark.skipif(
215+
RUNNING_ON_GH,
216+
reason="tests only suppose to run on snowfort",
217+
),
218+
]
219+
220+
221+
@pytest.mark.parametrize("fetch_with_process", [True, False])
222+
def test_dbapi_local(session, fetch_with_process):
223+
with tempfile.TemporaryDirectory() as temp_dir:
224+
dbpath = os.path.join(temp_dir, "testsqlite3.db")
225+
table_name, _, _, assert_data = sqlite3_db(dbpath)
226+
df = session.read.dbapi(
227+
functools.partial(create_connection_to_sqlite3_db, dbpath),
228+
table=table_name,
229+
custom_schema=SQLITE3_DB_CUSTOM_SCHEMA_STRING,
230+
fetch_size=2,
231+
fetch_merge_count=2,
232+
fetch_with_process=fetch_with_process,
233+
)
234+
assert df.order_by("ID").collect() == assert_data
235+
236+
237+
def test_dbapi_udtf(session):
238+
udtf_configs = {"external_access_integration": ""}
239+
test_datetime = datetime.datetime(2021, 1, 2, 12, 34, 56)
240+
test_date = test_datetime.date()
241+
test_time = test_datetime.time()
242+
table_name = "PrimitiveTypes"
243+
example_data = [
244+
(
245+
1,
246+
42,
247+
3.14,
248+
"Hello, world!",
249+
b"\x00\x01\x02\x03".hex(),
250+
None,
251+
test_datetime.isoformat(),
252+
test_date.isoformat(),
253+
test_time.isoformat(),
254+
1,
255+
2,
256+
3.0,
257+
4.0,
258+
'{"a": 1, "b": 2}',
259+
"[1, 2, 3]",
260+
"1",
261+
),
262+
(
263+
2,
264+
-10,
265+
2.718,
266+
"SQLite",
267+
b"\x04\x05\x06\x07".hex(),
268+
None,
269+
test_datetime.isoformat(),
270+
test_date.isoformat(),
271+
test_time.isoformat(),
272+
1,
273+
2,
274+
3.0,
275+
4.0,
276+
'{"a": 1, "b": 2}',
277+
"[1, 2, 3]",
278+
"2",
279+
),
280+
(
281+
3,
282+
9999,
283+
-0.99,
284+
"Python",
285+
b"\x08\x09\x0A\x0B".hex(),
286+
None,
287+
test_datetime.isoformat(),
288+
test_date.isoformat(),
289+
test_time.isoformat(),
290+
1,
291+
2,
292+
3.0,
293+
4.0,
294+
'{"a": 1, "b": 2}',
295+
"[1, 2, 3]",
296+
"3",
297+
),
298+
]
299+
expected_data = [
300+
Row(
301+
ID=1,
302+
INT_COL=42,
303+
REAL_COL=3.14,
304+
TEXT_COL="Hello, world!",
305+
BLOB_COL=bytearray(b"\x00\x01\x02\x03"),
306+
NULL_COL=None,
307+
TS_COL=datetime.datetime(2021, 1, 2, 12, 34, 56),
308+
DATE_COL=datetime.date(2021, 1, 2),
309+
TIME_COL=datetime.time(12, 34, 56),
310+
SHORT_COL=1,
311+
LONG_COL=2,
312+
DOUBLE_COL=3.0,
313+
DECIMAL_COL=4,
314+
MAP_COL='{\n "a": 1,\n "b": 2\n}',
315+
ARRAY_COL='[\n "[1, 2, 3]"\n]',
316+
VAR_COL='"1"',
317+
),
318+
Row(
319+
ID=2,
320+
INT_COL=-10,
321+
REAL_COL=2.718,
322+
TEXT_COL="SQLite",
323+
BLOB_COL=bytearray(b"\x04\x05\x06\x07"),
324+
NULL_COL=None,
325+
TS_COL=datetime.datetime(2021, 1, 2, 12, 34, 56),
326+
DATE_COL=datetime.date(2021, 1, 2),
327+
TIME_COL=datetime.time(12, 34, 56),
328+
SHORT_COL=1,
329+
LONG_COL=2,
330+
DOUBLE_COL=3.0,
331+
DECIMAL_COL=4,
332+
MAP_COL='{\n "a": 1,\n "b": 2\n}',
333+
ARRAY_COL='[\n "[1, 2, 3]"\n]',
334+
VAR_COL='"2"',
335+
),
336+
Row(
337+
ID=3,
338+
INT_COL=9999,
339+
REAL_COL=-0.99,
340+
TEXT_COL="Python",
341+
BLOB_COL=bytearray(b"\x08\t\n\x0b"),
342+
NULL_COL=None,
343+
TS_COL=datetime.datetime(2021, 1, 2, 12, 34, 56),
344+
DATE_COL=datetime.date(2021, 1, 2),
345+
TIME_COL=datetime.time(12, 34, 56),
346+
SHORT_COL=1,
347+
LONG_COL=2,
348+
DOUBLE_COL=3.0,
349+
DECIMAL_COL=4,
350+
MAP_COL='{\n "a": 1,\n "b": 2\n}',
351+
ARRAY_COL='[\n "[1, 2, 3]"\n]',
352+
VAR_COL='"3"',
353+
),
354+
]
355+
356+
def create_connection_sqlite3():
357+
import sqlite3
358+
359+
conn = sqlite3.connect(":memory:")
360+
cursor = conn.cursor()
361+
# Create a table with different primitive types
362+
# sqlite3 only supports 5 types: NULL, INTEGER, REAL, TEXT, BLOB
363+
cursor.execute(
364+
f"""
365+
CREATE TABLE IF NOT EXISTS {table_name} (
366+
id INTEGER PRIMARY KEY, -- Auto-incrementing primary key
367+
int_col INTEGER, -- Integer column
368+
real_col REAL, -- Floating point column
369+
text_col TEXT, -- String column
370+
blob_col BLOB, -- Binary data column
371+
null_col NULL, -- Explicit NULL type (for testing purposes)
372+
ts_col TEXT, -- Timestamp column in TEXT format
373+
date_col TEXT, -- Date column in TEXT format
374+
time_col TEXT, -- Time column in TEXT format
375+
short_col INTEGER, -- Short integer column
376+
long_col INTEGER, -- Long integer column
377+
double_col REAL, -- Double column
378+
decimal_col REAL, -- Decimal column
379+
map_col TEXT, -- Map column in TEXT format
380+
array_col TEXT, -- Array column in TEXT format
381+
var_col TEXT -- Variant column in TEXT format
382+
)
383+
"""
384+
)
385+
386+
cursor.executemany(
387+
f"INSERT INTO {table_name} VALUES ({','.join('?' * 16)})", example_data
388+
)
389+
conn.commit()
390+
return conn
391+
392+
df = session.read.dbapi(
393+
create_connection_sqlite3,
394+
table=table_name,
395+
custom_schema=SQLITE3_DB_CUSTOM_SCHEMA_STRING,
396+
fetch_size=2,
397+
udtf_configs=udtf_configs,
398+
)
399+
assert df.order_by("ID").collect() == expected_data

0 commit comments

Comments
 (0)