@@ -22,6 +22,7 @@ use std::fmt::{self, Debug, Display};
2222use crate :: { Result , ScalarValue } ;
2323
2424use crate :: error:: _plan_err;
25+ use crate :: utils:: aggregate:: precision_add;
2526use arrow:: datatypes:: { DataType , Schema } ;
2627
2728/// Represents a value with a degree of certainty. `Precision` is used to
@@ -210,6 +211,16 @@ impl Precision<ScalarValue> {
210211 }
211212 }
212213
214+ fn cast_scalar_to_sum_type ( value : & ScalarValue ) -> Result < ScalarValue > {
215+ let source_type = value. data_type ( ) ;
216+ let target_type = Self :: sum_data_type ( & source_type) ;
217+ if source_type == target_type {
218+ Ok ( value. clone ( ) )
219+ } else {
220+ value. cast_to ( & target_type)
221+ }
222+ }
223+
213224 /// Calculates the sum of two (possibly inexact) [`ScalarValue`] values,
214225 /// conservatively propagating exactness information. If one of the input
215226 /// values is [`Precision::Absent`], the result is `Absent` too.
@@ -240,39 +251,21 @@ impl Precision<ScalarValue> {
240251 /// This narrows overflow risk when `sum_value` statistics are merged:
241252 /// `Int8/Int16/Int32 -> Int64` and `UInt8/UInt16/UInt32 -> UInt64`.
242253 pub fn cast_to_sum_type ( & self ) -> Precision < ScalarValue > {
243- match self {
244- Precision :: Exact ( value) => {
245- let source_type = value. data_type ( ) ;
246- let target_type = Self :: sum_data_type ( & source_type) ;
247- if source_type == target_type {
248- Precision :: Exact ( value. clone ( ) )
249- } else {
250- value
251- . cast_to ( & target_type)
252- . map ( Precision :: Exact )
253- . unwrap_or ( Precision :: Absent )
254- }
255- }
256- Precision :: Inexact ( value) => {
257- let source_type = value. data_type ( ) ;
258- let target_type = Self :: sum_data_type ( & source_type) ;
259- if source_type == target_type {
260- Precision :: Inexact ( value. clone ( ) )
261- } else {
262- value
263- . cast_to ( & target_type)
264- . map ( Precision :: Inexact )
265- . unwrap_or ( Precision :: Absent )
266- }
267- }
268- Precision :: Absent => Precision :: Absent ,
254+ match ( self . is_exact ( ) , self . get_value ( ) ) {
255+ ( Some ( true ) , Some ( value) ) => Self :: cast_scalar_to_sum_type ( value)
256+ . map ( Precision :: Exact )
257+ . unwrap_or ( Precision :: Absent ) ,
258+ ( Some ( false ) , Some ( value) ) => Self :: cast_scalar_to_sum_type ( value)
259+ . map ( Precision :: Inexact )
260+ . unwrap_or ( Precision :: Absent ) ,
261+ ( _, _) => Precision :: Absent ,
269262 }
270263 }
271264
272265 /// SUM-style addition with integer widening to match SQL `SUM` return
273266 /// types for smaller integral inputs.
274267 pub fn add_for_sum ( & self , other : & Precision < ScalarValue > ) -> Precision < ScalarValue > {
275- self . cast_to_sum_type ( ) . add ( & other. cast_to_sum_type ( ) )
268+ precision_add ( & self . cast_to_sum_type ( ) , & other. cast_to_sum_type ( ) )
276269 }
277270
278271 /// Calculates the difference of two (possibly inexact) [`ScalarValue`] values,
@@ -727,8 +720,7 @@ impl Statistics {
727720
728721 col_stats. null_count = col_stats. null_count . add ( & item_cs. null_count ) ;
729722 col_stats. byte_size = col_stats. byte_size . add ( & item_cs. byte_size ) ;
730- col_stats. sum_value =
731- col_stats. sum_value . add_for_sum ( & item_cs. sum_value ) ;
723+ col_stats. sum_value = col_stats. sum_value . add_for_sum ( & item_cs. sum_value ) ;
732724 col_stats. min_value = col_stats. min_value . min ( & item_cs. min_value ) ;
733725 col_stats. max_value = col_stats. max_value . max ( & item_cs. max_value ) ;
734726 }
@@ -823,7 +815,15 @@ pub struct ColumnStatistics {
823815 pub max_value : Precision < ScalarValue > ,
824816 /// Minimum value of column
825817 pub min_value : Precision < ScalarValue > ,
826- /// Sum value of a column
818+ /// Sum value of a column.
819+ ///
820+ /// For integral columns, values should be kept in SUM-compatible widened
821+ /// types (`Int8/Int16/Int32 -> Int64`, `UInt8/UInt16/UInt32 -> UInt64`) to
822+ /// reduce overflow risk during statistics propagation.
823+ ///
824+ /// Callers should prefer [`ColumnStatistics::with_sum_value`] for setting
825+ /// this field and [`Precision<ScalarValue>::add_for_sum`] /
826+ /// [`Precision<ScalarValue>::cast_to_sum_type`] for sum arithmetic.
827827 pub sum_value : Precision < ScalarValue > ,
828828 /// Number of distinct values
829829 pub distinct_count : Precision < usize > ,
@@ -888,7 +888,19 @@ impl ColumnStatistics {
888888
889889 /// Set the sum value
890890 pub fn with_sum_value ( mut self , sum_value : Precision < ScalarValue > ) -> Self {
891- self . sum_value = sum_value;
891+ self . sum_value = match sum_value {
892+ Precision :: Exact ( value) => {
893+ Precision :: < ScalarValue > :: cast_scalar_to_sum_type ( & value)
894+ . map ( Precision :: Exact )
895+ . unwrap_or ( Precision :: Absent )
896+ }
897+ Precision :: Inexact ( value) => {
898+ Precision :: < ScalarValue > :: cast_scalar_to_sum_type ( & value)
899+ . map ( Precision :: Inexact )
900+ . unwrap_or ( Precision :: Absent )
901+ }
902+ Precision :: Absent => Precision :: Absent ,
903+ } ;
892904 self
893905 }
894906
@@ -1735,6 +1747,16 @@ mod tests {
17351747 assert_eq ! ( col_stats. byte_size, Precision :: Exact ( 8192 ) ) ;
17361748 }
17371749
1750+ #[ test]
1751+ fn test_with_sum_value_builder_widens_small_integers ( ) {
1752+ let col_stats = ColumnStatistics :: new_unknown ( )
1753+ . with_sum_value ( Precision :: Exact ( ScalarValue :: UInt32 ( Some ( 123 ) ) ) ) ;
1754+ assert_eq ! (
1755+ col_stats. sum_value,
1756+ Precision :: Exact ( ScalarValue :: UInt64 ( Some ( 123 ) ) )
1757+ ) ;
1758+ }
1759+
17381760 #[ test]
17391761 fn test_with_fetch_scales_byte_size ( ) {
17401762 // Test that byte_size is scaled by the row ratio in with_fetch
@@ -1882,7 +1904,7 @@ mod tests {
18821904 ) ;
18831905 assert_eq ! (
18841906 col1_stats. sum_value,
1885- Precision :: Exact ( ScalarValue :: Int32 ( Some ( 1100 ) ) )
1907+ Precision :: Exact ( ScalarValue :: Int64 ( Some ( 1100 ) ) )
18861908 ) ;
18871909
18881910 let col2_stats = & summary_stats. column_statistics [ 1 ] ;
@@ -1897,7 +1919,7 @@ mod tests {
18971919 ) ;
18981920 assert_eq ! (
18991921 col2_stats. sum_value,
1900- Precision :: Exact ( ScalarValue :: Int32 ( Some ( 2200 ) ) )
1922+ Precision :: Exact ( ScalarValue :: Int64 ( Some ( 2200 ) ) )
19011923 ) ;
19021924 }
19031925
@@ -2245,7 +2267,7 @@ mod tests {
22452267 ) ;
22462268 assert_eq ! (
22472269 col_stats. sum_value,
2248- Precision :: Inexact ( ScalarValue :: Int32 ( Some ( 1500 ) ) )
2270+ Precision :: Inexact ( ScalarValue :: Int64 ( Some ( 1500 ) ) )
22492271 ) ;
22502272 }
22512273}
0 commit comments