Skip to content

Commit fb5051e

Browse files
committed
Unify Arrow stream scanning via __arrow_c_stream__ and only pushdown filters if pyarrow is present
1 parent 3186810 commit fb5051e

6 files changed

Lines changed: 549 additions & 66 deletions

File tree

src/duckdb_py/arrow/arrow_array_stream.cpp

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,44 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
6666
py::handle arrow_obj_handle(factory->arrow_object);
6767
auto arrow_object_type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
6868

69+
if (arrow_object_type == PyArrowObjectType::PyCapsuleInterface) {
70+
py::object capsule_obj = arrow_obj_handle.attr("__arrow_c_stream__")();
71+
auto capsule = py::reinterpret_borrow<py::capsule>(capsule_obj);
72+
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
73+
if (!stream->release) {
74+
throw InvalidInputException(
75+
"The __arrow_c_stream__() method returned a released stream. "
76+
"If this object is single-use, implement __arrow_c_schema__() or expose a .schema attribute "
77+
"with _export_to_c() so that DuckDB can extract the schema without consuming the stream.");
78+
}
79+
80+
if (ModuleIsLoaded<PyarrowDatasetCacheItem>()) {
81+
// Tier A: full pushdown via pyarrow.dataset
82+
// Import as RecordBatchReader, feed through Scanner.from_batches for projection/filter pushdown.
83+
auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib");
84+
auto import_func = pyarrow_lib_module.attr("RecordBatchReader").attr("_import_from_c");
85+
py::object reader = import_func(reinterpret_cast<uint64_t>(stream));
86+
// _import_from_c takes ownership of the stream; null out to prevent capsule double-free
87+
stream->release = nullptr;
88+
auto &import_cache = *DuckDBPyConnection::ImportCache();
89+
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
90+
py::handle reader_handle = reader;
91+
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties);
92+
auto record_batches = scanner.attr("to_reader")();
93+
auto res = make_uniq<ArrowArrayStreamWrapper>();
94+
auto export_to_c = record_batches.attr("_export_to_c");
95+
export_to_c(reinterpret_cast<uint64_t>(&res->arrow_array_stream));
96+
return res;
97+
} else {
98+
// Tier B: no pyarrow.dataset, return raw stream (no pushdown)
99+
// DuckDB applies projection/filter post-scan via arrow_scan_dumb
100+
auto res = make_uniq<ArrowArrayStreamWrapper>();
101+
res->arrow_array_stream = *stream;
102+
stream->release = nullptr;
103+
return res;
104+
}
105+
}
106+
69107
if (arrow_object_type == PyArrowObjectType::PyCapsule) {
70108
auto res = make_uniq<ArrowArrayStreamWrapper>();
71109
auto capsule = py::reinterpret_borrow<py::capsule>(arrow_obj_handle);
@@ -78,21 +116,12 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
78116
return res;
79117
}
80118

