Skip to content

Commit e88ae02

Browse files
committed
refactor(stats): document and normalize sum_value widening
1 parent 8f75815 commit e88ae02

File tree

1 file changed

+56
-34
lines changed

1 file changed

+56
-34
lines changed

datafusion/common/src/stats.rs

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use std::fmt::{self, Debug, Display};
2222
use crate::{Result, ScalarValue};
2323

2424
use crate::error::_plan_err;
25+
use crate::utils::aggregate::precision_add;
2526
use arrow::datatypes::{DataType, Schema};
2627

2728
/// Represents a value with a degree of certainty. `Precision` is used to
@@ -210,6 +211,16 @@ impl Precision<ScalarValue> {
210211
}
211212
}
212213

214+
fn cast_scalar_to_sum_type(value: &ScalarValue) -> Result<ScalarValue> {
215+
let source_type = value.data_type();
216+
let target_type = Self::sum_data_type(&source_type);
217+
if source_type == target_type {
218+
Ok(value.clone())
219+
} else {
220+
value.cast_to(&target_type)
221+
}
222+
}
223+
213224
/// Calculates the sum of two (possibly inexact) [`ScalarValue`] values,
214225
/// conservatively propagating exactness information. If one of the input
215226
/// values is [`Precision::Absent`], the result is `Absent` too.
@@ -240,39 +251,21 @@ impl Precision<ScalarValue> {
240251
/// This narrows overflow risk when `sum_value` statistics are merged:
241252
/// `Int8/Int16/Int32 -> Int64` and `UInt8/UInt16/UInt32 -> UInt64`.
242253
pub fn cast_to_sum_type(&self) -> Precision<ScalarValue> {
243-
match self {
244-
Precision::Exact(value) => {
245-
let source_type = value.data_type();
246-
let target_type = Self::sum_data_type(&source_type);
247-
if source_type == target_type {
248-
Precision::Exact(value.clone())
249-
} else {
250-
value
251-
.cast_to(&target_type)
252-
.map(Precision::Exact)
253-
.unwrap_or(Precision::Absent)
254-
}
255-
}
256-
Precision::Inexact(value) => {
257-
let source_type = value.data_type();
258-
let target_type = Self::sum_data_type(&source_type);
259-
if source_type == target_type {
260-
Precision::Inexact(value.clone())
261-
} else {
262-
value
263-
.cast_to(&target_type)
264-
.map(Precision::Inexact)
265-
.unwrap_or(Precision::Absent)
266-
}
267-
}
268-
Precision::Absent => Precision::Absent,
254+
match (self.is_exact(), self.get_value()) {
255+
(Some(true), Some(value)) => Self::cast_scalar_to_sum_type(value)
256+
.map(Precision::Exact)
257+
.unwrap_or(Precision::Absent),
258+
(Some(false), Some(value)) => Self::cast_scalar_to_sum_type(value)
259+
.map(Precision::Inexact)
260+
.unwrap_or(Precision::Absent),
261+
(_, _) => Precision::Absent,
269262
}
270263
}
271264

272265
/// SUM-style addition with integer widening to match SQL `SUM` return
273266
/// types for smaller integral inputs.
274267
pub fn add_for_sum(&self, other: &Precision<ScalarValue>) -> Precision<ScalarValue> {
275-
self.cast_to_sum_type().add(&other.cast_to_sum_type())
268+
precision_add(&self.cast_to_sum_type(), &other.cast_to_sum_type())
276269
}
277270

278271
/// Calculates the difference of two (possibly inexact) [`ScalarValue`] values,
@@ -727,8 +720,7 @@ impl Statistics {
727720

728721
col_stats.null_count = col_stats.null_count.add(&item_cs.null_count);
729722
col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size);
730-
col_stats.sum_value =
731-
col_stats.sum_value.add_for_sum(&item_cs.sum_value);
723+
col_stats.sum_value = col_stats.sum_value.add_for_sum(&item_cs.sum_value);
732724
col_stats.min_value = col_stats.min_value.min(&item_cs.min_value);
733725
col_stats.max_value = col_stats.max_value.max(&item_cs.max_value);
734726
}
@@ -823,7 +815,15 @@ pub struct ColumnStatistics {
823815
pub max_value: Precision<ScalarValue>,
824816
/// Minimum value of column
825817
pub min_value: Precision<ScalarValue>,
826-
/// Sum value of a column
818+
/// Sum value of a column.
819+
///
820+
/// For integral columns, values should be kept in SUM-compatible widened
821+
/// types (`Int8/Int16/Int32 -> Int64`, `UInt8/UInt16/UInt32 -> UInt64`) to
822+
/// reduce overflow risk during statistics propagation.
823+
///
824+
/// Callers should prefer [`ColumnStatistics::with_sum_value`] for setting
825+
/// this field and [`Precision<ScalarValue>::add_for_sum`] /
826+
/// [`Precision<ScalarValue>::cast_to_sum_type`] for sum arithmetic.
827827
pub sum_value: Precision<ScalarValue>,
828828
/// Number of distinct values
829829
pub distinct_count: Precision<usize>,
@@ -888,7 +888,19 @@ impl ColumnStatistics {
888888

889889
/// Set the sum value
890890
pub fn with_sum_value(mut self, sum_value: Precision<ScalarValue>) -> Self {
891-
self.sum_value = sum_value;
891+
self.sum_value = match sum_value {
892+
Precision::Exact(value) => {
893+
Precision::<ScalarValue>::cast_scalar_to_sum_type(&value)
894+
.map(Precision::Exact)
895+
.unwrap_or(Precision::Absent)
896+
}
897+
Precision::Inexact(value) => {
898+
Precision::<ScalarValue>::cast_scalar_to_sum_type(&value)
899+
.map(Precision::Inexact)
900+
.unwrap_or(Precision::Absent)
901+
}
902+
Precision::Absent => Precision::Absent,
903+
};
892904
self
893905
}
894906

@@ -1735,6 +1747,16 @@ mod tests {
17351747
assert_eq!(col_stats.byte_size, Precision::Exact(8192));
17361748
}
17371749

1750+
#[test]
1751+
fn test_with_sum_value_builder_widens_small_integers() {
1752+
let col_stats = ColumnStatistics::new_unknown()
1753+
.with_sum_value(Precision::Exact(ScalarValue::UInt32(Some(123))));
1754+
assert_eq!(
1755+
col_stats.sum_value,
1756+
Precision::Exact(ScalarValue::UInt64(Some(123)))
1757+
);
1758+
}
1759+
17381760
#[test]
17391761
fn test_with_fetch_scales_byte_size() {
17401762
// Test that byte_size is scaled by the row ratio in with_fetch
@@ -1882,7 +1904,7 @@ mod tests {
18821904
);
18831905
assert_eq!(
18841906
col1_stats.sum_value,
1885-
Precision::Exact(ScalarValue::Int32(Some(1100)))
1907+
Precision::Exact(ScalarValue::Int64(Some(1100)))
18861908
);
18871909

18881910
let col2_stats = &summary_stats.column_statistics[1];
@@ -1897,7 +1919,7 @@ mod tests {
18971919
);
18981920
assert_eq!(
18991921
col2_stats.sum_value,
1900-
Precision::Exact(ScalarValue::Int32(Some(2200)))
1922+
Precision::Exact(ScalarValue::Int64(Some(2200)))
19011923
);
19021924
}
19031925

@@ -2245,7 +2267,7 @@ mod tests {
22452267
);
22462268
assert_eq!(
22472269
col_stats.sum_value,
2248-
Precision::Inexact(ScalarValue::Int32(Some(1500)))
2270+
Precision::Inexact(ScalarValue::Int64(Some(1500)))
22492271
);
22502272
}
22512273
}

0 commit comments

Comments
 (0)