Skip to content

Commit 5b2ddf5

Browse files
committed
better
1 parent 1498305 commit 5b2ddf5

3 files changed

Lines changed: 234 additions & 148 deletions

File tree

vortex-duckdb/cpp/optimizer.cpp

Lines changed: 162 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -18,135 +18,181 @@ using namespace duckdb;
1818
*/
1919

2020
// Collect CAST(bound_column, T) patterns where bound_column binds into given GET's index.
21-
void CollectCastTypes(const Expression &expr, idx_t index, const vector<ColumnIndex> &column_ids,
22-
unordered_map<column_t, LogicalType> &cast_map, unordered_set<column_t> &conflicts) {
23-
auto collect_children = [&] {
24-
ExpressionIterator::EnumerateChildren(
25-
expr, [&](const Expression &child) { CollectCastTypes(child, index, column_ids, cast_map, conflicts); });
26-
};
27-
28-
if (expr.GetExpressionClass() != ExpressionClass::BOUND_CAST) {
29-
return collect_children();
30-
}
31-
auto &bound_cast = expr.Cast<BoundCastExpression>();
32-
33-
if (bound_cast.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
34-
return collect_children();
35-
}
36-
auto &bound_column = bound_cast.child->Cast<BoundColumnRefExpression>();
37-
38-
if (bound_column.depth > 0 || bound_column.binding.table_index != index) {
39-
return collect_children();
40-
}
41-
42-
// We are in a leaf
43-
const column_t projection_id = bound_column.binding.column_index;
44-
if (IsVirtualColumn(projection_id)) {
45-
return;
46-
}
47-
D_ASSERT(projection_id < column_ids.size());
48-
const column_t column_id = column_ids[projection_id].GetPrimaryIndex();
49-
if (auto it = cast_map.find(column_id); it == cast_map.end()) {
50-
cast_map.emplace(column_id, bound_cast.return_type);
51-
} else if (it->second != bound_cast.return_type) {
52-
conflicts.insert(column_id);
53-
}
21+
// A bare bound_column ref (outside any CAST) is recorded as a conflict: the column is
22+
// consumed at its original type and its scan type must not change.
23+
static void CollectCastTypes(const Expression &expr,
24+
idx_t index,
25+
const vector<ColumnIndex> &column_ids,
26+
unordered_map<column_t, LogicalType> &cast_map,
27+
unordered_set<column_t> &conflicts) {
28+
auto collect_children = [&] {
29+
ExpressionIterator::EnumerateChildren(expr, [&](const Expression &child) {
30+
CollectCastTypes(child, index, column_ids, cast_map, conflicts);
31+
});
32+
};
33+
34+
// Bare column ref pointing to this GET: the column is used at its original type.
35+
if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) {
36+
auto &colref = expr.Cast<BoundColumnRefExpression>();
37+
if (colref.depth == 0 && colref.binding.table_index == index) {
38+
const column_t proj_id = colref.binding.column_index;
39+
if (!IsVirtualColumn(proj_id) && proj_id < column_ids.size()) {
40+
conflicts.insert(column_ids[proj_id].GetPrimaryIndex());
41+
}
42+
}
43+
return;
44+
}
45+
46+
if (expr.GetExpressionClass() != ExpressionClass::BOUND_CAST) {
47+
return collect_children();
48+
}
49+
auto &bound_cast = expr.Cast<BoundCastExpression>();
50+
51+
if (bound_cast.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
52+
return collect_children();
53+
}
54+
auto &bound_column = bound_cast.child->Cast<BoundColumnRefExpression>();
55+
56+
if (bound_column.depth > 0 || bound_column.binding.table_index != index) {
57+
return collect_children();
58+
}
59+
60+
// We are in a leaf: CAST(colref, T) where colref binds into this GET.
61+
const column_t projection_id = bound_column.binding.column_index;
62+
if (IsVirtualColumn(projection_id)) {
63+
return;
64+
}
65+
D_ASSERT(projection_id < column_ids.size());
66+
const column_t column_id = column_ids[projection_id].GetPrimaryIndex();
67+
if (auto it = cast_map.find(column_id); it == cast_map.end()) {
68+
cast_map.emplace(column_id, bound_cast.return_type);
69+
} else if (it->second != bound_cast.return_type) {
70+
conflicts.insert(column_id);
71+
}
5472
}
5573

