Skip to content

Commit 440c64a

Browse files
committed
Cache Arrow schema
1 parent 811b135 commit 440c64a

2 files changed

Lines changed: 134 additions & 126 deletions

File tree

src/duckdb_py/pyresult.cpp

Lines changed: 71 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "duckdb_python/arrow/arrow_export_utils.hpp"
2323
#include "duckdb/main/chunk_scan_state/query_result.hpp"
2424
#include "duckdb/common/arrow/arrow_query_result.hpp"
25+
#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp"
2526

2627
using namespace pybind11::literals;
2728

@@ -428,25 +429,39 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
428429
QueryResult::DeduplicateColumns(names);
429430
}
430431

431-
if (!result) {
432-
throw InvalidInputException("result closed");
433-
}
434432
auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib");
435433

434+
// If the producing operator cached the Arrow schema (built while its own
435+
// transaction was still active), reuse it instead of rebuilding it here.
436+
// Rebuilding post-commit breaks arrow type extensions whose schema callback
437+
// does a catalog lookup -- e.g. GeoArrow CRS resolution -- which asserts an
438+
// active transaction. See duckdb-python#475 / duckdb-spatial#788.
439+
ArrowQueryResult *cached_result = nullptr;
440+
if (result->type == QueryResultType::ARROW_RESULT) {
441+
auto &arrow_result = result->Cast<ArrowQueryResult>();
442+
if (arrow_result.HasCachedSchema()) {
443+
cached_result = &arrow_result;
444+
}
445+
}
446+
436447
py::list batches;
437448
if (result->type == QueryResultType::ARROW_RESULT) {
438449
auto &arrow_result = result->Cast<ArrowQueryResult>();
439450
auto arrays = arrow_result.ConsumeArrays();
440451
for (auto &array : arrays) {
441452
ArrowSchema arrow_schema;
442-
auto result_names = arrow_result.names;
443-
if (to_polars) {
444-
QueryResult::DeduplicateColumns(result_names);
445-
}
446453
ArrowArray data = array->arrow_array;
447454
array->arrow_array.release = nullptr;
448-
ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names,
449-
arrow_result.client_properties);
455+
if (cached_result) {
456+
cached_result->GetSchema(arrow_schema);
457+
} else {
458+
auto result_names = arrow_result.names;
459+
if (to_polars) {
460+
QueryResult::DeduplicateColumns(result_names);
461+
}
462+
ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names,
463+
arrow_result.client_properties);
464+
}
450465
TransformDuckToArrowChunk(arrow_schema, data, batches);
451466
}
452467
} else {
@@ -476,7 +491,33 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
476491
}
477492
}
478493

479-
return pyarrow::ToArrowTable(result->types, names, std::move(batches), result->client_properties);
494+
if (!cached_result) {
495+
return pyarrow::ToArrowTable(result->types, names, std::move(batches), result->client_properties);
496+
}
497+
498+
// Assemble the table from the cached schema (avoids the ToArrowSchema call
499+
// inside pyarrow::ToArrowTable, which would also assert post-commit).
500+
auto from_batches_func = pyarrow_lib_module.attr("Table").attr("from_batches");
501+
auto schema_import_func = pyarrow_lib_module.attr("Schema").attr("_import_from_c");
502+
ArrowSchema final_schema;
503+
cached_result->GetSchema(final_schema);
504+
auto schema_obj = schema_import_func(reinterpret_cast<uint64_t>(&final_schema));
505+
auto table = py::cast<duckdb::pyarrow::Table>(from_batches_func(batches, schema_obj));
506+
if (to_polars) {
507+
// The cached schema carries the original column names; polars needs them
508+
// unique. Rename only when there are real duplicates, so unique columns
509+
// (and their field metadata, e.g. geoarrow) are left untouched.
510+
auto deduped = result->names;
511+
QueryResult::DeduplicateColumns(deduped);
512+
if (deduped != result->names) {
513+
py::list renamed;
514+
for (auto &n : deduped) {
515+
renamed.append(n);
516+
}
517+
table = py::cast<duckdb::pyarrow::Table>(table.attr("rename_columns")(renamed));
518+
}
519+
}
520+
return table;
480521
}
481522

482523
ArrowArrayStream DuckDBPyResult::FetchArrowArrayStream(idx_t rows_per_batch) {
@@ -501,119 +542,6 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t
501542
return py::cast<duckdb::pyarrow::RecordBatchReader>(record_batch_reader);
502543
}
503544

