Skip to content

Commit 1852dfd

Browse files
committed
Unify filter pushdown across arrow and polars
1 parent 276028b commit 1852dfd

11 files changed

Lines changed: 1383 additions & 711 deletions

File tree

scripts/cache_data.json

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,9 @@
532532
"polars.DataFrame",
533533
"polars.LazyFrame",
534534
"polars.col",
535-
"polars.lit"
535+
"polars.lit",
536+
"polars.Series",
537+
"polars.Decimal"
536538
],
537539
"required": false
538540
},
@@ -822,5 +824,17 @@
822824
"full_path": "polars.lit",
823825
"name": "lit",
824826
"children": []
827+
},
828+
"polars.Series": {
829+
"type": "attribute",
830+
"full_path": "polars.Series",
831+
"name": "Series",
832+
"children": []
833+
},
834+
"polars.Decimal": {
835+
"type": "attribute",
836+
"full_path": "polars.Decimal",
837+
"name": "Decimal",
838+
"children": []
825839
}
826840
}

scripts/imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111
polars.LazyFrame
112112
polars.col
113113
polars.lit
114+
polars.Series
115+
polars.Decimal
114116

115117
import duckdb
116118
import duckdb.filesystem

src/duckdb_py/arrow/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# this is used for clang-tidy checks
22
add_library(
3-
python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp
4-
polars_filter_pushdown.cpp pyarrow_filter_pushdown.cpp)
3+
python_arrow OBJECT
4+
arrow_array_stream.cpp arrow_export_utils.cpp filter_pushdown_visitor.cpp
5+
polars_filter_pushdown.cpp pyarrow_filter_pushdown.cpp)
56

