Skip to content

Commit bbe4668

Browse files
committed
Use shared pointer to client context in QueryResult
1 parent c8449f5 commit bbe4668

11 files changed

Lines changed: 80 additions & 88 deletions

src/duckdb_py/arrow/arrow_array_stream.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "duckdb/common/assert.hpp"
1111
#include "duckdb/common/common.hpp"
1212
#include "duckdb/common/limits.hpp"
13-
#include "duckdb/main/client_config.hpp"
1413

1514
namespace duckdb {
1615

@@ -30,13 +29,12 @@ void VerifyArrowDatasetLoaded() {
3029

3130
py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
3231
ArrowStreamParameters &parameters,
33-
const ClientProperties &client_properties) {
32+
const shared_ptr<ClientContext> &client_context) {
3433
D_ASSERT(!py::isinstance<py::capsule>(arrow_obj_handle));
3534
ArrowSchemaWrapper schema;
3635
PythonTableArrowArrayStreamFactory::GetSchemaInternal(arrow_obj_handle, schema);
3736
ArrowTableSchema arrow_table;
38-
ArrowTableFunction::PopulateArrowTableSchema(*client_properties.client_context.get_mutable(), arrow_table,
39-
schema.arrow_schema);
37+
ArrowTableFunction::PopulateArrowTableSchema(*client_context, arrow_table, schema.arrow_schema);
4038

4139
auto filters = parameters.filters;
4240
auto &column_list = parameters.projected_columns.columns;
@@ -50,8 +48,9 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_
5048
}
5149

5250
if (has_filter) {
53-
auto filter = PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map,
54-
filter_to_col, client_properties, arrow_table);
51+
auto filter =
52+
PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col,
53+
client_context->GetClientProperties(), arrow_table);
5554
if (!filter.is(py::none())) {
5655
kwargs["filter"] = filter;
5756
}
@@ -78,7 +77,7 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
7877
try {
7978
auto filter_expr = PolarsFilterPushdown::TransformFilter(
8079
*filters, parameters.projected_columns.projection_map, parameters.projected_columns.filter_to_col,
81-
factory->client_properties);
80+
factory->client_context->GetClientProperties());
8281
if (!filter_expr.is(py::none())) {
8382
lf = lf.attr("filter")(filter_expr);
8483
filters_pushed = true;
@@ -139,7 +138,7 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
139138
auto &import_cache = *DuckDBPyConnection::ImportCache();
140139
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
141140
py::handle reader_handle = reader;
142-
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties);
141+
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_context);
143142
auto record_batches = scanner.attr("to_reader")();
144143
auto res = make_uniq<ArrowArrayStreamWrapper>();
145144
auto export_to_c = record_batches.attr("_export_to_c");
@@ -177,12 +176,12 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
177176
// If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack
178177
// scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain
179178
auto record_batches = arrow_obj_handle.attr("to_reader")();
180-
scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_properties);
179+
scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_context);
181180
break;
182181
}
183182
case PyArrowObjectType::Dataset: {
184183
py::object arrow_scanner = arrow_obj_handle.attr("__class__").attr("scanner");
185-
scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_properties);
184+
scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_context);
186185
break;
187186
}
188187
default: {

src/duckdb_py/arrow/arrow_export_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ namespace duckdb {
1818
namespace pyarrow {
1919

2020
py::object ToArrowTable(const vector<LogicalType> &types, const vector<string> &names, const py::list &batches,
21-
ClientProperties &options) {
21+
ClientContext &client_context) {
2222
py::gil_scoped_acquire acquire;
2323

2424
auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib");
2525
auto from_batches_func = pyarrow_lib_module.attr("Table").attr("from_batches");
2626
auto schema_import_func = pyarrow_lib_module.attr("Schema").attr("_import_from_c");
2727
ArrowSchema schema;
28-
ArrowConverter::ToArrowSchema(&schema, types, names, options);
28+
ArrowConverter::ToArrowSchema(&schema, types, names, client_context);
2929
auto schema_obj = schema_import_func(reinterpret_cast<uint64_t>(&schema));
3030

3131
return py::cast<duckdb::pyarrow::Table>(from_batches_func(batches, schema_obj));

src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ PyArrowObjectType GetArrowType(const py::handle &obj);
6868

6969
class PythonTableArrowArrayStreamFactory {
7070
public:
71-
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p,
71+
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const shared_ptr<ClientContext> &client_context,
7272
PyArrowObjectType arrow_type_p)
73-
: arrow_object(arrow_table), client_properties(client_properties_p), cached_arrow_type(arrow_type_p) {
73+
: arrow_object(arrow_table), client_context(client_context), cached_arrow_type(arrow_type_p) {
7474
cached_schema.release = nullptr;
7575
}
7676

@@ -94,7 +94,7 @@ class PythonTableArrowArrayStreamFactory {
9494
//! Arrow Object (i.e., Scanner, Record Batch Reader, Table, Dataset)
9595
PyObject *arrow_object;
9696

97-
const ClientProperties client_properties;
97+
const shared_ptr<ClientContext> client_context;
9898
const PyArrowObjectType cached_arrow_type;
9999

100100
//! Cached Arrow table from an unfiltered .collect().to_arrow() on a LazyFrame.
@@ -106,7 +106,8 @@ class PythonTableArrowArrayStreamFactory {
106106
bool schema_cached = false;
107107

108108
static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
109-
ArrowStreamParameters &parameters, const ClientProperties &client_properties);
109+
ArrowStreamParameters &parameters,
110+
const shared_ptr<ClientContext> &client_context);
110111
};
111112
} // namespace duckdb
112113

src/duckdb_py/include/duckdb_python/arrow/arrow_export_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace duckdb {
77
namespace pyarrow {
88

99
py::object ToArrowTable(const vector<LogicalType> &types, const vector<string> &names, const py::list &batches,
10-
ClientProperties &options);
10+
ClientContext &client_context);
1111

1212
} // namespace pyarrow
1313

src/duckdb_py/include/duckdb_python/pyresult.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct DuckDBPyResult {
5959
const vector<string> &GetNames();
6060
const vector<LogicalType> &GetTypes();
6161

62-
ClientProperties GetClientProperties();
62+
shared_ptr<ClientContext> GetClientContext() const;
6363

6464
private:
6565
void FillNumpy(py::dict &res, idx_t col_idx, NumpyResultConversion &conversion, const char *name);

src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ namespace duckdb {
1010

1111
struct PythonReplacementScan {
1212
public:
13-
static unique_ptr<TableRef> Replace(ClientContext &context, ReplacementScanInput &input,
13+
static unique_ptr<TableRef> Replace(ClientContext &client_context, ReplacementScanInput &input,
1414
optional_ptr<ReplacementScanData> data);
1515
//! Try to perform a replacement, returns NULL on error
1616
static unique_ptr<TableRef> TryReplacementObject(const py::object &entry, const string &name,
17-
ClientContext &context, bool relation = false);
17+
ClientContext &client_context, bool relation = false);
1818
//! Perform a replacement or throw if it failed
19-
static unique_ptr<TableRef> ReplacementObject(const py::object &entry, const string &name, ClientContext &context,
20-
bool relation = false);
19+
static unique_ptr<TableRef> ReplacementObject(const py::object &entry, const string &name,
20+
ClientContext &client_context, bool relation = false);
2121
};
2222

2323
} // namespace duckdb

src/duckdb_py/pyconnection.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,7 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::Append(const string &name, co
763763
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterPythonObject(const string &name,
764764
const py::object &python_object) {
765765
auto &connection = con.GetConnection();
766-
auto &client = *connection.context;
767-
auto object = PythonReplacementScan::ReplacementObject(python_object, name, client);
766+
auto object = PythonReplacementScan::ReplacementObject(python_object, name, *connection.context);
768767
auto view_rel = make_shared_ptr<ViewRelation>(connection.context, std::move(object), name);
769768
bool replace = registered_objects.count(name);
770769
view_rel->CreateView(name, replace, true);

src/duckdb_py/pyrelation.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,18 +1024,18 @@ PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) {
10241024
ArrowSchema arrow_schema;
10251025
auto result_names = names;
10261026
QueryResult::DeduplicateColumns(result_names);
1027-
ClientProperties client_properties;
1027+
shared_ptr<ClientContext> client_context;
10281028
if (rel) {
1029-
client_properties = rel->context->GetContext()->GetClientProperties();
1029+
client_context = rel->context->GetContext();
10301030
} else if (result) {
1031-
client_properties = result->GetClientProperties();
1031+
client_context = result->GetClientContext();
10321032
} else {
10331033
throw InternalException("DuckDBPyRelation To Polars must have a valid relation or result");
10341034
}
1035-
ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, client_properties);
1035+
ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, *client_context);
10361036
py::list batches;
10371037
// Now we create an empty arrow table
1038-
auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), client_properties);
1038+
auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), *client_context);
10391039

