Skip to content

Commit e96a6a3

Browse files
committed
address reviews
1 parent e444135 commit e96a6a3

1 file changed

Lines changed: 276 additions & 55 deletions

File tree

datafusion/physical-plan/src/union.rs

Lines changed: 276 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use crate::stream::ObservedStream;
4949
use arrow::datatypes::{Field, Schema, SchemaRef};
5050
use arrow::record_batch::RecordBatch;
5151
use datafusion_common::config::ConfigOptions;
52+
use datafusion_common::stats::Precision;
5253
use datafusion_common::tree_node::TreeNodeRecursion;
5354
use datafusion_common::{
5455
Result, assert_or_internal_err, exec_err, internal_datafusion_err,
@@ -853,7 +854,7 @@ fn col_stats_union(
853854
mut left: ColumnStatistics,
854855
right: &ColumnStatistics,
855856
) -> ColumnStatistics {
856-
left.distinct_count = left.distinct_count.add(&right.distinct_count).to_inexact();
857+
left.distinct_count = union_distinct_count(&left, right);
857858
left.min_value = left.min_value.min(&right.min_value);
858859
left.max_value = left.max_value.max(&right.max_value);
859860
left.sum_value = left.sum_value.add(&right.sum_value);
@@ -862,6 +863,92 @@ fn col_stats_union(
862863
left
863864
}
864865

866+
fn union_distinct_count(
867+
left: &ColumnStatistics,
868+
right: &ColumnStatistics,
869+
) -> Precision<usize> {
870+
let (ndv_left, ndv_right) = match (
871+
left.distinct_count.get_value(),
872+
right.distinct_count.get_value(),
873+
) {
874+
(Some(&l), Some(&r)) => (l, r),
875+
_ => return Precision::Absent,
876+
};
877+
878+
// Even with exact inputs, the union NDV depends on how
879+
// many distinct values are shared between the left and right.
880+
// We can only estimate this via range overlap. Thus both paths
881+
// below return `Inexact`.
882+
if let Some(ndv) = estimate_ndv_with_overlap(left, right, ndv_left, ndv_right) {
883+
return Precision::Inexact(ndv);
884+
}
885+
886+
Precision::Inexact(ndv_left + ndv_right)
887+
}
888+
889+
/// Estimates the distinct count for a union using range overlap, following
890+
/// the approach used by Trino:
891+
///
892+
/// overlap_a = fraction of A's range that overlaps with B
893+
/// overlap_b = fraction of B's range that overlaps with A
894+
/// NDV = max(overlap_a * NDV_a, overlap_b * NDV_b) [intersection]
895+
/// + (1 - overlap_a) * NDV_a [only in A]
896+
/// + (1 - overlap_b) * NDV_b [only in B]
897+
fn estimate_ndv_with_overlap(
898+
left: &ColumnStatistics,
899+
right: &ColumnStatistics,
900+
ndv_left: usize,
901+
ndv_right: usize,
902+
) -> Option<usize> {
903+
let min_left = left.min_value.get_value()?;
904+
let max_left = left.max_value.get_value()?;
905+
let min_right = right.min_value.get_value()?;
906+
let max_right = right.max_value.get_value()?;
907+
908+
let range_left = max_left.distance(min_left)?;
909+
let range_right = max_right.distance(min_right)?;
910+
911+
// Constant columns (range == 0) can't use the proportional overlap
912+
// formula below, so check interval overlap directly instead.
913+
if range_left == 0 || range_right == 0 {
914+
let overlaps = min_left <= max_right && min_right <= max_left;
915+
return Some(if overlaps {
916+
usize::max(ndv_left, ndv_right)
917+
} else {
918+
ndv_left + ndv_right
919+
});
920+
}
921+
922+
let overlap_min = if min_left >= min_right {
923+
min_left
924+
} else {
925+
min_right
926+
};
927+
let overlap_max = if max_left <= max_right {
928+
max_left
929+
} else {
930+
max_right
931+
};
932+
933+
if overlap_min > overlap_max {
934+
return Some(ndv_left + ndv_right);
935+
}
936+
937+
let overlap_range = overlap_max.distance(overlap_min)? as f64;
938+
939+
let overlap_left = overlap_range / range_left as f64;
940+
let overlap_right = overlap_range / range_right as f64;
941+
942+
let intersection = f64::max(
943+
overlap_left * ndv_left as f64,
944+
overlap_right * ndv_right as f64,
945+
);
946+
let only_left = (1.0 - overlap_left) * ndv_left as f64;
947+
let only_right = (1.0 - overlap_right) * ndv_right as f64;
948+
949+
Some((intersection + only_left + only_right).round() as usize)
950+
}
951+
865952
fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
866953
let Statistics {
867954
num_rows: right_num_rows,
@@ -1014,7 +1101,7 @@ mod tests {
10141101
total_byte_size: Precision::Exact(52),
10151102
column_statistics: vec![
10161103
ColumnStatistics {
1017-
distinct_count: Precision::Inexact(8),
1104+
distinct_count: Precision::Inexact(6),
10181105
max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
10191106
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
10201107
sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
@@ -1043,60 +1130,194 @@ mod tests {
10431130
assert_eq!(result, expected);
10441131
}
10451132

1046-
#[tokio::test]
1047-
async fn test_stats_union_distinct_count_inexact() {
1048-
let left = Statistics {
1049-
num_rows: Precision::Exact(10),
1050-
total_byte_size: Precision::Absent,
1051-
column_statistics: vec![
1052-
ColumnStatistics {
1053-
distinct_count: Precision::Exact(10),
1054-
..Default::default()
1055-
},
1056-
ColumnStatistics {
1057-
distinct_count: Precision::Inexact(7),
1058-
..Default::default()
1059-
},
1060-
ColumnStatistics {
1061-
distinct_count: Precision::Inexact(4),
1062-
..Default::default()
1063-
},
1064-
],
1065-
};
1066-
1067-
let right = Statistics {
1068-
num_rows: Precision::Exact(8),
1069-
total_byte_size: Precision::Absent,
1070-
column_statistics: vec![
1071-
ColumnStatistics {
1072-
distinct_count: Precision::Inexact(5),
1073-
..Default::default()
1074-
},
1075-
ColumnStatistics {
1076-
distinct_count: Precision::Inexact(3),
1077-
..Default::default()
1078-
},
1079-
ColumnStatistics {
1080-
distinct_count: Precision::Absent,
1081-
..Default::default()
1082-
},
1083-
],
1084-
};
1085-
1086-
let result = stats_union(left, right);
1133+
#[test]
1134+
fn test_union_distinct_count() {
1135+
// (left_ndv, left_min, left_max, right_ndv, right_min, right_max, expected)
1136+
let cases: Vec<(
1137+
Precision<usize>,
1138+
Option<i64>,
1139+
Option<i64>,
1140+
Precision<usize>,
1141+
Option<i64>,
1142+
Option<i64>,
1143+
Precision<usize>,
1144+
)> = vec![
1145+
// disjoint ranges: NDV = 5 + 3
1146+
(
1147+
Precision::Exact(5),
1148+
Some(0),
1149+
Some(10),
1150+
Precision::Exact(3),
1151+
Some(20),
1152+
Some(30),
1153+
Precision::Inexact(8),
1154+
),
1155+
// identical ranges: intersection = max(10, 8) = 10
1156+
(
1157+
Precision::Exact(10),
1158+
Some(0),
1159+
Some(100),
1160+
Precision::Exact(8),
1161+
Some(0),
1162+
Some(100),
1163+
Precision::Inexact(10),
1164+
),
1165+
// partial overlap: 50 + 50 + 25 = 125
1166+
(
1167+
Precision::Exact(100),
1168+
Some(0),
1169+
Some(100),
1170+
Precision::Exact(50),
1171+
Some(50),
1172+
Some(150),
1173+
Precision::Inexact(125),
1174+
),
1175+
// right contained in left: 50 + 50 + 0 = 100
1176+
(
1177+
Precision::Exact(100),
1178+
Some(0),
1179+
Some(100),
1180+
Precision::Exact(50),
1181+
Some(25),
1182+
Some(75),
1183+
Precision::Inexact(100),
1184+
),
1185+
// both constant, same value
1186+
(
1187+
Precision::Exact(1),
1188+
Some(5),
1189+
Some(5),
1190+
Precision::Exact(1),
1191+
Some(5),
1192+
Some(5),
1193+
Precision::Inexact(1),
1194+
),
1195+
// both constant, different values
1196+
(
1197+
Precision::Exact(1),
1198+
Some(5),
1199+
Some(5),
1200+
Precision::Exact(1),
1201+
Some(10),
1202+
Some(10),
1203+
Precision::Inexact(2),
1204+
),
1205+
// left constant within right range
1206+
(
1207+
Precision::Exact(1),
1208+
Some(5),
1209+
Some(5),
1210+
Precision::Exact(10),
1211+
Some(0),
1212+
Some(10),
1213+
Precision::Inexact(10),
1214+
),
1215+
// left constant outside right range
1216+
(
1217+
Precision::Exact(1),
1218+
Some(20),
1219+
Some(20),
1220+
Precision::Exact(10),
1221+
Some(0),
1222+
Some(10),
1223+
Precision::Inexact(11),
1224+
),
1225+
// right constant within left range
1226+
(
1227+
Precision::Exact(10),
1228+
Some(0),
1229+
Some(10),
1230+
Precision::Exact(1),
1231+
Some(5),
1232+
Some(5),
1233+
Precision::Inexact(10),
1234+
),
1235+
// right constant outside left range
1236+
(
1237+
Precision::Exact(10),
1238+
Some(0),
1239+
Some(10),
1240+
Precision::Exact(1),
1241+
Some(20),
1242+
Some(20),
1243+
Precision::Inexact(11),
1244+
),
1245+
// missing min/max falls back to sum (exact + exact)
1246+
(
1247+
Precision::Exact(10),
1248+
None,
1249+
None,
1250+
Precision::Exact(5),
1251+
None,
1252+
None,
1253+
Precision::Inexact(15),
1254+
),
1255+
// missing min/max falls back to sum (exact + inexact)
1256+
(
1257+
Precision::Exact(10),
1258+
None,
1259+
None,
1260+
Precision::Inexact(5),
1261+
None,
1262+
None,
1263+
Precision::Inexact(15),
1264+
),
1265+
// missing min/max falls back to sum (inexact + inexact)
1266+
(
1267+
Precision::Inexact(7),
1268+
None,
1269+
None,
1270+
Precision::Inexact(3),
1271+
None,
1272+
None,
1273+
Precision::Inexact(10),
1274+
),
1275+
// one side absent
1276+
(
1277+
Precision::Exact(10),
1278+
None,
1279+
None,
1280+
Precision::Absent,
1281+
None,
1282+
None,
1283+
Precision::Absent,
1284+
),
1285+
// one side absent (inexact + absent)
1286+
(
1287+
Precision::Inexact(4),
1288+
None,
1289+
None,
1290+
Precision::Absent,
1291+
None,
1292+
None,
1293+
Precision::Absent,
1294+
),
1295+
];
10871296

1088-
assert_eq!(
1089-
result.column_statistics[0].distinct_count,
1090-
Precision::Inexact(15)
1091-
);
1092-
assert_eq!(
1093-
result.column_statistics[1].distinct_count,
1094-
Precision::Inexact(10)
1095-
);
1096-
assert_eq!(
1097-
result.column_statistics[2].distinct_count,
1098-
Precision::Absent
1099-
);
1297+
for (
1298+
i,
1299+
(left_ndv, left_min, left_max, right_ndv, right_min, right_max, expected),
1300+
) in cases.into_iter().enumerate()
1301+
{
1302+
let to_sv = |v| Precision::Exact(ScalarValue::Int64(Some(v)));
1303+
let left = ColumnStatistics {
1304+
distinct_count: left_ndv,
1305+
min_value: left_min.map(to_sv).unwrap_or(Precision::Absent),
1306+
max_value: left_max.map(to_sv).unwrap_or(Precision::Absent),
1307+
..Default::default()
1308+
};
1309+
let right = ColumnStatistics {
1310+
distinct_count: right_ndv,
1311+
min_value: right_min.map(to_sv).unwrap_or(Precision::Absent),
1312+
max_value: right_max.map(to_sv).unwrap_or(Precision::Absent),
1313+
..Default::default()
1314+
};
1315+
assert_eq!(
1316+
union_distinct_count(&left, &right),
1317+
expected,
1318+
"case {i} failed"
1319+
);
1320+
}
11001321
}
11011322

11021323
#[tokio::test]

0 commit comments

Comments
 (0)