504-
// Holds owned copies of the string data for a deep-copied ArrowSchema node.
505-
struct ArrowSchemaCopyData {
506-
string format;
507-
string name;
508-
string metadata;
509-
};
510-
511-
static void ReleaseCopiedArrowSchema(ArrowSchema *schema) {
512-
if (!schema || !schema->release) {
513-
return;
514-
}
515-
for (int64_t i = 0; i < schema->n_children; i++) {
516-
if (schema->children[i]->release) {
517-
schema->children[i]->release(schema->children[i]);
518-
}
519-
delete schema->children[i];
520-
}
521-
delete[] schema->children;
522-
if (schema->dictionary) {
523-
if (schema->dictionary->release) {
524-
schema->dictionary->release(schema->dictionary);
525-
}
526-
delete schema->dictionary;
527-
}
528-
delete reinterpret_cast<ArrowSchemaCopyData *>(schema->private_data);
529-
schema->release = nullptr;
530-
}
531-
532-
static idx_t ArrowMetadataSize(const char *metadata) {
533-
if (!metadata) {
534-
return 0;
535-
}
536-
// Arrow metadata format: int32 num_entries, then for each entry:
537-
// int32 key_len, key_bytes, int32 value_len, value_bytes
538-
auto ptr = metadata;
539-
int32_t num_entries;
540-
memcpy(&num_entries, ptr, sizeof(int32_t));
541-
ptr += sizeof(int32_t);
542-
for (int32_t i = 0; i < num_entries; i++) {
543-
int32_t len;
544-
memcpy(&len, ptr, sizeof(int32_t));
545-
ptr += sizeof(int32_t) + len;
546-
memcpy(&len, ptr, sizeof(int32_t));
547-
ptr += sizeof(int32_t) + len;
548-
}
549-
return ptr - metadata;
550-
}
551-
552-
// Deep-copy an ArrowSchema. The Arrow C Data Interface specifies that get_schema
553-
// transfers ownership to the caller, so each call must produce an independent copy.
554-
// Each node owns its string data via an ArrowSchemaCopyData in private_data.
555-
static int ArrowSchemaDeepCopy(const ArrowSchema &source, ArrowSchema *out, string &error) {
556-
out->release = nullptr;
557-
try {
558-
auto data = new ArrowSchemaCopyData();
559-
data->format = source.format ? source.format : "";
560-
data->name = source.name ? source.name : "";
561-
if (source.metadata) {
562-
auto metadata_size = ArrowMetadataSize(source.metadata);
563-
data->metadata.assign(source.metadata, metadata_size);
564-
}
565-
566-
out->format = data->format.c_str();
567-
out->name = data->name.c_str();
568-
out->metadata = source.metadata ? data->metadata.data() : nullptr;
569-
out->flags = source.flags;
570-
out->n_children = source.n_children;
571-
out->dictionary = nullptr;
572-
out->private_data = data;
573-
out->release = ReleaseCopiedArrowSchema;
574-
575-
if (source.n_children > 0) {
576-
out->children = new ArrowSchema *[source.n_children];
577-
for (int64_t i = 0; i < source.n_children; i++) {
578-
out->children[i] = new ArrowSchema();
579-
auto rc = ArrowSchemaDeepCopy(*source.children[i], out->children[i], error);
580-
if (rc != 0) {
581-
for (int64_t j = 0; j <= i; j++) {
582-
if (out->children[j]->release) {
583-
out->children[j]->release(out->children[j]);
584-
}
585-
delete out->children[j];
586-
}
587-
delete[] out->children;
588-
out->children = nullptr;
589-
out->n_children = 0;
590-
// Release the partially constructed node
591-
delete data;
592-
out->private_data = nullptr;
593-
out->release = nullptr;
594-
return rc;
595-
}
596-
}
597-
} else {
598-
out->children = nullptr;
599-
}
600-
601-
if (source.dictionary) {
602-
out->dictionary = new ArrowSchema();
603-
auto rc = ArrowSchemaDeepCopy(*source.dictionary, out->dictionary, error);
604-
if (rc != 0) {
605-
delete out->dictionary;
606-
out->dictionary = nullptr;
607-
return rc;
608-
}
609-
}
610-
} catch (std::exception &e) {
611-
error = e.what();
612-
return -1;
613-
}
614-
return 0;
615-
}
616-
617545
// Wraps pre-built Arrow arrays from an ArrowQueryResult into an ArrowArrayStream.
618546
// This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper
619547
// with an ArrowQueryResult (which throws NotImplementedException from FetchInternal).
@@ -628,7 +556,16 @@ struct ArrowQueryResultStreamWrapper {
628556
arrays = arrow_result.ConsumeArrays();
629557

630558
cached_schema.release = nullptr;
631-
ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, result->client_properties);
559+
if (arrow_result.HasCachedSchema()) {
560+
// Reuse the schema the collector built under the producing transaction.
561+
// Rebuilding it here (post-commit) would assert for arrow type
562+
// extensions that do a catalog lookup, e.g. GeoArrow CRS -- this is the
563+
// capsule / Arrow C Stream form of #475 (pa.table(rel), pl.DataFrame(rel),
564+
// ADBC). See duckdb-python#475 / duckdb-spatial#788.
565+
arrow_result.GetSchema(cached_schema);
566+
} else {
567+
ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, result->client_properties);
568+
}
632569