5674
// Replace every CAST(bound_column, T) with a bare bound_column at type T when T
5775
// is listed in projection_cast.
58-
static void ReplaceCastTypes(unique_ptr<Expression> &expr, idx_t index,
76+
static void ReplaceCastTypes(unique_ptr<Expression> &expr,
77+
idx_t index,
5978
const unordered_map<column_t, LogicalType> &projection_cast) {
60-
auto replace_children = [&] {
61-
ExpressionIterator::EnumerateChildren(
62-
*expr, [&](unique_ptr<Expression> &child) { ReplaceCastTypes(child, index, projection_cast); });
63-
};
64-
65-
if (expr->GetExpressionClass() != ExpressionClass::BOUND_CAST) {
66-
return replace_children();
67-
}
68-
auto &bound_cast = expr->Cast<BoundCastExpression>();
69-
70-
if (bound_cast.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
71-
return replace_children();
72-
}
73-
auto &bound_column = bound_cast.child->Cast<BoundColumnRefExpression>();
74-
75-
if (bound_column.depth > 0 || bound_column.binding.table_index != index) {
76-
return replace_children();
77-
}
78-
79-
const column_t projection_id = bound_column.binding.column_index;
80-
auto it = projection_cast.find(projection_id);
81-
if (it == projection_cast.end() || it->second != bound_cast.return_type) {
82-
return replace_children();
83-
}
84-
85-
expr = make_uniq<BoundColumnRefExpression>(it->second, bound_column.binding);
79+
auto replace_children = [&] {
80+
ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr<Expression> &child) {
81+
ReplaceCastTypes(child, index, projection_cast);
82+
});
83+
};
84+
85+
if (expr->GetExpressionClass() != ExpressionClass::BOUND_CAST) {
86+
return replace_children();
87+
}
88+
auto &bound_cast = expr->Cast<BoundCastExpression>();
89+
90+
if (bound_cast.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
91+
return replace_children();
92+
}
93+
auto &bound_column = bound_cast.child->Cast<BoundColumnRefExpression>();
94+
95+
if (bound_column.depth > 0 || bound_column.binding.table_index != index) {
96+
return replace_children();
97+
}
98+
99+
const column_t projection_id = bound_column.binding.column_index;
100+
auto it = projection_cast.find(projection_id);
101+
if (it == projection_cast.end() || it->second != bound_cast.return_type) {
102+
return replace_children();
103+
}
104+
105+
expr = make_uniq<BoundColumnRefExpression>(it->second, bound_column.binding);
106+
}
107+
108+
// Collect cast-type candidates from every operator in the plan tree.
109+
static void CollectFromPlan(LogicalOperator &op,
110+
idx_t index,
111+
const vector<ColumnIndex> &column_ids,
112+
unordered_map<column_t, LogicalType> &cast_map,
113+
unordered_set<column_t> &conflicts) {
114+
LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr<Expression> *expr_ptr) {
115+
CollectCastTypes(**expr_ptr, index, column_ids, cast_map, conflicts);
116+
});
117+
for (auto &child : op.children) {
118+
CollectFromPlan(*child, index, column_ids, cast_map, conflicts);
119+
}
86120
}
87121

