Skip to content

Commit f502acd

Browse files
committed
keep connection alive for arrow streams
1 parent b71639f commit f502acd

4 files changed

Lines changed: 24 additions & 9 deletions

File tree

external/duckdb

Submodule duckdb updated 50 files

src/duckdb_py/include/duckdb_python/pyresult.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct DuckDBPyResult {
4646

4747
ArrowArrayStream FetchArrowArrayStream(idx_t rows_per_batch = 1000000);
4848
duckdb::pyarrow::RecordBatchReader FetchRecordBatchReader(idx_t rows_per_batch = 1000000);
49-
py::object FetchArrowCapsule(idx_t rows_per_batch = 1000000);
49+
py::object FetchArrowCapsule(shared_ptr<ClientContext> context, idx_t rows_per_batch = 1000000);
5050

5151
static py::list GetDescription(const vector<string> &names, const vector<LogicalType> &types);
5252

src/duckdb_py/pyrelation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,8 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema)
10061006
ExecuteOrThrow();
10071007
}
10081008
AssertResultOpen();
1009-
auto res = result->FetchArrowCapsule();
1009+
auto context = rel ? rel->context->GetContext() : nullptr;
1010+
auto res = result->FetchArrowCapsule(std::move(context));
10101011
result = nullptr;
10111012
return res;
10121013
}

src/duckdb_py/pyresult.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t
500500
// This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper
501501
// with an ArrowQueryResult (which throws NotImplementedException from FetchInternal).
502502
struct ArrowQueryResultStreamWrapper {
503-
ArrowQueryResultStreamWrapper(unique_ptr<QueryResult> result_p) : result(std::move(result_p)), index(0) {
503+
ArrowQueryResultStreamWrapper(unique_ptr<QueryResult> result_p, shared_ptr<ClientContext> context_p)
504+
: result(std::move(result_p)), context(std::move(context_p)), index(0) {
504505
auto &arrow_result = result->Cast<ArrowQueryResult>();
505506
arrays = arrow_result.ConsumeArrays();
506507
types = result->types;
@@ -562,6 +563,7 @@ struct ArrowQueryResultStreamWrapper {
562563

563564
ArrowArrayStream stream;
564565
unique_ptr<QueryResult> result;
566+
shared_ptr<ClientContext> context;
565567
vector<unique_ptr<ArrowArrayWrapper>> arrays;
566568
vector<LogicalType> types;
567569
vector<string> names;
@@ -570,7 +572,9 @@ struct ArrowQueryResultStreamWrapper {
570572
string last_error;
571573
};
572574

573-
// Destructor for capsules that own a heap-allocated ArrowArrayStream (slow path).
575+
// Destructor for capsules that own a heap-allocated ArrowArrayStream.
576+
// If PyCapsule_GetContext is set, it points to a shared_ptr<ClientContext> that
577+
// keeps the connection's ClientContext alive for the lifetime of the capsule.
574578
static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) {
575579
auto data = PyCapsule_GetPointer(object, "arrow_array_stream");
576580
if (!data) {
@@ -581,25 +585,35 @@ static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) {
581585
stream->release(stream);
582586
}
583587
delete stream;
588+
auto ctx = PyCapsule_GetContext(object);
589+
if (ctx) {
590+
delete reinterpret_cast<shared_ptr<ClientContext> *>(ctx);
591+
}
584592
}
585593

586-
py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) {
594+
py::object DuckDBPyResult::FetchArrowCapsule(shared_ptr<ClientContext> context, idx_t rows_per_batch) {
587595
if (result && result->type == QueryResultType::ARROW_RESULT) {
588596
// Fast path: yield pre-built Arrow arrays directly.
589597
// The wrapper is heap-allocated; Release() deletes it via private_data.
590598
// We heap-allocate a separate ArrowArrayStream for the capsule so that the capsule
591599
// holds a stable pointer even after the wrapper is consumed and deleted by a scan.
592-
auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result));
600+
auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result), std::move(context));
593601
auto stream = new ArrowArrayStream();
594602
*stream = wrapper->stream;
595603
wrapper->stream.release = nullptr;
596604
return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor);
597605
}
598-
// Existing slow path for MaterializedQueryResult / StreamQueryResult
606+
// Existing slow path for MaterializedQueryResult / StreamQueryResult.
607+
// Keep the ClientContext alive via the capsule's context pointer so that
608+
// deferred get_schema / get_next calls can still dereference client_context.
599609
auto stream_p = FetchArrowArrayStream(rows_per_batch);
600610
auto stream = new ArrowArrayStream();
601611
*stream = stream_p;
602-
return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor);
612+
auto capsule = py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor);
613+
if (context) {
614+
PyCapsule_SetContext(capsule.ptr(), new shared_ptr<ClientContext>(std::move(context)));
615+
}
616+
return capsule;
603617
}
604618

605619
py::list DuckDBPyResult::GetDescription(const vector<string> &names, const vector<LogicalType> &types) {

0 commit comments

Comments
 (0)