67
target_link_libraries(python_arrow PRIVATE _duckdb_dependencies)
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#include "duckdb_python/arrow/filter_pushdown_visitor.hpp"
2+
3+
#include "duckdb/function/scalar/struct_utils.hpp"
4+
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5+
#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
6+
#include "duckdb/planner/expression/bound_constant_expression.hpp"
7+
#include "duckdb/planner/expression/bound_function_expression.hpp"
8+
#include "duckdb/planner/expression/bound_operator_expression.hpp"
9+
#include "duckdb/planner/expression/bound_reference_expression.hpp"
10+
#include "duckdb/planner/filter/conjunction_filter.hpp"
11+
#include "duckdb/planner/filter/constant_filter.hpp"
12+
#include "duckdb/planner/filter/expression_filter.hpp"
13+
#include "duckdb/planner/filter/in_filter.hpp"
14+
#include "duckdb/planner/filter/optional_filter.hpp"
15+
#include "duckdb/planner/filter/struct_filter.hpp"
16+
17+
namespace duckdb {
18+
19+
namespace {
20+
21+
bool ValueIsNan(const Value &value) {
22+
if (value.type().id() == LogicalTypeId::FLOAT) {
23+
return Value::IsNan(value.GetValue<float>());
24+
}
25+
if (value.type().id() == LogicalTypeId::DOUBLE) {
26+
return Value::IsNan(value.GetValue<double>());
27+
}
28+
return false;
29+
}
30+
31+
// ResolveColumn walks a column-side expression to extract the (full path, leaf
32+
// ArrowType) pair. Accepts a bare BoundReferenceExpression and (nested)
33+
// `struct_extract` chains. Anything else throws NotImplementedException —
34+
// that gives the OPTIONAL_FILTER catch point a chance to swallow it.
35+
struct ResolvedColumn {
36+
vector<string> path;
37+
const ArrowType *leaf_type;
38+
};
39+
40+
ResolvedColumn ResolveColumn(const Expression &expr, const vector<string> &root_path, const ArrowType *root_type) {
41+
if (expr.GetExpressionClass() == ExpressionClass::BOUND_REF) {
42+
return {root_path, root_type};
43+
}
44+
if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) {
45+
throw NotImplementedException("Cannot push down arrow scan filter on column-side expression: %s",
46+
ExpressionClassToString(expr.GetExpressionClass()));
47+
}
48+
auto &func = expr.Cast<BoundFunctionExpression>();
49+
idx_t child_idx;
50+
if (!TryGetStructExtractChildIndex(func, child_idx)) {
51+
throw NotImplementedException("Cannot push down arrow scan filter on column-side function: %s",
52+
ExpressionTypeToString(expr.GetExpressionType()));
53+
}
54+
// Recurse innermost-first so names accumulate root → leaf.
55+
auto inner = ResolveColumn(*func.children[0], root_path, root_type);
56+
inner.path.push_back(StructType::GetChildName(func.children[0]->GetReturnType(), child_idx));
57+
if (inner.leaf_type) {
58+
inner.leaf_type = &inner.leaf_type->GetTypeInfo<ArrowStructInfo>().GetChild(child_idx);
59+
}
60+
return inner;
61+
}
62+
63+
py::object EmitCompare(FilterBackend &backend, ExpressionType op, py::object col, const Value &constant,
64+
const ArrowType *arrow_type, const string &timezone_config) {
65+
if (ValueIsNan(constant)) {
66+
return backend.NaNCompare(op, std::move(col));
67+
}
68+
auto scalar = backend.MakeScalar(constant, arrow_type, timezone_config);
69+
return backend.Compare(op, std::move(col), std::move(scalar));
70+
}
71+
72+
} // anonymous namespace
73+
74+
py::object TransformExpression(const Expression &expression, const vector<string> &column_path, FilterBackend &backend,
75+
const ArrowType *arrow_type, const string &timezone_config) {
76+
auto expression_class = expression.GetExpressionClass();
77+
auto expression_type = expression.GetExpressionType();
78+
79+
if (expression_class == ExpressionClass::BOUND_FUNCTION) {
80+
auto &bound_function_expression = expression.Cast<BoundFunctionExpression>();
81+
if (BoundComparisonExpression::IsComparison(expression_type)) {
82+
auto &left = BoundComparisonExpression::Left(bound_function_expression);
83+
auto &right = BoundComparisonExpression::Right(bound_function_expression);
84+
85+
optional_ptr<const Expression> column_side;
86+
optional_ptr<const BoundConstantExpression> constant_side;
87+
88+
if (right.GetExpressionType() == ExpressionType::VALUE_CONSTANT) {
89+
column_side = &left;
90+
constant_side = &right.Cast<BoundConstantExpression>();
91+
} else if (left.GetExpressionType() == ExpressionType::VALUE_CONSTANT) {
92+
column_side = &right;
93+
constant_side = &left.Cast<BoundConstantExpression>();
94+
expression_type = FlipComparisonExpression(expression_type);
95+
} else {
96+
throw NotImplementedException("Can only push down constant comparisons.");
97+
}
98+
99+
auto resolved = ResolveColumn(*column_side, column_path, arrow_type);
100+
auto col = backend.MakeColumnRef(resolved.path);
101+
return EmitCompare(backend, expression_type, std::move(col), constant_side->value, resolved.leaf_type,
102+
timezone_config);
103+
}
104+
}
105+
106+
if (expression_class == ExpressionClass::BOUND_OPERATOR) {
107+
auto &op_expr = expression.Cast<BoundOperatorExpression>();
108+
if (expression_type == ExpressionType::OPERATOR_IS_NULL) {
109+
auto resolved = ResolveColumn(*op_expr.children[0], column_path, arrow_type);
110+
auto col = backend.MakeColumnRef(resolved.path);
111+
return backend.IsNull(std::move(col));
112+
}
113+
if (expression_type == ExpressionType::OPERATOR_IS_NOT_NULL) {
114+
auto resolved = ResolveColumn(*op_expr.children[0], column_path, arrow_type);
115+
auto col = backend.MakeColumnRef(resolved.path);
116+
return backend.IsNotNull(std::move(col));
117+
}
118+
if (expression_type == ExpressionType::COMPARE_IN) {
119+
auto resolved = ResolveColumn(*op_expr.children[0], column_path, arrow_type);
120+
auto col = backend.MakeColumnRef(resolved.path);
121+
vector<Value> values;
122+
for (idx_t i = 1; i < op_expr.children.size(); i++) {
123+
auto &const_expr = op_expr.children[i]->Cast<BoundConstantExpression>();
124+
values.push_back(const_expr.value);
125+
}
126+
auto col_type = op_expr.children[0]->GetReturnType();
127+
return backend.IsIn(std::move(col), values, col_type, timezone_config);
128+
}
129+
}
130+
131+
if (expression_class == ExpressionClass::BOUND_CONJUNCTION) {
132+
if (expression_type == ExpressionType::CONJUNCTION_OR || expression_type == ExpressionType::CONJUNCTION_AND) {
133+
auto &conj_expr = expression.Cast<BoundConjunctionExpression>();
134+
py::object result = py::none();
135+
for (idx_t i = 0; i < conj_expr.children.size(); i++) {
136+
py::object child_expression =
137+
TransformExpression(*conj_expr.children[i], column_path, backend, arrow_type, timezone_config);
138+
if (child_expression.is(py::none())) {
139+
continue;
140+
}
141+
if (result.is(py::none())) {
142+
result = std::move(child_expression);
143+
} else if (expression_type == ExpressionType::CONJUNCTION_AND) {
144+
result = backend.And(std::move(result), std::move(child_expression));
145+
} else {
146+
result = backend.Or(std::move(result), std::move(child_expression));
147+
}
148+
}
149+
return result;
150+
}
151+
}
152+
153+
throw NotImplementedException("Pushdown Filter Type %s is not currently supported in arrow scans",
154+
ExpressionClassToString(expression_class));
155+
}
156+
157+
py::object TransformFilter(const TableFilter &filter, vector<string> column_path, FilterBackend &backend,
158+
const ArrowType *arrow_type, const string &timezone_config) {
159+
switch (filter.filter_type) {
160+
case TableFilterType::CONSTANT_COMPARISON: {
161+
auto &constant_filter = filter.Cast<ConstantFilter>();
162+
auto col = backend.MakeColumnRef(column_path);
163+
return EmitCompare(backend, constant_filter.comparison_type, std::move(col), constant_filter.constant,
164+
arrow_type, timezone_config);
165+
}
166+
case TableFilterType::IS_NULL: {
167+
auto col = backend.MakeColumnRef(column_path);
168+
return backend.IsNull(std::move(col));
169+
}
170+
case TableFilterType::IS_NOT_NULL: {
171+
auto col = backend.MakeColumnRef(column_path);
172+
return backend.IsNotNull(std::move(col));
173+
}
174+
case TableFilterType::CONJUNCTION_AND: {
175+
auto &and_filter = filter.Cast<ConjunctionAndFilter>();
176+
py::object result = py::none();
177+
for (idx_t i = 0; i < and_filter.child_filters.size(); i++) {
178+
py::object child_expression =
179+
TransformFilter(*and_filter.child_filters[i], column_path, backend, arrow_type, timezone_config);
180+
if (child_expression.is(py::none())) {
181+
continue;
182+
}
183+
if (result.is(py::none())) {
184+
result = std::move(child_expression);
185+
} else {
186+
result = backend.And(std::move(result), std::move(child_expression));
187+
}
188+
}
189+
return result;
190+
}
191+
case TableFilterType::CONJUNCTION_OR: {
192+
auto &or_filter = filter.Cast<ConjunctionOrFilter>();
193+
py::object result = py::none();
194+
for (idx_t i = 0; i < or_filter.child_filters.size(); i++) {
195+
py::object child_expression =
196+
TransformFilter(*or_filter.child_filters[i], column_path, backend, arrow_type, timezone_config);
197+
if (child_expression.is(py::none())) {
198+
continue;
199+
}
200+
if (result.is(py::none())) {
201+
result = std::move(child_expression);
202+
} else {
203+
result = backend.Or(std::move(result), std::move(child_expression));
204+
}
205+
}
206+
return result;
207+
}
208+
case TableFilterType::STRUCT_EXTRACT: {
209+
auto &struct_filter = filter.Cast<StructFilter>();
210+
column_path.push_back(struct_filter.child_name);
211+
const ArrowType *child_type = nullptr;
212+
if (arrow_type) {
213+
child_type = &arrow_type->GetTypeInfo<ArrowStructInfo>().GetChild(struct_filter.child_idx);
214+
}
215+
return TransformFilter(*struct_filter.child_filter, std::move(column_path), backend, child_type,
216+
timezone_config);
217+
}
218+
case TableFilterType::OPTIONAL_FILTER: {
219+
auto &optional_filter = filter.Cast<OptionalFilter>();
220+
if (!optional_filter.child_filter) {
221+
return py::none();
222+
}
223+
try {
224+
return TransformFilter(*optional_filter.child_filter, column_path, backend, arrow_type, timezone_config);
225+
} catch (const NotImplementedException &) {
226+
return py::none();
227+
}
228+
}
229+
case TableFilterType::IN_FILTER: {
230+
auto &in_filter = filter.Cast<InFilter>();
231+
auto col = backend.MakeColumnRef(column_path);
232+
// The column's logical type for IN comes from the values themselves
233+
// (they share the comparison type). Empty IN lists are not produced
234+
// by the optimizer so we can safely index values[0].
235+
LogicalType col_logical_type =
236+
in_filter.values.empty() ? LogicalType::SQLNULL : in_filter.values.front().type();
237+
return backend.IsIn(std::move(col), in_filter.values, col_logical_type, timezone_config);
238+
}
239+
case TableFilterType::DYNAMIC_FILTER:
240+
return py::none();
241+
case TableFilterType::EXPRESSION_FILTER: {
242+
auto &expression_filter = filter.Cast<ExpressionFilter>();
243+
return TransformExpression(*expression_filter.expr, column_path, backend, arrow_type, timezone_config);
244+
}
245+
default:
246+
throw NotImplementedException("Pushdown Filter Type %s is not currently supported in arrow scans",
247+
EnumUtil::ToString(filter.filter_type));
248+
}
249+
}
250+
251+
} // namespace duckdb

0 commit comments

Comments
 (0)