10401040
// And we extract the polars schema from the arrow table
10411041
auto polars_df = py::cast<PolarsDataFrame>(pybind11::module_::import("polars").attr("DataFrame")(empty_table));

src/duckdb_py/pyresult.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ DuckDBPyResult::~DuckDBPyResult() {
4343
}
4444
}
4545

46-
ClientProperties DuckDBPyResult::GetClientProperties() {
47-
return result->client_properties;
46+
shared_ptr<ClientContext> DuckDBPyResult::GetClientContext() const {
47+
return result->client_context;
4848
}
4949

5050
const vector<string> &DuckDBPyResult::GetNames() {
@@ -138,7 +138,8 @@ Optional<py::tuple> DuckDBPyResult::Fetchone() {
138138
continue;
139139
}
140140
auto val = current_chunk->data[col_idx].GetValue(chunk_offset);
141-
res[col_idx] = PythonObject::FromValue(val, result->types[col_idx], result->client_properties);
141+
res[col_idx] =
142+
PythonObject::FromValue(val, result->types[col_idx], result->client_context->GetClientProperties());
142143
}
143144
chunk_offset++;
144145
return res;
@@ -225,8 +226,8 @@ unique_ptr<NumpyResultConversion> DuckDBPyResult::InitializeNumpyConversion(bool
225226
initial_capacity = materialized.RowCount();
226227
}
227228

228-
auto conversion =
229-
make_uniq<NumpyResultConversion>(result->types, initial_capacity, result->client_properties, pandas);
229+
auto conversion = make_uniq<NumpyResultConversion>(result->types, initial_capacity,
230+
result->client_context->GetClientProperties(), pandas);
230231
return conversion;
231232
}
232233

