Skip to content

Commit 013081f

Browse files
committed
tests: add TVF test cases
1 parent a9bb62d commit 013081f

5 files changed

Lines changed: 1049 additions & 0 deletions

File tree

tests/fast/tvf/test_arrow.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from typing import Iterator
2+
3+
import pytest
4+
5+
import duckdb
6+
from duckdb.functional import PythonTVFType
7+
8+
9+
def simple_generator(count: int = 10) -> Iterator[tuple[str, int]]:
10+
for i in range(count):
11+
yield (f"name_{i}", i)
12+
13+
14+
def simple_arrow_table(count: int):
15+
import pyarrow as pa
16+
17+
data = {
18+
"id": list(range(count)),
19+
"value": [i * 2 for i in range(count)],
20+
"name": [f"row_{i}" for i in range(count)],
21+
}
22+
return pa.table(data)
23+
24+
25+
def test_arrow_small(tmp_path):
26+
pa = pytest.importorskip("pyarrow")
27+
28+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
29+
conn.create_table_function(
30+
"simple_arrow",
31+
simple_arrow_table,
32+
schema=[("x", "BIGINT"), ("y", "VARCHAR")], # Wrong schema!
33+
type=PythonTVFType.ARROW_TABLE,
34+
)
35+
36+
with pytest.raises(Exception) as exc_info:
37+
result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall()
38+
39+
assert (
40+
"Vector::Reference" in str(exc_info.value)
41+
or "schema" in str(exc_info.value).lower()
42+
)
43+
44+
45+
def test_arrow_large_1(tmp_path):
46+
pa = pytest.importorskip("pyarrow")
47+
48+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
49+
n = 2048 * 1000
50+
51+
conn.create_table_function(
52+
"large_arrow",
53+
simple_arrow_table,
54+
schema=[("id", "BIGINT"), ("value", "BIGINT"), ("name", "VARCHAR")],
55+
type="arrow_table",
56+
)
57+
58+
result = conn.execute(
59+
"SELECT COUNT(*) FROM large_arrow(?)", parameters=(n,)
60+
).fetchone()
61+
assert result[0] == n
62+
63+
df = conn.sql(f"SELECT * FROM large_arrow({n}) LIMIT 10").df()
64+
assert len(df) == 10
65+
assert df["id"].tolist() == list(range(10))
66+
67+
arrow_result = conn.execute(
68+
"SELECT * FROM large_arrow(?)", parameters=(n,)
69+
).fetch_arrow_table()
70+
assert len(arrow_result) == n
71+
72+
result = conn.sql(
73+
"SELECT SUM(value) FROM large_arrow(?)", params=(n,)
74+
).fetchone()
75+
expected_sum = sum(i * 2 for i in range(n))
76+
assert result[0] == expected_sum
77+
78+
79+
def test_large_arrow_execute(tmp_path):
80+
pytest.importorskip("pyarrow")
81+
82+
count = 2048 * 1000
83+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
84+
schema = [["name", "VARCHAR"], ["id", "INT"]]
85+
86+
conn.create_table_function(
87+
name="gen_function",
88+
callable=simple_generator,
89+
parameters=None,
90+
schema=schema,
91+
type="tuples",
92+
)
93+
94+
result = conn.execute(
95+
"SELECT * FROM gen_function(?)",
96+
parameters=(count,),
97+
).fetch_arrow_table()
98+
99+
assert len(result) == count
100+
101+
102+
def test_large_arrow_sql(tmp_path):
103+
pytest.importorskip("pyarrow")
104+
105+
count = 2048 * 1000
106+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
107+
schema = [["name", "VARCHAR"], ["id", "INT"]]
108+
109+
conn.create_table_function(
110+
name="gen_function",
111+
callable=simple_generator,
112+
parameters=None,
113+
schema=schema,
114+
type="tuples",
115+
)
116+
117+
result = conn.sql(
118+
"SELECT * FROM gen_function(?)",
119+
params=(count,),
120+
).fetch_arrow_table()
121+
122+
assert len(result) == count
123+
124+
125+
def test_arrowbatched_execute(tmp_path):
126+
pytest.importorskip("pyarrow")
127+
128+
count = 2048 * 1000
129+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
130+
schema = [["name", "VARCHAR"], ["id", "INT"]]
131+
132+
conn.create_table_function(
133+
name="gen_function",
134+
callable=simple_generator,
135+
parameters=None,
136+
schema=schema,
137+
type="tuples",
138+
)
139+
140+
result = conn.execute(
141+
"SELECT * FROM gen_function(?)",
142+
parameters=(count,),
143+
).fetch_record_batch()
144+
145+
result = conn.execute(
146+
f"SELECT * FROM gen_function({count})",
147+
).fetch_record_batch()
148+
149+
c = 0
150+
for batch in result:
151+
c += batch.num_rows
152+
assert c == count
153+
154+
155+
def test_arrowbatched_sql_relation(tmp_path):
156+
pytest.importorskip("pyarrow")
157+
158+
count = 2048 * 1000
159+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
160+
schema = [["name", "VARCHAR"], ["id", "INT"]]
161+
162+
conn.create_table_function(
163+
name="gen_function",
164+
callable=simple_generator,
165+
parameters=None,
166+
schema=schema,
167+
type="tuples",
168+
)
169+
170+
result = conn.sql(
171+
f"SELECT * FROM gen_function({count})",
172+
).fetch_arrow_reader()
173+
174+
c = 0
175+
for batch in result:
176+
c += batch.num_rows
177+
assert c == count
178+
179+
180+
def test_arrowbatched_sql_materialized(tmp_path):
181+
pytest.importorskip("pyarrow")
182+
183+
count = 2048 * 1000
184+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
185+
schema = [["name", "VARCHAR"], ["id", "INT"]]
186+
187+
conn.create_table_function(
188+
name="gen_function",
189+
callable=simple_generator,
190+
parameters=None,
191+
schema=schema,
192+
type="tuples",
193+
)
194+
195+
# passing parameters makes it non-lazy /materialized
196+
result = conn.sql(
197+
"SELECT * FROM gen_function(?)",
198+
params=(count,),
199+
).fetch_arrow_reader()
200+
201+
c = 0
202+
for batch in result:
203+
c += batch.num_rows
204+
assert c == count
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Test Arrow TVF schema validation"""
2+
3+
import pytest
4+
5+
import duckdb
6+
from duckdb.functional import PythonTVFType
7+
8+
9+
def simple_arrow_table(count: int = 10):
10+
import pyarrow as pa
11+
12+
data = {
13+
"id": list(range(count)),
14+
"value": [i * 2 for i in range(count)],
15+
"name": [f"row_{i}" for i in range(count)],
16+
}
17+
return pa.table(data)
18+
19+
20+
def test_arrow_correct_schema(tmp_path):
21+
pa = pytest.importorskip("pyarrow")
22+
23+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
24+
conn.create_table_function(
25+
"arrow_func",
26+
simple_arrow_table,
27+
schema=[("id", "BIGINT"), ("value", "BIGINT"), ("name", "VARCHAR")],
28+
type=PythonTVFType.ARROW_TABLE,
29+
)
30+
31+
result = conn.execute("SELECT * FROM arrow_func(5)").fetchall()
32+
assert len(result) == 5
33+
assert result[0] == (0, 0, "row_0")
34+
35+
36+
def test_arrow_more_columns(tmp_path):
37+
pa = pytest.importorskip("pyarrow")
38+
39+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
40+
# table has 3 cols, but declare only 2
41+
conn.create_table_function(
42+
"arrow_func",
43+
simple_arrow_table,
44+
schema=[("x", "BIGINT"), ("y", "BIGINT")], # Missing third column
45+
type=PythonTVFType.ARROW_TABLE,
46+
)
47+
48+
with pytest.raises(duckdb.InvalidInputException) as exc_info:
49+
conn.execute("SELECT * FROM arrow_func(5)").fetchall()
50+
51+
error_msg = str(exc_info.value).lower()
52+
assert (
53+
"schema mismatch" in error_msg
54+
or "3 columns" in error_msg
55+
or "2 were declared" in error_msg
56+
)
57+
58+
59+
def test_arrow_fewer_columns(tmp_path):
60+
pa = pytest.importorskip("pyarrow")
61+
62+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
63+
# table has 3 columns, but declare 4
64+
conn.create_table_function(
65+
"arrow_func",
66+
simple_arrow_table,
67+
schema=[
68+
("id", "BIGINT"),
69+
("value", "BIGINT"),
70+
("name", "VARCHAR"),
71+
("extra", "INT"), # Extra column that doesn't exist
72+
],
73+
type=PythonTVFType.ARROW_TABLE,
74+
)
75+
76+
with pytest.raises(duckdb.InvalidInputException) as exc_info:
77+
conn.execute("SELECT * FROM arrow_func(5)").fetchall()
78+
79+
error_msg = str(exc_info.value).lower()
80+
assert (
81+
"schema mismatch" in error_msg
82+
or "3 columns" in error_msg
83+
or "4 were declared" in error_msg
84+
)
85+
86+
87+
def test_arrow_type_mismatch(tmp_path):
88+
pa = pytest.importorskip("pyarrow")
89+
90+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
91+
conn.create_table_function(
92+
"arrow_func",
93+
simple_arrow_table,
94+
schema=[
95+
("id", "VARCHAR"), # Wrong type - should be BIGINT
96+
("value", "BIGINT"),
97+
("name", "VARCHAR"),
98+
],
99+
type=PythonTVFType.ARROW_TABLE,
100+
)
101+
102+
with pytest.raises(duckdb.InvalidInputException) as exc_info:
103+
conn.execute("SELECT * FROM arrow_func(5)").fetchall()
104+
105+
error_msg = str(exc_info.value).lower()
106+
assert "type" in error_msg or "mismatch" in error_msg
107+
108+
109+
def test_arrow_name_mismatch_allowed(tmp_path):
110+
pa = pytest.importorskip("pyarrow")
111+
112+
with duckdb.connect(tmp_path / "test.duckdb") as conn:
113+
conn.create_table_function(
114+
"arrow_func",
115+
simple_arrow_table,
116+
schema=[
117+
("a", "BIGINT"), # Arrow has 'id'
118+
("b", "BIGINT"), # Arrow has 'value'
119+
("c", "VARCHAR"), # Arrow has 'name'
120+
],
121+
type=PythonTVFType.ARROW_TABLE,
122+
)
123+
124+
result = conn.execute("SELECT * FROM arrow_func(3)").fetchall()
125+
assert len(result) == 3

0 commit comments

Comments
 (0)