88-
// Walk the plan bottom-up and, for each node whose direct child is a GET that
89-
// supports type_pushdown, push every CAST(colref, T) found in that node's
90-
// expressions into the GET so the scan produces T directly.
91-
unique_ptr<LogicalOperator> TryPushdownCastTypes(ClientContext& context, unique_ptr<LogicalOperator> op) {
92-
for (auto &child : op->children) {
93-
child = TryPushdownCastTypes(context, std::move(child));
94-
}
95-
96-
for (const auto &child : op->children) {
97-
if (child->type != LogicalOperatorType::LOGICAL_GET) {
98-
continue;
99-
}
100-
auto &get = child->Cast<LogicalGet>();
101-
if (!get.function.type_pushdown) {
102-
continue;
103-
}
104-
105-
const vector<ColumnIndex> &column_ids = get.GetColumnIds();
106-
const idx_t index = get.table_index;
107-
unordered_map<column_t, LogicalType> cast_map;
108-
unordered_set<column_t> conflicts;
109-
110-
LogicalOperatorVisitor::EnumerateExpressions(*op, [&](unique_ptr<Expression> *expr_ptr) {
111-
CollectCastTypes(**expr_ptr, index, column_ids, cast_map, conflicts);
112-
});
113-
114-
for (column_t col_id : conflicts) {
115-
cast_map.erase(col_id);
116-
}
117-
if (cast_map.empty()) {
118-
continue;
119-
}
120-
121-
get.function.type_pushdown(context, get.bind_data, cast_map);
122-
for (const auto &[col_id, new_type] : cast_map) {
123-
get.returned_types[col_id] = new_type;
124-
}
125-
126-
unordered_map<idx_t, LogicalType> proj_to_type;
127-
for (idx_t i = 0; i < column_ids.size(); i++) {
128-
const column_t col_idx = column_ids[i].GetPrimaryIndex();
129-
if (auto it = cast_map.find(col_idx); it != cast_map.end()) {
130-
proj_to_type[i] = it->second;
131-
}
132-
}
133-
134-
LogicalOperatorVisitor::EnumerateExpressions(
135-
*op, [&](unique_ptr<Expression> *expr_ptr) { ReplaceCastTypes(*expr_ptr, get.table_index, proj_to_type); });
136-
}
137-
138-
return op;
122+
// Replace cast expressions in every operator in the plan tree.
123+
static void
124+
ReplaceInPlan(LogicalOperator &op, idx_t index, const unordered_map<column_t, LogicalType> &proj_to_type) {
125+
LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr<Expression> *expr_ptr) {
126+
ReplaceCastTypes(*expr_ptr, index, proj_to_type);
127+
});
128+
for (auto &child : op.children) {
129+
ReplaceInPlan(*child, index, proj_to_type);
130+
}
131+
}
132+
133+
static void FindGetWithTypePushdown(LogicalOperator &op, vector<LogicalGet *> &gets) {
134+
if (op.type == LogicalOperatorType::LOGICAL_GET) {
135+
auto &get = op.Cast<LogicalGet>();
136+
if (get.function.type_pushdown) {
137+
gets.push_back(&get);
138+
}
139+
}
140+
for (auto &child : op.children) {
141+
FindGetWithTypePushdown(*child, gets);
142+
}
143+
}
144+
145+
// For each GET that supports type_pushdown, collect CAST(col, T) patterns from
146+
// the *entire* plan. Columns that appear bare (outside any cast) or are cast to
147+
// multiple conflicting types are excluded. The surviving types are pushed into
148+
// the GET's bind_data and returned_types, and the redundant CASTs are stripped
149+
// from all operator expressions throughout the plan.
150+
static unique_ptr<LogicalOperator> TryPushdownCastTypes(ClientContext &context,
151+
unique_ptr<LogicalOperator> plan) {
152+
vector<LogicalGet *> gets;
153+
FindGetWithTypePushdown(*plan, gets);
154+
155+
for (LogicalGet *get : gets) {
156+
const vector<ColumnIndex> &column_ids = get->GetColumnIds();
157+
const idx_t index = get->table_index;
158+
unordered_map<column_t, LogicalType> cast_map;
159+
unordered_set<column_t> conflicts;
160+
161+
CollectFromPlan(*plan, index, column_ids, cast_map, conflicts);
162+
163+
for (column_t col_id : conflicts) {
164+
cast_map.erase(col_id);
165+
}
166+
if (cast_map.empty()) {
167+
continue;
168+
}
169+
170+
get->function.type_pushdown(context, get->bind_data, cast_map);
171+
for (const auto &[col_id, new_type] : cast_map) {
172+
get->returned_types[col_id] = new_type;
173+
}
174+
175+
unordered_map<idx_t, LogicalType> proj_to_type;
176+
for (idx_t i = 0; i < column_ids.size(); i++) {
177+
const column_t col_idx = column_ids[i].GetPrimaryIndex();
178+
if (auto it = cast_map.find(col_idx); it != cast_map.end()) {
179+
proj_to_type[i] = it->second;
180+
}
181+
}
182+
183+
ReplaceInPlan(*plan, index, proj_to_type);
184+
}
185+
186+
return plan;
139187
}
140188

141189
static void VortexOptimizeFunction(OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
142190
plan = TryPushdownCastTypes(input.context, std::move(plan));
143191
}
144192

145-
class VortexOptimizerExtension final : public OptimizerExtension {
146-
public:
147-
VortexOptimizerExtension() {
148-
optimize_function = VortexOptimizeFunction;
149-
}
193+
struct VortexOptimizerExtension final : OptimizerExtension {
194+
VortexOptimizerExtension() : OptimizerExtension(VortexOptimizeFunction, nullptr, {}) {
195+
}
150196
};
151197

152198
extern "C" duckdb_state duckdb_vx_optimizer_extension_register(duckdb_database ffi_db) {

vortex-duckdb/cpp/table_function.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d
408408
tf.filter_prune = true;
409409
tf.sampling_pushdown = false;
410410

411-
//tf.type_pushdown = type_pushdown;
411+
tf.type_pushdown = type_pushdown;
412412
tf.pushdown_complex_filter = c_pushdown_complex_filter;
413413
tf.cardinality = c_cardinality;
414414
tf.get_partition_info = get_partition_info;

0 commit comments

Comments
 (0)