@@ -297,7 +298,8 @@ void DuckDBPyResult::ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_obje
297298
if (result->types[i] == LogicalType::TIMESTAMP_TZ) {
298299
// first localize to UTC then convert to timezone_config
299300
auto utc_local = df[names[i].c_str()].attr("dt").attr("tz_localize")("UTC");
300-
auto new_value = utc_local.attr("dt").attr("tz_convert")(result->client_properties.time_zone);
301+
auto new_value =
302+
utc_local.attr("dt").attr("tz_convert")(result->client_context->GetClientProperties().time_zone);
301303
// We need to create the column anew because the exact dt changed to a new timezone
302304
ReplaceDFColumn(df, names[i].c_str(), i, new_value);
303305
} else if (date_as_object && result->types[i] == LogicalType::DATE) {
@@ -440,8 +442,7 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
440442
}
441443
ArrowArray data = array->arrow_array;
442444
array->arrow_array.release = nullptr;
443-
ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names,
444-
arrow_result.client_properties);
445+
ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names, *GetClientContext());
445446
TransformDuckToArrowChunk(arrow_schema, data, batches);
446447
}
447448
} else {
@@ -453,9 +454,9 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
453454
{
454455
D_ASSERT(py::gil_check());
455456
py::gil_scoped_release release;
456-
count = ArrowUtil::FetchChunk(scan_state, query_result.client_properties, rows_per_batch, &data,
457-
ArrowTypeExtensionData::GetExtensionTypes(
458-
*query_result.client_properties.client_context, query_result.types));
457+
auto arrow_type_exts =
458+
ArrowTypeExtensionData::GetExtensionTypes(*GetClientContext(), query_result.types);
459+
count = ArrowUtil::FetchChunk(scan_state, *GetClientContext(), rows_per_batch, &data, arrow_type_exts);
459460
}
460461
if (count == 0) {
461462
break;
@@ -465,13 +466,12 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
465466
if (to_polars) {
466467
QueryResult::DeduplicateColumns(result_names);
467468
}
468-
ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, result_names,
469-
query_result.client_properties);
469+
ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, result_names, *GetClientContext());
470470
TransformDuckToArrowChunk(arrow_schema, data, batches);
471471
}
472472
}
473473

474-
return pyarrow::ToArrowTable(result->types, names, std::move(batches), result->client_properties);
474+
return pyarrow::ToArrowTable(result->types, names, std::move(batches), *GetClientContext());
475475
}
476476

477477
ArrowArrayStream DuckDBPyResult::FetchArrowArrayStream(idx_t rows_per_batch) {
@@ -623,7 +623,7 @@ struct ArrowQueryResultStreamWrapper {
623623
arrays = arrow_result.ConsumeArrays();
624624

625625
cached_schema.release = nullptr;
626-
ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, result->client_properties);
626+
ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, *result->client_context);
627627

628628
stream.private_data = this;
629629
stream.get_schema = GetSchema;

0 commit comments

Comments
 (0)