Skip to content

Commit c7fd89e

Browse files
committed
initial
Signed-off-by: Mikhail Kot <mikhail@spiraldb.com>
1 parent 201661f commit c7fd89e

20 files changed

Lines changed: 891 additions & 254 deletions

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-duckdb/build.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ const DEFAULT_DUCKDB_VERSION: &str = "1.5.3";
2727

2828
const BUILD_ARTIFACTS: [&str; 3] = ["libduckdb.dylib", "libduckdb.so", "libduckdb_static.a"];
2929

30-
const SOURCE_FILES: [&str; 7] = [
30+
const SOURCE_FILES: [&str; 9] = [
3131
"cpp/vortex_duckdb.cpp",
3232
"cpp/copy_function.cpp",
3333
"cpp/expr.cpp",
34+
"cpp/optimizer.cpp",
3435
"cpp/scalar_fn_pushdown.cpp",
36+
"cpp/cast_pushdown.cpp",
3537
"cpp/table_filter.cpp",
3638
"cpp/table_function.cpp",
3739
"cpp/vector.cpp",
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
#include "cast_pushdown.hpp"
4+
#include "table_function.hpp"
5+
6+
#include "duckdb/planner/operator/logical_get.hpp"
7+
#include "duckdb/planner/operator/logical_projection.hpp"
8+
#include "duckdb/planner/expression/bound_cast_expression.hpp"
9+
#include "duckdb/planner/expression/bound_columnref_expression.hpp"
10+
11+
// A GET reachable through a single-child chain of filters/projections. A join
12+
// (or any other multi-child operator) breaks the chain.
13+
// See test/sql/copy/csv/test_insert_into_types.test in duckdb (cast not pushed past a join)
14+
static bool ReachesPushdownGet(const LogicalOperator &op) {
15+
const LogicalOperator *cur = &op;
16+
while (cur->children.size() == 1) {
17+
cur = cur->children[0].get();
18+
switch (cur->type) {
19+
case LogicalOperatorType::LOGICAL_GET:
20+
return cur->Cast<LogicalGet>().function.bind == duckdb_vx_table_function_bind;
21+
case LogicalOperatorType::LOGICAL_FILTER:
22+
case LogicalOperatorType::LOGICAL_PROJECTION:
23+
continue;
24+
default:
25+
return false;
26+
}
27+
}
28+
return false;
29+
}
30+
31+
void CastCollect::VisitOperator(LogicalOperator &op) {
32+
/*
33+
* Logical projection expressions are columns which reference underlying
34+
* GETs. Don't process them, as they would add conflicts for every column
35+
* used in projection. Example: PROJECTION(col) -> GET(col). We don't want
36+
* to visit BoundColumnRefExpression in PROJECTION to avoid registering a
37+
* non-existent conflict.
38+
*
39+
* However, CastReplace will visit them because we need to update their
40+
* types if pushdown succeeded.
41+
*/
42+
if (op.type != LogicalOperatorType::LOGICAL_PROJECTION) {
43+
return LogicalOperatorVisitor::VisitOperator(op);
44+
}
45+
auto &projection = op.Cast<LogicalProjection>();
46+
47+
// Only push casts from a projection that forwards just column refs and
48+
// casts and reaches a GET without a join in between. A constant or other
49+
// expression makes the projection ineligible.
50+
// See test/sql/copy/csv/test_csv_error_message_type.test (top-level cast
51+
// to VARCHAR must still push) and test_large_integer_detection.test (a
52+
// nested cast to VARCHAR must not) in duckdb.
53+
bool clean = ReachesPushdownGet(projection);
54+
for (const auto &e : projection.expressions) {
55+
switch (e->GetExpressionClass()) {
56+
case ExpressionClass::BOUND_COLUMN_REF:
57+
case ExpressionClass::BOUND_CAST:
58+
continue;
59+
default:
60+
clean = false;
61+
break;
62+
}
63+
}
64+
if (clean) {
65+
for (const auto &e : projection.expressions) {
66+
if (e->GetExpressionClass() == ExpressionClass::BOUND_CAST) {
67+
top_level_casts.insert(e.get());
68+
}
69+
}
70+
}
71+
if (projections.count(projection.table_index)) {
72+
VisitOperatorChildren(op);
73+
return;
74+
}
75+
76+
LogicalOperatorVisitor::VisitOperator(op);
77+
}
78+
79+
ExpressionPtr CastCollect::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) {
80+
if (const auto binding = Resolve(expr.binding, analyses, projections)) {
81+
// Column is used without cast applied to it, register a conflict.
82+
// Not emplace() as we need to update the value if it was present
83+
binding->analysis.col_to_expr[binding->column_index] = nullptr;
84+
}
85+
return std::move(*ptr);
86+
}
87+
88+
ExpressionPtr CastCollect::VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) {
89+
if (expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
90+
// Descend into children so e.g. fn(col, other) still sees "col" and
91+
// registers a conflict
92+
return nullptr;
93+
}
94+
const auto &bound_col = expr.child->Cast<BoundColumnRefExpression>();
95+
const auto binding = Resolve(bound_col.binding, analyses, projections);
96+
if (!binding) {
97+
return nullptr;
98+
}
99+
auto &col_to_expr = binding->analysis.col_to_expr;
100+
101+
if (auto it = col_to_expr.find(binding->column_index); it == col_to_expr.end()) {
102+
// Only a top-level projection cast starts a candidate.
103+
if (top_level_casts.count(&expr)) {
104+
col_to_expr.emplace(binding->column_index, &expr);
105+
}
106+
} else if (it->second == nullptr ||
107+
it->second->Cast<BoundCastExpression>().return_type != expr.return_type) {
108+
// Different target type, or already a conflict.
109+
it->second = nullptr;
110+
}
111+
112+
return std::move(*ptr);
113+
}
114+
115+
ExpressionPtr CastReplace::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) {
116+
const auto binding = Resolve(expr.binding, analyses, projections);
117+
if (!binding) {
118+
return std::move(*ptr);
119+
}
120+
121+
const auto &[analysis, column_index, projection] = *binding;
122+
if (CanPushdownColumn(analysis, column_index)) {
123+
const idx_t storage_index = analysis.get.GetColumnIds()[column_index].GetPrimaryIndex();
124+
const LogicalType return_type = analysis.get.returned_types[storage_index];
125+
expr.return_type = return_type;
126+
// LogicalProjection types are resolved by calling
127+
// LogicalProjection::ResolveTypes, so we need to check whether types in
128+
// projection have been resolved, and updated them only if needed.
129+
if (projection != nullptr && !projection->types.empty()) {
130+
projection->types[column_index] = return_type;
131+
}
132+
}
133+
134+
return std::move(*ptr);
135+
}
136+
137+
ExpressionPtr CastReplace::VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) {
138+
if (expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
139+
return nullptr; // Same as in ScalarFnCollect::VisitReplace
140+
}
141+
auto &bound_col_base = expr.child;
142+
const auto &bound_col = bound_col_base->Cast<BoundColumnRefExpression>();
143+
const auto binding = Resolve(bound_col.binding, analyses, projections);
144+
if (!binding) {
145+
return nullptr;
146+
}
147+
148+
const auto &[analysis, column_index, projection] = *binding;
149+
if (!CanPushdownColumn(analysis, column_index)) {
150+
return std::move(*ptr);
151+
}
152+
153+
const idx_t storage_index = analysis.get.GetColumnIds()[column_index].GetPrimaryIndex();
154+
const LogicalType return_type = analysis.get.returned_types[storage_index];
155+
bound_col_base->return_type = return_type;
156+
// Same as in CastReplace::VisitReplace(BoundColumnRefExpression)
157+
if (projection != nullptr && !projection->types.empty()) {
158+
projection->types[column_index] = return_type;
159+
}
160+
return std::move(bound_col_base);
161+
}
162+
163+
CastCollect::CastCollect(Analyses &analyses, const Projections &projections)
164+
: analyses(analyses), projections(projections) {
165+
}
166+
167+
CastReplace::CastReplace(Analyses &analyses, const Projections &projections)
168+
: analyses(analyses), projections(projections) {
169+
}

