Skip to content

Commit 9f91638

Browse files
committed
fix
Signed-off-by: Mikhail Kot <to@myrrc.dev>
1 parent 2ee2033 commit 9f91638

4 files changed

Lines changed: 93 additions & 21 deletions

File tree

vortex-duckdb/cpp/include/duckdb_vx/table_function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void duckdb_vx_string_map_insert(duckdb_vx_string_map map, const char *key, cons
4242

4343
// Input data passed into the init_global and init_local callbacks.
4444
typedef struct {
45-
const void *bind_data;
45+
void *bind_data;
4646

4747
/**
4848
* Projected columns that are requested to be read. These are not

vortex-duckdb/cpp/table_function.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,52 @@ static Value &UnwrapValue(duckdb_value value) {
9494
return *(reinterpret_cast<Value *>(value));
9595
}
9696

97+
// For boolean or integral types, derive distinct count from min/max pair.
98+
idx_t integer_distinct(LogicalTypeId id, const Value &min, const Value &max) {
99+
switch (id) {
100+
case LogicalTypeId::BOOLEAN:
101+
return 1 + max.GetValueUnsafe<bool>() - min.GetValueUnsafe<bool>();
102+
case LogicalTypeId::UTINYINT:
103+
return 1 + max.GetValueUnsafe<uint8_t>() - min.GetValueUnsafe<uint8_t>();
104+
case LogicalTypeId::USMALLINT:
105+
return 1 + max.GetValueUnsafe<uint16_t>() - min.GetValueUnsafe<uint16_t>();
106+
case LogicalTypeId::UINTEGER:
107+
return 1 + max.GetValueUnsafe<uint32_t>() - min.GetValueUnsafe<uint32_t>();
108+
case LogicalTypeId::UBIGINT:
109+
return 1 + max.GetValueUnsafe<uint64_t>() - min.GetValueUnsafe<uint64_t>();
110+
case LogicalTypeId::TINYINT:
111+
return 1 + abs(max.GetValueUnsafe<int8_t>() - min.GetValueUnsafe<int8_t>());
112+
case LogicalTypeId::SMALLINT:
113+
return 1 + abs(max.GetValueUnsafe<int16_t>() - min.GetValueUnsafe<int16_t>());
114+
case LogicalTypeId::INTEGER:
115+
return 1 + labs(max.GetValueUnsafe<int32_t>() - min.GetValueUnsafe<int32_t>());
116+
case LogicalTypeId::BIGINT:
117+
return 1 + llabs(max.GetValueUnsafe<int64_t>() - min.GetValueUnsafe<int64_t>());
118+
// Don't estimate distinct for huge ints since result may not fit in u64.
119+
default:
120+
return 0;
121+
}
122+
}
123+
97124
unique_ptr<BaseStatistics> numeric_stats(duckdb_column_statistics &stats, LogicalType type) {
98125
BaseStatistics out = StringStats::CreateUnknown(type);
99-
if (stats.min) {
126+
if (stats.min && stats.max) {
127+
const Value &min = UnwrapValue(stats.min);
128+
NumericStats::SetMin(out, min);
129+
130+
const Value &max = UnwrapValue(stats.max);
131+
NumericStats::SetMax(out, max);
132+
133+
if (const idx_t distinct = integer_distinct(type.id(), min, max); distinct > 0) {
134+
out.SetDistinctCount(distinct);
135+
}
136+
137+
duckdb_destroy_value(&stats.min);
138+
duckdb_destroy_value(&stats.max);
139+
} else if (stats.min) {
100140
NumericStats::SetMin(out, UnwrapValue(stats.min));
101141
duckdb_destroy_value(&stats.min);
102-
}
103-
if (stats.max) {
142+
} else if (stats.max) {
104143
NumericStats::SetMax(out, UnwrapValue(stats.max));
105144
duckdb_destroy_value(&stats.max);
106145
}
@@ -112,14 +151,26 @@ unique_ptr<BaseStatistics> numeric_stats(duckdb_column_statistics &stats, Logica
112151

113152
unique_ptr<BaseStatistics> string_stats(duckdb_column_statistics &stats, LogicalType type) {
114153
BaseStatistics out = StringStats::CreateUnknown(type);
115-
if (stats.min) {
154+
if (stats.min && stats.max) {
155+
const std::string &min = StringValue::Get(UnwrapValue(stats.min));
156+
StringStats::SetMin(out, min);
157+
duckdb_destroy_value(&stats.min);
158+
159+
const std::string &max = StringValue::Get(UnwrapValue(stats.max));
160+
StringStats::SetMax(out, max);
161+
duckdb_destroy_value(&stats.max);
162+
163+
if (min == max) {
164+
out.SetDistinctCount(1);
165+
}
166+
} else if (stats.min) {
116167
StringStats::SetMin(out, StringValue::Get(UnwrapValue(stats.min)));
117168
duckdb_destroy_value(&stats.min);
118-
}
119-
if (stats.max) {
169+
} else if (stats.max) {
120170
StringStats::SetMax(out, StringValue::Get(UnwrapValue(stats.max)));
121171
duckdb_destroy_value(&stats.max);
122172
}
173+
123174
if (stats.max_string_length >> 63) {
124175
StringStats::SetMaxStringLength(out, uint32_t(stats.max_string_length));
125176
}

vortex-duckdb/src/datasource.rs

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//! to get a blanket [`TableFunction`] implementation covering init, scan, progress, filter
88
//! pushdown, cardinality, and partitioning.
99
10+
use std::cmp::max;
1011
use std::fmt::Debug;
1112
use std::ops::Range;
1213
use std::sync::Arc;
@@ -76,6 +77,7 @@ use crate::duckdb::DuckdbStringMapRef;
7677
use crate::duckdb::ExpressionRef;
7778
use crate::duckdb::LogicalType;
7879
use crate::duckdb::PartitionData;
80+
use crate::duckdb::TableFilterClass;
7981
use crate::duckdb::TableFilterSetRef;
8082
use crate::duckdb::TableFunction;
8183
use crate::duckdb::TableInitInput;
@@ -120,6 +122,7 @@ pub struct DataSourceBindData {
120122
data_source: Arc<MultiLayoutDataSource>,
121123
filter_exprs: Vec<Expression>,
122124
column_fields: Vec<DuckdbField>,
125+
has_non_optional_filter: bool,
123126
}
124127

125128
impl Clone for DataSourceBindData {
@@ -129,6 +132,7 @@ impl Clone for DataSourceBindData {
129132
// filter_exprs are consumed once in `init_global`.
130133
filter_exprs: vec![],
131134
column_fields: self.column_fields.clone(),
135+
has_non_optional_filter: self.has_non_optional_filter,
132136
}
133137
}
134138
}
@@ -254,6 +258,20 @@ impl ColumnStatisticsAggregate {
254258
}
255259
}
256260

261+
// Duckdb requires post-filter cardinality estimates, otherwise join
262+
// planner may flip join sides which is a huge regression for some
263+
// queries i.e. 1000x for tpcds 85.
264+
//
265+
// See duckdb/src/optimizer/join_order/relation_statistics_helper.cpp
266+
const DEFAULT_SELECTIVITY: f64 = 0.2;
267+
fn postfilter_cardinality(cardinality: u64, has_non_optional_filter: bool) -> u64 {
268+
if has_non_optional_filter {
269+
max(1, (cardinality as f64 * DEFAULT_SELECTIVITY) as u64)
270+
} else {
271+
cardinality
272+
}
273+
}
274+
257275
impl<T: DataSourceTableFunction> TableFunction for T {
258276
type BindData = DataSourceBindData;
259277
type GlobalState = DataSourceGlobal;
@@ -277,6 +295,7 @@ impl<T: DataSourceTableFunction> TableFunction for T {
277295
data_source: Arc::new(data_source),
278296
filter_exprs: vec![],
279297
column_fields,
298+
has_non_optional_filter: false,
280299
})
281300
}
282301

@@ -299,13 +318,15 @@ impl<T: DataSourceTableFunction> TableFunction for T {
299318
row_range,
300319
file_selection,
301320
file_range,
321+
has_non_optional_filter,
302322
} = extract_table_filter_expr(
303323
init_input.table_filter_set(),
304324
column_ids,
305325
&bind_data.column_fields,
306326
&bind_data.filter_exprs,
307327
bind_data.data_source.dtype(),
308328
)?;
329+
bind_data.has_non_optional_filter = has_non_optional_filter;
309330

310331
let filter_expr_str = filter
311332
.as_ref()
@@ -504,17 +525,9 @@ impl<T: DataSourceTableFunction> TableFunction for T {
504525
let Some(expr) = try_from_bound_expression(expr)? else {
505526
return Ok(false);
506527
};
507-
bind_data.filter_exprs.push(expr);
508528

509-
// NOTE(ngates): Vortex does indeed run exact filters, so in theory we should return `true`
510-
// here to tell DuckDB we've handled the filter. However, DuckDB applies some crude
511-
// cardinality estimation heuristics (e.g. an equality filter => 20% selectivity) that
512-
// means by returning false, DuckDB runs an additional filter (a little bit of overhead)
513-
// but tends to end up with a better query plan.
514-
// If we plumb row count estimation into the layout tree, perhaps we could use zone maps
515-
// etc. to return estimates. But this function is probably called too late anyway. Maybe
516-
// we need our own cardinality heuristics.
517-
Ok(false)
529+
bind_data.filter_exprs.push(expr);
530+
Ok(true)
518531
}
519532

520533
/// Get column-wise statistics. Available only if we're reading a single
@@ -542,8 +555,10 @@ impl<T: DataSourceTableFunction> TableFunction for T {
542555

543556
fn cardinality(bind_data: &Self::BindData) -> Cardinality {
544557
match bind_data.data_source.row_count() {
545-
Some(Precision::Exact(v)) => Cardinality::Maximum(v),
546-
Some(Precision::Inexact(v)) => Cardinality::Estimate(v),
558+
Some(Precision::Exact(v) | Precision::Inexact(v)) => {
559+
// Post-filter estimate is always a heuristic.
560+
Cardinality::Estimate(postfilter_cardinality(v, bind_data.has_non_optional_filter))
561+
}
547562
None => Cardinality::Unknown,
548563
}
549564
}
@@ -687,6 +702,7 @@ struct FilterWithVirtualColumns {
687702
row_range: Option<Range<u64>>,
688703
file_selection: Selection,
689704
file_range: Option<Range<u64>>,
705+
has_non_optional_filter: bool,
690706
}
691707

692708
/// Creates a table filter expression, row selection, and row range from the table filter set,
@@ -698,6 +714,8 @@ fn extract_table_filter_expr(
698714
additional_filters: &[Expression],
699715
dtype: &DType,
700716
) -> VortexResult<FilterWithVirtualColumns> {
717+
let mut has_non_optional_filter = false;
718+
701719
let mut table_filter_exprs: HashSet<Expression> = if let Some(filter) = table_filter_set {
702720
filter
703721
.into_iter()
@@ -706,6 +724,8 @@ fn extract_table_filter_expr(
706724
!is_virtual_column(column_ids[idx_u])
707725
})
708726
.map(|(idx, ex)| {
727+
has_non_optional_filter |= !matches!(ex.as_class(), TableFilterClass::Optional(_));
728+
709729
let idx_u: usize = idx.as_();
710730
let col_idx: usize = column_ids[idx_u].as_();
711731
let name = &column_fields.get(col_idx).vortex_expect("exists").name;
@@ -741,6 +761,7 @@ fn extract_table_filter_expr(
741761
row_range,
742762
file_selection,
743763
file_range,
764+
has_non_optional_filter,
744765
};
745766
Ok(out)
746767
}

vortex-duckdb/src/duckdb/table_function/init.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ impl<'a, T: TableFunction> TableInitInput<'a, T> {
7575
}
7676

7777
/// Returns the bind data for the table function.
78-
pub fn bind_data(&self) -> &T::BindData {
79-
unsafe { &*self.input.bind_data.cast::<T::BindData>() }
78+
pub fn bind_data(&self) -> &mut T::BindData {
79+
unsafe { &mut *self.input.bind_data.cast::<T::BindData>() }
8080
}
8181

8282
pub fn column_ids(&self) -> &[u64] {

0 commit comments

Comments
 (0)