Skip to content

Commit 3fd0f1e

Browse files
Fix arrow stream construction (duckdb#294)
2 parents 57b8d66 + 6b690aa commit 3fd0f1e

4 files changed

Lines changed: 17 additions & 24 deletions

File tree

src/duckdb_py/arrow/arrow_array_stream.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ void VerifyArrowDatasetLoaded() {
2727
}
2828
}
2929

30-
py::object PythonTableArrowArrayStreamFactory::ProduceScanner(DBConfig &config, py::object &arrow_scanner,
31-
py::handle &arrow_obj_handle,
30+
py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
3231
ArrowStreamParameters &parameters,
3332
const ClientProperties &client_properties) {
3433
D_ASSERT(!py::isinstance<py::capsule>(arrow_obj_handle));
3534
ArrowSchemaWrapper schema;
3635
PythonTableArrowArrayStreamFactory::GetSchemaInternal(arrow_obj_handle, schema);
3736
ArrowTableSchema arrow_table;
38-
ArrowTableFunction::PopulateArrowTableSchema(config, arrow_table, schema.arrow_schema);
37+
ArrowTableFunction::PopulateArrowTableSchema(*client_properties.client_context.get_mutable(), arrow_table,
38+
schema.arrow_schema);
3939

4040
auto filters = parameters.filters;
4141
auto &column_list = parameters.projected_columns.columns;
@@ -86,26 +86,23 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
8686
auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset");
8787
auto dataset = arrow_dataset(arrow_obj_handle);
8888
py::object arrow_scanner = dataset.attr("__class__").attr("scanner");
89-
scanner = ProduceScanner(factory->config, arrow_scanner, dataset, parameters, factory->client_properties);
89+
scanner = ProduceScanner(arrow_scanner, dataset, parameters, factory->client_properties);
9090
break;
9191
}
9292
case PyArrowObjectType::RecordBatchReader: {
93-
scanner = ProduceScanner(factory->config, arrow_batch_scanner, arrow_obj_handle, parameters,
94-
factory->client_properties);
93+
scanner = ProduceScanner(arrow_batch_scanner, arrow_obj_handle, parameters, factory->client_properties);
9594
break;
9695
}
9796
case PyArrowObjectType::Scanner: {
9897
// If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack
9998
// scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain
10099
auto record_batches = arrow_obj_handle.attr("to_reader")();
101-
scanner = ProduceScanner(factory->config, arrow_batch_scanner, record_batches, parameters,
102-
factory->client_properties);
100+
scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_properties);
103101
break;
104102
}
105103
case PyArrowObjectType::Dataset: {
106104
py::object arrow_scanner = arrow_obj_handle.attr("__class__").attr("scanner");
107-
scanner =
108-
ProduceScanner(factory->config, arrow_scanner, arrow_obj_handle, parameters, factory->client_properties);
105+
scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_properties);
109106
break;
110107
}
111108
default: {

src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ PyArrowObjectType GetArrowType(const py::handle &obj);
6868

6969
class PythonTableArrowArrayStreamFactory {
7070
public:
71-
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p,
72-
DBConfig &config)
73-
: arrow_object(arrow_table), client_properties(client_properties_p), config(config) {};
71+
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p)
72+
: arrow_object(arrow_table), client_properties(client_properties_p) {};
7473

7574
//! Produces an Arrow Scanner, should be only called once when initializing Scan States
7675
static unique_ptr<ArrowArrayStreamWrapper> Produce(uintptr_t factory, ArrowStreamParameters &parameters);
@@ -83,10 +82,9 @@ class PythonTableArrowArrayStreamFactory {
8382
PyObject *arrow_object;
8483

8584
const ClientProperties client_properties;
86-
DBConfig &config;
8785

8886
private:
89-
static py::object ProduceScanner(DBConfig &config, py::object &arrow_scanner, py::handle &arrow_obj_handle,
87+
static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
9088
ArrowStreamParameters &parameters, const ClientProperties &client_properties);
9189
};
9290
} // namespace duckdb

src/duckdb_py/python_replacement_scan.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace duckdb {
1818

1919
static void CreateArrowScan(const string &name, py::object entry, TableFunctionRef &table_function,
2020
vector<unique_ptr<ParsedExpression>> &children, ClientProperties &client_properties,
21-
PyArrowObjectType type, DBConfig &config, DatabaseInstance &db) {
21+
PyArrowObjectType type, DatabaseInstance &db) {
2222
shared_ptr<ExternalDependency> external_dependency = make_shared_ptr<ExternalDependency>();
2323
if (type == PyArrowObjectType::MessageReader) {
2424
if (!db.ExtensionIsLoaded("nanoarrow")) {
@@ -56,7 +56,7 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
5656
type = PyArrowObjectType::PyCapsule;
5757
}
5858

59-
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties, config);
59+
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties);
6060
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
6161
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
6262

@@ -113,7 +113,7 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
113113
if (PandasDataFrame::IsPyArrowBacked(entry)) {
114114
auto table = PandasDataFrame::ToArrowTable(entry);
115115
CreateArrowScan(name, table, *table_function, children, client_properties, PyArrowObjectType::Table,
116-
DBConfig::GetConfig(context), *context.db);
116+
*context.db);
117117
} else {
118118
string name = "df_" + StringUtil::GenerateRandomName();
119119
auto new_df = PandasScanFunction::PandasReplaceCopiedNames(entry);
@@ -143,17 +143,16 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
143143
} else if (PolarsDataFrame::IsDataFrame(entry)) {
144144
auto arrow_dataset = entry.attr("to_arrow")();
145145
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
146-
DBConfig::GetConfig(context), *context.db);
146+
*context.db);
147147
} else if (PolarsDataFrame::IsLazyFrame(entry)) {
148148
auto materialized = entry.attr("collect")();
149149
auto arrow_dataset = materialized.attr("to_arrow")();
150150
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
151-
DBConfig::GetConfig(context), *context.db);
151+
*context.db);
152152
} else if (DuckDBPyConnection::GetArrowType(entry) != PyArrowObjectType::Invalid &&
153153
!(DuckDBPyConnection::GetArrowType(entry) == PyArrowObjectType::MessageReader && !relation)) {
154154
arrow_type = DuckDBPyConnection::GetArrowType(entry);
155-
CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type,
156-
DBConfig::GetConfig(context), *context.db);
155+
CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, *context.db);
157156
} else if (DuckDBPyConnection::IsAcceptedNumpyObject(entry) != NumpyObjectType::INVALID) {
158157
numpytype = DuckDBPyConnection::IsAcceptedNumpyObject(entry);
159158
string np_name = "np_" + StringUtil::GenerateRandomName();

src/duckdb_py/python_udf.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ static void ConvertArrowTableToVector(const py::object &table, Vector &out, Clie
7474
D_ASSERT(py::gil_check());
7575
py::gil_scoped_release gil;
7676

77-
auto stream_factory =
78-
make_uniq<PythonTableArrowArrayStreamFactory>(ptr, context.GetClientProperties(), DBConfig::GetConfig(context));
77+
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(ptr, context.GetClientProperties());
7978
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
8079
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
8180

0 commit comments

Comments
 (0)