Skip to content

Commit 8d9b080

Browse files
authored
[Minor] propagate distinct_count as inexact through unions (#20846)
## Which issue does this PR close? Does not close but part of #20766 ## Rationale for this change As @jonathanc-n describes here is the Trino's formula about inexact formula: ``` // for unioning A + B // calculate A overlap with B using min/max statistics overlap_a = percent of overlap that A has with B overlap_b = percent of overlap that B has with A new_distinct_count = max(overlap_a * NDV_a, overlap_b * NDV_b) // find interesect + (1 - overlap_a) * NDV_a // overlap for just a + (1 - overlap_b) * NDV_b // overlap for just b ``` ## What changes are included in this PR? Instead of absent set `distinct_count` with inexact precision depending on overlaps and distinct counts ## Are these changes tested? I've added unit tests ## Are there any user-facing changes? No
1 parent 129c58f commit 8d9b080

File tree

1 file changed

+311
-2
lines changed

1 file changed

+311
-2
lines changed

datafusion/physical-plan/src/union.rs

Lines changed: 311 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ fn col_stats_union(
854854
mut left: ColumnStatistics,
855855
right: &ColumnStatistics,
856856
) -> ColumnStatistics {
857-
left.distinct_count = Precision::Absent;
857+
left.distinct_count = union_distinct_count(&left, right);
858858
left.min_value = left.min_value.min(&right.min_value);
859859
left.max_value = left.max_value.max(&right.max_value);
860860
left.sum_value = left.sum_value.add(&right.sum_value);
@@ -863,6 +863,123 @@ fn col_stats_union(
863863
left
864864
}
865865

866+
fn union_distinct_count(
867+
left: &ColumnStatistics,
868+
right: &ColumnStatistics,
869+
) -> Precision<usize> {
870+
let (ndv_left, ndv_right) = match (
871+
left.distinct_count.get_value(),
872+
right.distinct_count.get_value(),
873+
) {
874+
(Some(&l), Some(&r)) => (l, r),
875+
_ => return Precision::Absent,
876+
};
877+
878+
// Even with exact inputs, the union NDV depends on how
879+
// many distinct values are shared between the left and right.
880+
// We can only estimate this via range overlap. Thus both paths
881+
// below return `Inexact`.
882+
if let Some(ndv) = estimate_ndv_with_overlap(left, right, ndv_left, ndv_right) {
883+
return Precision::Inexact(ndv);
884+
}
885+
886+
Precision::Inexact(ndv_left + ndv_right)
887+
}
888+
889+
/// Estimates the distinct count for a union using range overlap,
890+
/// following the approach used by Trino:
891+
///
892+
/// Assumes values are distributed uniformly within each input's
893+
/// `[min, max]` range (the standard assumption when only summary
894+
/// statistics are available, classic for scalar-based statistics
895+
/// propagation). Under uniformity the fraction of an input's
896+
/// distinct values that land in a sub-range equals the fraction of
897+
/// the range that sub-range covers.
898+
///
899+
/// The combined value space is split into three disjoint regions:
900+
///
901+
/// ```text
902+
/// |-- only A --|-- overlap --|-- only B --|
903+
/// ```
904+
///
905+
/// * **Only in A/B** – values outside the other input's range
906+
/// contribute `(1 − overlap_a) · NDV_a` and `(1 − overlap_b) · NDV_b`.
907+
/// * **Overlap** – both inputs may produce values here. We take
908+
/// `max(overlap_a · NDV_a, overlap_b · NDV_b)` rather than the
909+
/// sum because values in the same sub-range are likely shared
910+
/// (the smaller set is assumed to be a subset of the larger).
911+
/// This is conservative: it avoids inflating the NDV estimate,
912+
/// which is safer for downstream join-order decisions.
913+
///
914+
/// The formula ranges between `[max(NDV_a, NDV_b), NDV_a + NDV_b]`,
915+
/// from full overlap to no overlap. Boundary cases confirm this:
916+
/// disjoint ranges → `NDV_a + NDV_b`, identical ranges →
917+
/// `max(NDV_a, NDV_b)`.
918+
///
919+
/// ```text
920+
/// NDV = max(overlap_a * NDV_a, overlap_b * NDV_b) [intersection]
921+
/// + (1 - overlap_a) * NDV_a [only in A]
922+
/// + (1 - overlap_b) * NDV_b [only in B]
923+
/// ```
924+
fn estimate_ndv_with_overlap(
925+
left: &ColumnStatistics,
926+
right: &ColumnStatistics,
927+
ndv_left: usize,
928+
ndv_right: usize,
929+
) -> Option<usize> {
930+
let min_left = left.min_value.get_value()?;
931+
let max_left = left.max_value.get_value()?;
932+
let min_right = right.min_value.get_value()?;
933+
let max_right = right.max_value.get_value()?;
934+
935+
let range_left = max_left.distance(min_left)?;
936+
let range_right = max_right.distance(min_right)?;
937+
938+
// Constant columns (range == 0) can't use the proportional overlap
939+
// formula below, so check interval overlap directly instead.
940+
if range_left == 0 || range_right == 0 {
941+
let overlaps = min_left <= max_right && min_right <= max_left;
942+
return Some(if overlaps {
943+
usize::max(ndv_left, ndv_right)
944+
} else {
945+
ndv_left + ndv_right
946+
});
947+
}
948+
949+
let overlap_min = if min_left >= min_right {
950+
min_left
951+
} else {
952+
min_right
953+
};
954+
let overlap_max = if max_left <= max_right {
955+
max_left
956+
} else {
957+
max_right
958+
};
959+
960+
// Short-circuit: when there's no overlap the formula naturally
961+
// degrades to ndv_left + ndv_right (overlap_range = 0 gives
962+
// overlap_left = overlap_right = 0), but returning early avoids
963+
// the floating-point math and a fallible `distance()` call.
964+
if overlap_min > overlap_max {
965+
return Some(ndv_left + ndv_right);
966+
}
967+
968+
let overlap_range = overlap_max.distance(overlap_min)? as f64;
969+
970+
let overlap_left = overlap_range / range_left as f64;
971+
let overlap_right = overlap_range / range_right as f64;
972+
973+
let intersection = f64::max(
974+
overlap_left * ndv_left as f64,
975+
overlap_right * ndv_right as f64,
976+
);
977+
let only_left = (1.0 - overlap_left) * ndv_left as f64;
978+
let only_right = (1.0 - overlap_right) * ndv_right as f64;
979+
980+
Some((intersection + only_left + only_right).round() as usize)
981+
}
982+
866983
fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
867984
let Statistics {
868985
num_rows: right_num_rows,
@@ -890,6 +1007,7 @@ mod tests {
8901007
use arrow::compute::SortOptions;
8911008
use arrow::datatypes::DataType;
8921009
use datafusion_common::ScalarValue;
1010+
use datafusion_common::stats::Precision;
8931011
use datafusion_physical_expr::equivalence::convert_to_orderings;
8941012
use datafusion_physical_expr::expressions::col;
8951013

@@ -1014,7 +1132,7 @@ mod tests {
10141132
total_byte_size: Precision::Exact(52),
10151133
column_statistics: vec![
10161134
ColumnStatistics {
1017-
distinct_count: Precision::Absent,
1135+
distinct_count: Precision::Inexact(6),
10181136
max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
10191137
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
10201138
sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
@@ -1043,6 +1161,197 @@ mod tests {
10431161
assert_eq!(result, expected);
10441162
}
10451163

1164+
#[test]
1165+
fn test_union_distinct_count() {
1166+
// (left_ndv, left_min, left_max, right_ndv, right_min, right_max, expected)
1167+
type NdvTestCase = (
1168+
Precision<usize>,
1169+
Option<i64>,
1170+
Option<i64>,
1171+
Precision<usize>,
1172+
Option<i64>,
1173+
Option<i64>,
1174+
Precision<usize>,
1175+
);
1176+
let cases: Vec<NdvTestCase> = vec![
1177+
// disjoint ranges: NDV = 5 + 3
1178+
(
1179+
Precision::Exact(5),
1180+
Some(0),
1181+
Some(10),
1182+
Precision::Exact(3),
1183+
Some(20),
1184+
Some(30),
1185+
Precision::Inexact(8),
1186+
),
1187+
// identical ranges: intersection = max(10, 8) = 10
1188+
(
1189+
Precision::Exact(10),
1190+
Some(0),
1191+
Some(100),
1192+
Precision::Exact(8),
1193+
Some(0),
1194+
Some(100),
1195+
Precision::Inexact(10),
1196+
),
1197+
// partial overlap: 50 + 50 + 25 = 125
1198+
(
1199+
Precision::Exact(100),
1200+
Some(0),
1201+
Some(100),
1202+
Precision::Exact(50),
1203+
Some(50),
1204+
Some(150),
1205+
Precision::Inexact(125),
1206+
),
1207+
// right contained in left: 50 + 50 + 0 = 100
1208+
(
1209+
Precision::Exact(100),
1210+
Some(0),
1211+
Some(100),
1212+
Precision::Exact(50),
1213+
Some(25),
1214+
Some(75),
1215+
Precision::Inexact(100),
1216+
),
1217+
// both constant, same value
1218+
(
1219+
Precision::Exact(1),
1220+
Some(5),
1221+
Some(5),
1222+
Precision::Exact(1),
1223+
Some(5),
1224+
Some(5),
1225+
Precision::Inexact(1),
1226+
),
1227+
// both constant, different values
1228+
(
1229+
Precision::Exact(1),
1230+
Some(5),
1231+
Some(5),
1232+
Precision::Exact(1),
1233+
Some(10),
1234+
Some(10),
1235+
Precision::Inexact(2),
1236+
),
1237+
// left constant within right range
1238+
(
1239+
Precision::Exact(1),
1240+
Some(5),
1241+
Some(5),
1242+
Precision::Exact(10),
1243+
Some(0),
1244+
Some(10),
1245+
Precision::Inexact(10),
1246+
),
1247+
// left constant outside right range
1248+
(
1249+
Precision::Exact(1),
1250+
Some(20),
1251+
Some(20),
1252+
Precision::Exact(10),
1253+
Some(0),
1254+
Some(10),
1255+
Precision::Inexact(11),
1256+
),
1257+
// right constant within left range
1258+
(
1259+
Precision::Exact(10),
1260+
Some(0),
1261+
Some(10),
1262+
Precision::Exact(1),
1263+
Some(5),
1264+
Some(5),
1265+
Precision::Inexact(10),
1266+
),
1267+
// right constant outside left range
1268+
(
1269+
Precision::Exact(10),
1270+
Some(0),
1271+
Some(10),
1272+
Precision::Exact(1),
1273+
Some(20),
1274+
Some(20),
1275+
Precision::Inexact(11),
1276+
),
1277+
// missing min/max falls back to sum (exact + exact)
1278+
(
1279+
Precision::Exact(10),
1280+
None,
1281+
None,
1282+
Precision::Exact(5),
1283+
None,
1284+
None,
1285+
Precision::Inexact(15),
1286+
),
1287+
// missing min/max falls back to sum (exact + inexact)
1288+
(
1289+
Precision::Exact(10),
1290+
None,
1291+
None,
1292+
Precision::Inexact(5),
1293+
None,
1294+
None,
1295+
Precision::Inexact(15),
1296+
),
1297+
// missing min/max falls back to sum (inexact + inexact)
1298+
(
1299+
Precision::Inexact(7),
1300+
None,
1301+
None,
1302+
Precision::Inexact(3),
1303+
None,
1304+
None,
1305+
Precision::Inexact(10),
1306+
),
1307+
// one side absent
1308+
(
1309+
Precision::Exact(10),
1310+
None,
1311+
None,
1312+
Precision::Absent,
1313+
None,
1314+
None,
1315+
Precision::Absent,
1316+
),
1317+
// one side absent (inexact + absent)
1318+
(
1319+
Precision::Inexact(4),
1320+
None,
1321+
None,
1322+
Precision::Absent,
1323+
None,
1324+
None,
1325+
Precision::Absent,
1326+
),
1327+
];
1328+
1329+
for (
1330+
i,
1331+
(left_ndv, left_min, left_max, right_ndv, right_min, right_max, expected),
1332+
) in cases.into_iter().enumerate()
1333+
{
1334+
let to_sv = |v| Precision::Exact(ScalarValue::Int64(Some(v)));
1335+
let left = ColumnStatistics {
1336+
distinct_count: left_ndv,
1337+
min_value: left_min.map(to_sv).unwrap_or(Precision::Absent),
1338+
max_value: left_max.map(to_sv).unwrap_or(Precision::Absent),
1339+
..Default::default()
1340+
};
1341+
let right = ColumnStatistics {
1342+
distinct_count: right_ndv,
1343+
min_value: right_min.map(to_sv).unwrap_or(Precision::Absent),
1344+
max_value: right_max.map(to_sv).unwrap_or(Precision::Absent),
1345+
..Default::default()
1346+
};
1347+
assert_eq!(
1348+
union_distinct_count(&left, &right),
1349+
expected,
1350+
"case {i} failed"
1351+
);
1352+
}
1353+
}
1354+
10461355
#[tokio::test]
10471356
async fn test_union_equivalence_properties() -> Result<()> {
10481357
let schema = create_test_schema()?;

0 commit comments

Comments
 (0)