vortex-duckdb/cpp/expr.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
#include "expr.h"
5+
#include "duckdb/common/type_visitor.hpp"
56
#include "duckdb/function/scalar_function.hpp"
67
#include "duckdb/planner/expression/bound_between_expression.hpp"
8+
#include "duckdb/planner/expression/bound_cast_expression.hpp"
79
#include "duckdb/planner/expression/bound_columnref_expression.hpp"
810
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
911
#include "duckdb/planner/expression/bound_constant_expression.hpp"
@@ -129,3 +131,17 @@ extern "C" void duckdb_vx_expr_get_bound_function(duckdb_vx_expr ffi_expr,
129131
out->scalar_function = reinterpret_cast<duckdb_vx_sfunc>(&expr.function);
130132
out->bind_info = expr.bind_info.get();
131133
}
134+
135+
extern "C" duckdb_vx_expr duckdb_vx_expr_get_bound_cast_child(duckdb_vx_expr ffi_expr) {
136+
D_ASSERT(ffi_expr);
137+
auto &expr = reinterpret_cast<Expression *>(ffi_expr)->Cast<BoundCastExpression>();
138+
return reinterpret_cast<duckdb_vx_expr>(expr.child.get());
139+
}
140+
141+
extern "C" bool duckdb_vx_logical_type_contains_128bit(duckdb_logical_type ffi_type) {
142+
D_ASSERT(ffi_type);
143+
auto &type = *reinterpret_cast<LogicalType *>(ffi_type);
144+
return TypeVisitor::Contains(type, [](const LogicalType &t) {
145+
return t.id() == LogicalTypeId::HUGEINT || t.id() == LogicalTypeId::UHUGEINT;
146+
});
147+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
#pragma once
4+
#include "optimizer.hpp"
5+
6+
#include "duckdb/common/unordered_set.hpp"
7+
#include "duckdb/main/client_context.hpp"
8+
#include "duckdb/planner/expression.hpp"
9+
#include "duckdb/planner/logical_operator.hpp"
10+
11+
using namespace duckdb;
12+
13+
/**
14+
* Collect CAST(col) expressions. If "col" is used without CAST in "plan",
15+
* record in "analyses.conflicts"
16+
*/
17+
struct CastCollect final : LogicalOperatorVisitor {
18+
Analyses &analyses;
19+
const Projections &projections;
20+
// Casts that are direct outputs of a clean projection over a GET. Only these
21+
// start a pushdown candidate; a nested cast may push down a different value.
22+
// See test/sql/copy/csv/auto/test_large_integer_detection.test in duckdb
23+
unordered_set<const Expression *> top_level_casts;
24+
25+
CastCollect(Analyses &analyses, const Projections &projections);
26+
void VisitOperator(LogicalOperator &op) override;
27+
ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override;
28+
ExpressionPtr VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) override;
29+
};
30+
31+
/*
32+
* For "col" in columns collected by ScalarFnCollect, replace CAST(col) to "col"
33+
* if "col" doesn't have conflicting usage. Update return types for bound
34+
* columns and logical projections referencing this column.
35+
*/
36+
struct CastReplace final : LogicalOperatorVisitor {
37+
Analyses &analyses;
38+
const Projections &projections;
39+
40+
CastReplace(Analyses &analyses, const Projections &aliases);
41+
ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override;
42+
ExpressionPtr VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) override;
43+
};

vortex-duckdb/cpp/include/expr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ typedef struct {
264264

265265
void duckdb_vx_expr_get_bound_function(duckdb_vx_expr expr, duckdb_vx_expr_bound_function *out);
266266

267+
duckdb_vx_expr duckdb_vx_expr_get_bound_cast_child(duckdb_vx_expr expr);
268+
269+
// Check if type or contained types i.e. List(T) contains HUGEINT/UHUGEINT
270+
// These are not present in DType so we can't convert.
271+
bool duckdb_vx_logical_type_contains_128bit(duckdb_logical_type type);
272+
267273
#ifdef __cplusplus /* End C ABI */
268274
}
269275
#endif

0 commit comments

Comments
 (0)