633570
stream.private_data = this;
634571
stream.get_schema = GetSchema;
@@ -648,7 +585,11 @@ struct ArrowQueryResultStreamWrapper {
648585
return -1;
649586
}
650587
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
651-
return ArrowSchemaDeepCopy(self->cached_schema, out, self->last_error);
588+
auto rc = duckdb_nanoarrow::ArrowSchemaDeepCopy(&self->cached_schema, out);
589+
if (rc != NANOARROW_OK) {
590+
self->last_error = "failed to copy cached Arrow schema";
591+
}
592+
return rc;
652593
}
653594

654595
static int GetNext(ArrowArrayStream *stream, ArrowArray *out) {
@@ -731,7 +672,11 @@ struct SchemaCachingStreamWrapper {
731672
if (!self->schema_ok) {
732673
return -1;
733674
}
734-
return ArrowSchemaDeepCopy(self->cached_schema, out, self->schema_error);
675+
auto rc = duckdb_nanoarrow::ArrowSchemaDeepCopy(&self->cached_schema, out);
676+
if (rc != NANOARROW_OK) {
677+
self->schema_error = "failed to copy cached Arrow schema";
678+
}
679+
return rc;
735680
}
736681

737682
static int GetNext(ArrowArrayStream *stream, ArrowArray *out) {
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Regression test for duckdb-python#475 / duckdb-spatial#788.
2+
3+
Converting a GEOMETRY column that carries a CRS to Arrow used to raise
4+
``InternalException: TransactionContext::ActiveTransaction called without
5+
active transaction``. The GeoArrow schema callback does a catalog lookup to
6+
resolve the CRS, which needs an active transaction -- but the Arrow schema was
7+
rebuilt at consume time, after the producing (auto-commit) transaction had
8+
already closed.
9+
10+
The fix builds and caches the schema on ArrowQueryResult while the producing
11+
transaction is still active, and the consumers below reuse it. Each test
12+
exercises one of those consumers. The geometry-with-CRS value is built with a
13+
pure-core cast, so no spatial extension is required.
14+
"""
15+
16+
from __future__ import annotations
17+
18+
import pytest
19+
20+
import duckdb
21+
22+
pa = pytest.importorskip("pyarrow")
23+
24+
# An authority-code CRS forces the catalog lookup that used to require an open
25+
# transaction. No spatial extension needed -- the cast and geoarrow.wkb mapping
26+
# are both in core.
27+
GEOM_SQL = "SELECT 'POINT(0 1)'::GEOMETRY('OGC:CRS84') AS g"
28+
29+
30+
def _assert_geoarrow_with_crs(field: pa.Field) -> None:
31+
metadata = field.metadata or {}
32+
assert metadata.get(b"ARROW:extension:name") == b"geoarrow.wkb"
33+
assert b"crs" in metadata.get(b"ARROW:extension:metadata", b"")
34+
35+
36+
def test_475_to_arrow_table_geometry_with_crs():
37+
con = duckdb.connect()
38+
table = con.sql(GEOM_SQL).to_arrow_table()
39+
assert table.num_rows == 1
40+
_assert_geoarrow_with_crs(table.schema.field("g"))
41+
42+
43+
def test_475_arrow_capsule_geometry_with_crs():
44+
# pa.table(rel) consumes via __arrow_c_stream__ (the capsule / ADBC path).
45+
con = duckdb.connect()
46+
table = pa.table(con.sql(GEOM_SQL))
47+
assert table.num_rows == 1
48+
_assert_geoarrow_with_crs(table.schema.field("g"))
49+
50+
51+
def test_475_record_batch_reader_geometry_with_crs():
52+
con = duckdb.connect()
53+
table = con.sql(GEOM_SQL).to_arrow_reader().read_all()
54+
assert table.num_rows == 1
55+
_assert_geoarrow_with_crs(table.schema.field("g"))
56+
57+
58+
def test_475_polars_geometry_with_crs():
59+
pl = pytest.importorskip("polars")
60+
con = duckdb.connect()
61+
# polars.DataFrame(rel) pulls the relation's Arrow C stream directly.
62+
df = pl.DataFrame(con.sql(GEOM_SQL))
63+
assert df.height == 1

0 commit comments

Comments
 (0)