|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | +#include "cast_pushdown.hpp" |
| 4 | +#include "table_function.hpp" |
| 5 | + |
| 6 | +#include "duckdb/planner/operator/logical_get.hpp" |
| 7 | +#include "duckdb/planner/operator/logical_projection.hpp" |
| 8 | +#include "duckdb/planner/expression/bound_cast_expression.hpp" |
| 9 | +#include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 10 | + |
| 11 | +// A GET reachable through a single-child chain of filters/projections. A join |
| 12 | +// (or any other multi-child operator) breaks the chain. |
| 13 | +// See test/sql/copy/csv/test_insert_into_types.test in duckdb (cast not pushed past a join) |
| 14 | +static bool ReachesPushdownGet(const LogicalOperator &op) { |
| 15 | + const LogicalOperator *cur = &op; |
| 16 | + while (cur->children.size() == 1) { |
| 17 | + cur = cur->children[0].get(); |
| 18 | + switch (cur->type) { |
| 19 | + case LogicalOperatorType::LOGICAL_GET: |
| 20 | + return cur->Cast<LogicalGet>().function.bind == duckdb_vx_table_function_bind; |
| 21 | + case LogicalOperatorType::LOGICAL_FILTER: |
| 22 | + case LogicalOperatorType::LOGICAL_PROJECTION: |
| 23 | + continue; |
| 24 | + default: |
| 25 | + return false; |
| 26 | + } |
| 27 | + } |
| 28 | + return false; |
| 29 | +} |
| 30 | + |
| 31 | +void CastCollect::VisitOperator(LogicalOperator &op) { |
| 32 | + /* |
| 33 | + * Logical projection expressions are columns which reference underlying |
| 34 | + * GETs. Don't process them, as they would add conflicts for every column |
| 35 | + * used in projection. Example: PROJECTION(col) -> GET(col). We don't want |
| 36 | + * to visit BoundColumnRefExpression in PROJECTION to avoid registering a |
| 37 | + * non-existent conflict. |
| 38 | + * |
| 39 | + * However, CastReplace will visit them because we need to update their |
| 40 | + * types if pushdown succeeded. |
| 41 | + */ |
| 42 | + if (op.type != LogicalOperatorType::LOGICAL_PROJECTION) { |
| 43 | + return LogicalOperatorVisitor::VisitOperator(op); |
| 44 | + } |
| 45 | + auto &projection = op.Cast<LogicalProjection>(); |
| 46 | + |
| 47 | + // Only push casts from a projection that forwards just column refs and |
| 48 | + // casts and reaches a GET without a join in between. A constant or other |
| 49 | + // expression makes the projection ineligible. |
| 50 | + // See test/sql/copy/csv/test_csv_error_message_type.test (top-level cast |
| 51 | + // to VARCHAR must still push) and test_large_integer_detection.test (a |
| 52 | + // nested cast to VARCHAR must not) in duckdb. |
| 53 | + bool clean = ReachesPushdownGet(projection); |
| 54 | + for (const auto &e : projection.expressions) { |
| 55 | + switch (e->GetExpressionClass()) { |
| 56 | + case ExpressionClass::BOUND_COLUMN_REF: |
| 57 | + case ExpressionClass::BOUND_CAST: |
| 58 | + continue; |
| 59 | + default: |
| 60 | + clean = false; |
| 61 | + break; |
| 62 | + } |
| 63 | + } |
| 64 | + if (clean) { |
| 65 | + for (const auto &e : projection.expressions) { |
| 66 | + if (e->GetExpressionClass() == ExpressionClass::BOUND_CAST) { |
| 67 | + top_level_casts.insert(e.get()); |
| 68 | + } |
| 69 | + } |
| 70 | + } |
| 71 | + if (projections.count(projection.table_index)) { |
| 72 | + VisitOperatorChildren(op); |
| 73 | + return; |
| 74 | + } |
| 75 | + |
| 76 | + LogicalOperatorVisitor::VisitOperator(op); |
| 77 | +} |
| 78 | + |
| 79 | +ExpressionPtr CastCollect::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { |
| 80 | + if (const auto binding = Resolve(expr.binding, analyses, projections)) { |
| 81 | + // Column is used without cast applied to it, register a conflict. |
| 82 | + // Not emplace() as we need to update the value if it was present |
| 83 | + binding->analysis.col_to_expr[binding->column_index] = nullptr; |
| 84 | + } |
| 85 | + return std::move(*ptr); |
| 86 | +} |
| 87 | + |
| 88 | +ExpressionPtr CastCollect::VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) { |
| 89 | + if (expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { |
| 90 | + // Descend into children so e.g. fn(col, other) still sees "col" and |
| 91 | + // registers a conflict |
| 92 | + return nullptr; |
| 93 | + } |
| 94 | + const auto &bound_col = expr.child->Cast<BoundColumnRefExpression>(); |
| 95 | + const auto binding = Resolve(bound_col.binding, analyses, projections); |
| 96 | + if (!binding) { |
| 97 | + return nullptr; |
| 98 | + } |
| 99 | + auto &col_to_expr = binding->analysis.col_to_expr; |
| 100 | + |
| 101 | + if (auto it = col_to_expr.find(binding->column_index); it == col_to_expr.end()) { |
| 102 | + // Only a top-level projection cast starts a candidate. |
| 103 | + if (top_level_casts.count(&expr)) { |
| 104 | + col_to_expr.emplace(binding->column_index, &expr); |
| 105 | + } |
| 106 | + } else if (it->second == nullptr || |
| 107 | + it->second->Cast<BoundCastExpression>().return_type != expr.return_type) { |
| 108 | + // Different target type, or already a conflict. |
| 109 | + it->second = nullptr; |
| 110 | + } |
| 111 | + |
| 112 | + return std::move(*ptr); |
| 113 | +} |
| 114 | + |
| 115 | +ExpressionPtr CastReplace::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { |
| 116 | + const auto binding = Resolve(expr.binding, analyses, projections); |
| 117 | + if (!binding) { |
| 118 | + return std::move(*ptr); |
| 119 | + } |
| 120 | + |
| 121 | + const auto &[analysis, column_index, projection] = *binding; |
| 122 | + if (CanPushdownColumn(analysis, column_index)) { |
| 123 | + const idx_t storage_index = analysis.get.GetColumnIds()[column_index].GetPrimaryIndex(); |
| 124 | + const LogicalType return_type = analysis.get.returned_types[storage_index]; |
| 125 | + expr.return_type = return_type; |
| 126 | + // LogicalProjection types are resolved by calling |
| 127 | + // LogicalProjection::ResolveTypes, so we need to check whether types in |
| 128 | + // projection have been resolved, and updated them only if needed. |
| 129 | + if (projection != nullptr && !projection->types.empty()) { |
| 130 | + projection->types[column_index] = return_type; |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + return std::move(*ptr); |
| 135 | +} |
| 136 | + |
| 137 | +ExpressionPtr CastReplace::VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) { |
| 138 | + if (expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { |
| 139 | + return nullptr; // Same as in ScalarFnCollect::VisitReplace |
| 140 | + } |
| 141 | + auto &bound_col_base = expr.child; |
| 142 | + const auto &bound_col = bound_col_base->Cast<BoundColumnRefExpression>(); |
| 143 | + const auto binding = Resolve(bound_col.binding, analyses, projections); |
| 144 | + if (!binding) { |
| 145 | + return nullptr; |
| 146 | + } |
| 147 | + |
| 148 | + const auto &[analysis, column_index, projection] = *binding; |
| 149 | + if (!CanPushdownColumn(analysis, column_index)) { |
| 150 | + return std::move(*ptr); |
| 151 | + } |
| 152 | + |
| 153 | + const idx_t storage_index = analysis.get.GetColumnIds()[column_index].GetPrimaryIndex(); |
| 154 | + const LogicalType return_type = analysis.get.returned_types[storage_index]; |
| 155 | + bound_col_base->return_type = return_type; |
| 156 | + // Same as in CastReplace::VisitReplace(BoundColumnRefExpression) |
| 157 | + if (projection != nullptr && !projection->types.empty()) { |
| 158 | + projection->types[column_index] = return_type; |
| 159 | + } |
| 160 | + return std::move(bound_col_base); |
| 161 | +} |
| 162 | + |
| 163 | +CastCollect::CastCollect(Analyses &analyses, const Projections &projections) |
| 164 | + : analyses(analyses), projections(projections) { |
| 165 | +} |
| 166 | + |
| 167 | +CastReplace::CastReplace(Analyses &analyses, const Projections &projections) |
| 168 | + : analyses(analyses), projections(projections) { |
| 169 | +} |
0 commit comments