119+
// Scanner and Dataset: require pyarrow.dataset for pushdown
120+
VerifyArrowDatasetLoaded();
81121
auto &import_cache = *DuckDBPyConnection::ImportCache();
82122
py::object scanner;
83123
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
84124
switch (arrow_object_type) {
85-
case PyArrowObjectType::Table: {
86-
auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset");
87-
auto dataset = arrow_dataset(arrow_obj_handle);
88-
py::object arrow_scanner = dataset.attr("__class__").attr("scanner");
89-
scanner = ProduceScanner(arrow_scanner, dataset, parameters, factory->client_properties);
90-
break;
91-
}
92-
case PyArrowObjectType::RecordBatchReader: {
93-
scanner = ProduceScanner(arrow_batch_scanner, arrow_obj_handle, parameters, factory->client_properties);
94-
break;
95-
}
96125
case PyArrowObjectType::Scanner: {
97126
// If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack
98127
// scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain
@@ -119,37 +148,29 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
119148
}
120149

121150
void PythonTableArrowArrayStreamFactory::GetSchemaInternal(py::handle arrow_obj_handle, ArrowSchemaWrapper &schema) {
151+
// PyCapsule (from bare capsule Produce path)
122152
if (py::isinstance<py::capsule>(arrow_obj_handle)) {
123153
auto capsule = py::reinterpret_borrow<py::capsule>(arrow_obj_handle);
124154
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
125155
if (!stream->release) {
126156
throw InternalException("ArrowArrayStream was released by another thread/library");
127157
}
128-
stream->get_schema(stream, &schema.arrow_schema);
129-
return;
130-
}
131-
132-
auto table_class = py::module::import("pyarrow").attr("Table");
133-
if (py::isinstance(arrow_obj_handle, table_class)) {
134-
auto obj_schema = arrow_obj_handle.attr("schema");
135-
auto export_to_c = obj_schema.attr("_export_to_c");
136-
export_to_c(reinterpret_cast<uint64_t>(&schema.arrow_schema));
158+
if (stream->get_schema(stream, &schema.arrow_schema)) {
159+
throw InvalidInputException("Failed to get Arrow schema from stream: %s",
160+
stream->get_last_error ? stream->get_last_error(stream) : "unknown error");
161+
}
137162
return;
138163
}
139164

165+
// Scanner: use projected_schema; everything else (RecordBatchReader, Dataset): use .schema
140166
VerifyArrowDatasetLoaded();
141-
142167
auto &import_cache = *DuckDBPyConnection::ImportCache();
143-
auto scanner_class = import_cache.pyarrow.dataset.Scanner();
144-
145-
if (py::isinstance(arrow_obj_handle, scanner_class)) {
168+
if (py::isinstance(arrow_obj_handle, import_cache.pyarrow.dataset.Scanner())) {
146169
auto obj_schema = arrow_obj_handle.attr("projected_schema");
147-
auto export_to_c = obj_schema.attr("_export_to_c");
148-
export_to_c(reinterpret_cast<uint64_t>(&schema));
170+
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
149171
} else {
150172
auto obj_schema = arrow_obj_handle.attr("schema");
151-
auto export_to_c = obj_schema.attr("_export_to_c");
152-
export_to_c(reinterpret_cast<uint64_t>(&schema));
173+
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
153174
}
154175
}
155176

@@ -158,6 +179,36 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS
158179
auto factory = static_cast<PythonTableArrowArrayStreamFactory *>(reinterpret_cast<void *>(factory_ptr)); // NOLINT
159180
D_ASSERT(factory->arrow_object);
160181
py::handle arrow_obj_handle(factory->arrow_object);
182+
183+
auto type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
184+
if (type == PyArrowObjectType::PyCapsuleInterface) {
185+
// Get __arrow_c_schema__ if it exists
186+
if (py::hasattr(arrow_obj_handle, "__arrow_c_schema__")) {
187+
auto schema_capsule = arrow_obj_handle.attr("__arrow_c_schema__")();
188+
auto capsule = py::reinterpret_borrow<py::capsule>(schema_capsule);
189+
auto arrow_schema = capsule.get_pointer<struct ArrowSchema>();
190+
schema.arrow_schema = *arrow_schema;
191+
arrow_schema->release = nullptr; // take ownership
192+
return;
193+
}
194+
// Otherwise try to use .schema with _export_to_c
195+
if (py::hasattr(arrow_obj_handle, "schema")) {
196+
auto obj_schema = arrow_obj_handle.attr("schema");
197+
if (py::hasattr(obj_schema, "_export_to_c")) {
198+
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
199+
return;
200+
}
201+
}
202+
// Fallback: create a temporary stream just for the schema (consumes single-use streams!)
203+
auto stream_capsule = arrow_obj_handle.attr("__arrow_c_stream__")();
204+
auto capsule = py::reinterpret_borrow<py::capsule>(stream_capsule);
205+
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
206+
if (stream->get_schema(stream, &schema.arrow_schema)) {
207+
throw InvalidInputException("Failed to get Arrow schema from stream: %s",
208+
stream->get_last_error ? stream->get_last_error(stream) : "unknown error");
209+
}
210+
return; // stream_capsule goes out of scope, stream released by capsule destructor
211+
}
161212
GetSchemaInternal(arrow_obj_handle, schema);
162213
}
163214

src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,7 @@ class Table : public py::object {
5151

5252
} // namespace pyarrow
5353

54-
enum class PyArrowObjectType {
55-
Invalid,
56-
Table,
57-
RecordBatchReader,
58-
Scanner,
59-
Dataset,
60-
PyCapsule,
61-
PyCapsuleInterface,
62-
MessageReader
63-
};
54+
enum class PyArrowObjectType { Invalid, Table, Scanner, Dataset, PyCapsule, PyCapsuleInterface, MessageReader };
6455

6556
void TransformDuckToArrowChunk(ArrowSchema &arrow_schema, ArrowArray &data, py::list &batches);
6657

src/duckdb_py/pyconnection.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,26 +2383,16 @@ PyArrowObjectType DuckDBPyConnection::GetArrowType(const py::handle &obj) {
23832383

23842384
if (ModuleIsLoaded<PyarrowCacheItem>()) {
23852385
auto &import_cache = *DuckDBPyConnection::ImportCache();
2386-
// First Verify Lib Types
2387-
auto table_class = import_cache.pyarrow.Table();
2388-
auto record_batch_reader_class = import_cache.pyarrow.RecordBatchReader();
2389-
auto message_reader_class = import_cache.pyarrow.ipc.MessageReader();
2390-
if (py::isinstance(obj, table_class)) {
2391-
return PyArrowObjectType::Table;
2392-
} else if (py::isinstance(obj, record_batch_reader_class)) {
2393-
return PyArrowObjectType::RecordBatchReader;
2394-
} else if (py::isinstance(obj, message_reader_class)) {
2386+
// MessageReader requires nanoarrow, separate scan function
2387+
if (py::isinstance(obj, import_cache.pyarrow.ipc.MessageReader())) {
23952388
return PyArrowObjectType::MessageReader;
23962389
}
23972390

23982391
if (ModuleIsLoaded<PyarrowDatasetCacheItem>()) {
2399-
// Then Verify dataset types
2400-
auto dataset_class = import_cache.pyarrow.dataset.Dataset();
2401-
auto scanner_class = import_cache.pyarrow.dataset.Scanner();
2402-
2403-
if (py::isinstance(obj, scanner_class)) {
2392+
// Scanner/Dataset don't have __arrow_c_stream__, need dedicated handling
2393+
if (py::isinstance(obj, import_cache.pyarrow.dataset.Scanner())) {
24042394
return PyArrowObjectType::Scanner;
2405-
} else if (py::isinstance(obj, dataset_class)) {
2395+
} else if (py::isinstance(obj, import_cache.pyarrow.dataset.Dataset())) {
24062396
return PyArrowObjectType::Dataset;
24072397
}
24082398
}

src/duckdb_py/python_replacement_scan.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
5151
auto dependency_item = PythonDependencyItem::Create(stream_messages);
5252
external_dependency->AddDependency("replacement_cache", std::move(dependency_item));
5353
} else {
54-
if (type == PyArrowObjectType::PyCapsuleInterface) {
55-
entry = entry.attr("__arrow_c_stream__")();
56-
type = PyArrowObjectType::PyCapsule;
57-
}
58-
5954
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties);
6055
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
6156
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
@@ -66,7 +61,10 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
6661
make_uniq<ConstantExpression>(Value::POINTER(CastPointerToValue(stream_factory_get_schema))));
6762

6863
if (type == PyArrowObjectType::PyCapsule) {
69-
// Disable projection+filter pushdown
64+
// Disable projection+filter pushdown for bare capsules (single-use, no PyArrow wrapper)
65+
table_function.function = make_uniq<FunctionExpression>("arrow_scan_dumb", std::move(children));
66+
} else if (type == PyArrowObjectType::PyCapsuleInterface && !ModuleIsLoaded<PyarrowDatasetCacheItem>()) {
67+
// No pyarrow.dataset: scan without pushdown, DuckDB handles projection/filter post-scan
7068
table_function.function = make_uniq<FunctionExpression>("arrow_scan_dumb", std::move(children));
7169
} else {
7270
table_function.function = make_uniq<FunctionExpression>("arrow_scan", std::move(children));
@@ -140,7 +138,8 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
140138
dependency->AddDependency("replacement_cache", PythonDependencyItem::Create(entry));
141139
subquery->external_dependency = std::move(dependency);
142140
return std::move(subquery);
143-
} else if (PolarsDataFrame::IsDataFrame(entry)) {
141+
} else if (PolarsDataFrame::IsDataFrame(entry) && !py::hasattr(entry, "__arrow_c_stream__")) {
142+
// Legacy path for Polars < 1.4 (no __arrow_c_stream__); newer Polars falls through to GetArrowType
144143
auto arrow_dataset = entry.attr("to_arrow")();
145144
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
146145
*context.db);
@@ -149,9 +148,8 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
149148
auto arrow_dataset = materialized.attr("to_arrow")();
150149
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
151150
*context.db);
152-
} else if (DuckDBPyConnection::GetArrowType(entry) != PyArrowObjectType::Invalid &&
153-
!(DuckDBPyConnection::GetArrowType(entry) == PyArrowObjectType::MessageReader && !relation)) {
154-
arrow_type = DuckDBPyConnection::GetArrowType(entry);
151+
} else if ((arrow_type = DuckDBPyConnection::GetArrowType(entry)) != PyArrowObjectType::Invalid &&
152+
!(arrow_type == PyArrowObjectType::MessageReader && !relation)) {
155153
CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, *context.db);
156154
} else if (DuckDBPyConnection::IsAcceptedNumpyObject(entry) != NumpyObjectType::INVALID) {
157155
numpytype = DuckDBPyConnection::IsAcceptedNumpyObject(entry);

tests/fast/arrow/test_arrow_pycapsule.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,24 @@ def __arrow_c_stream__(self, requested_schema=None) -> object:
2929
obj = MyObject(df)
3030

3131
# Call the __arrow_c_stream__ from within DuckDB
32+
# MyObject has no __arrow_c_schema__, so GetSchema() falls back to __arrow_c_stream__ (1 call),
33+
# then Produce() calls __arrow_c_stream__ again (1 call) = 2 calls minimum per scan.
3234
res = duckdb_cursor.sql("select * from obj")
3335
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
34-
assert obj.count == 1
36+
count_after_first = obj.count
37+
assert count_after_first >= 2
3538

3639
# Call the __arrow_c_stream__ method and pass in the capsule instead
3740
capsule = obj.__arrow_c_stream__()
3841
res = duckdb_cursor.sql("select * from capsule")
3942
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
40-
assert obj.count == 2
43+
assert obj.count == count_after_first + 1
4144

4245
# Ensure __arrow_c_stream__ accepts a requested_schema argument as noop
4346
capsule = obj.__arrow_c_stream__(requested_schema="foo") # noqa: F841
4447
res = duckdb_cursor.sql("select * from capsule")
4548
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
46-
assert obj.count == 3
49+
assert obj.count == count_after_first + 2
4750

4851
def test_capsule_roundtrip(self, duckdb_cursor):
4952
def create_capsule():

0 commit comments

Comments
 (0)