Skip to content

Commit 27eff84

Browse files
committed
fix(stats): widen sum_value integer arithmetic to SUM-compatible types
1 parent acec058 commit 27eff84

File tree

5 files changed

+272
-32
lines changed

5 files changed

+272
-32
lines changed

datafusion/common/src/stats.rs

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ impl Precision<usize> {
180180
}
181181

182182
impl Precision<ScalarValue> {
183+
fn sum_data_type(data_type: &DataType) -> DataType {
184+
match data_type {
185+
DataType::Int8 | DataType::Int16 | DataType::Int32 => DataType::Int64,
186+
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => DataType::UInt64,
187+
_ => data_type.clone(),
188+
}
189+
}
190+
183191
/// Calculates the sum of two (possibly inexact) [`ScalarValue`] values,
184192
/// conservatively propagating exactness information. If one of the input
185193
/// values is [`Precision::Absent`], the result is `Absent` too.
@@ -198,6 +206,46 @@ impl Precision<ScalarValue> {
198206
}
199207
}
200208

209+
/// Casts integer values to the wider SQL `SUM` return type.
210+
///
211+
/// This narrows overflow risk when `sum_value` statistics are merged:
212+
/// `Int8/Int16/Int32 -> Int64` and `UInt8/UInt16/UInt32 -> UInt64`.
213+
pub fn cast_to_sum_type(&self) -> Precision<ScalarValue> {
214+
match self {
215+
Precision::Exact(value) => {
216+
let source_type = value.data_type();
217+
let target_type = Self::sum_data_type(&source_type);
218+
if source_type == target_type {
219+
Precision::Exact(value.clone())
220+
} else {
221+
value
222+
.cast_to(&target_type)
223+
.map(Precision::Exact)
224+
.unwrap_or(Precision::Absent)
225+
}
226+
}
227+
Precision::Inexact(value) => {
228+
let source_type = value.data_type();
229+
let target_type = Self::sum_data_type(&source_type);
230+
if source_type == target_type {
231+
Precision::Inexact(value.clone())
232+
} else {
233+
value
234+
.cast_to(&target_type)
235+
.map(Precision::Inexact)
236+
.unwrap_or(Precision::Absent)
237+
}
238+
}
239+
Precision::Absent => Precision::Absent,
240+
}
241+
}
242+
243+
/// SUM-style addition with integer widening to match SQL `SUM` return
244+
/// types for smaller integral inputs.
245+
pub fn add_for_sum(&self, other: &Precision<ScalarValue>) -> Precision<ScalarValue> {
246+
self.cast_to_sum_type().add(&other.cast_to_sum_type())
247+
}
248+
201249
/// Calculates the difference of two (possibly inexact) [`ScalarValue`] values,
202250
/// conservatively propagating exactness information. If one of the input
203251
/// values is [`Precision::Absent`], the result is `Absent` too.
@@ -636,7 +684,8 @@ impl Statistics {
636684
col_stats.null_count = col_stats.null_count.add(&item_col_stats.null_count);
637685
col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value);
638686
col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value);
639-
col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value);
687+
col_stats.sum_value =
688+
col_stats.sum_value.add_for_sum(&item_col_stats.sum_value);
640689
col_stats.distinct_count = Precision::Absent;
641690
col_stats.byte_size = col_stats.byte_size.add(&item_col_stats.byte_size);
642691
}
@@ -948,6 +997,45 @@ mod tests {
948997
assert_eq!(precision.add(&Precision::Absent), Precision::Absent);
949998
}
950999

1000+
#[test]
1001+
fn test_add_for_sum_scalar_integer_widening() {
1002+
let precision = Precision::Exact(ScalarValue::Int32(Some(42)));
1003+
1004+
assert_eq!(
1005+
precision.add_for_sum(&Precision::Exact(ScalarValue::Int32(Some(23)))),
1006+
Precision::Exact(ScalarValue::Int64(Some(65))),
1007+
);
1008+
assert_eq!(
1009+
precision.add_for_sum(&Precision::Inexact(ScalarValue::Int32(Some(23)))),
1010+
Precision::Inexact(ScalarValue::Int64(Some(65))),
1011+
);
1012+
}
1013+
1014+
#[test]
1015+
fn test_add_for_sum_prevents_int32_overflow() {
1016+
let lhs = Precision::Exact(ScalarValue::Int32(Some(i32::MAX)));
1017+
let rhs = Precision::Exact(ScalarValue::Int32(Some(1)));
1018+
1019+
assert_eq!(
1020+
lhs.add_for_sum(&rhs),
1021+
Precision::Exact(ScalarValue::Int64(Some(i64::from(i32::MAX) + 1))),
1022+
);
1023+
}
1024+
1025+
#[test]
1026+
fn test_add_for_sum_scalar_unsigned_integer_widening() {
1027+
let precision = Precision::Exact(ScalarValue::UInt32(Some(42)));
1028+
1029+
assert_eq!(
1030+
precision.add_for_sum(&Precision::Exact(ScalarValue::UInt32(Some(23)))),
1031+
Precision::Exact(ScalarValue::UInt64(Some(65))),
1032+
);
1033+
assert_eq!(
1034+
precision.add_for_sum(&Precision::Inexact(ScalarValue::UInt32(Some(23)))),
1035+
Precision::Inexact(ScalarValue::UInt64(Some(65))),
1036+
);
1037+
}
1038+
9511039
#[test]
9521040
fn test_sub() {
9531041
let precision1 = Precision::Exact(42);
@@ -1193,7 +1281,7 @@ mod tests {
11931281
);
11941282
assert_eq!(
11951283
col1_stats.sum_value,
1196-
Precision::Exact(ScalarValue::Int32(Some(1100)))
1284+
Precision::Exact(ScalarValue::Int64(Some(1100)))
11971285
); // 500 + 600
11981286

11991287
let col2_stats = &summary_stats.column_statistics[1];
@@ -1208,7 +1296,7 @@ mod tests {
12081296
);
12091297
assert_eq!(
12101298
col2_stats.sum_value,
1211-
Precision::Exact(ScalarValue::Int32(Some(2200)))
1299+
Precision::Exact(ScalarValue::Int64(Some(2200)))
12121300
); // 1000 + 1200
12131301
}
12141302

datafusion/datasource/src/statistics.rs

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ fn sort_columns_from_physical_sort_exprs(
293293
since = "47.0.0",
294294
note = "Please use `get_files_with_limit` and `compute_all_files_statistics` instead"
295295
)]
296-
#[expect(unused)]
296+
#[cfg_attr(not(test), expect(unused))]
297297
pub async fn get_statistics_with_limit(
298298
all_files: impl Stream<Item = Result<(PartitionedFile, Arc<Statistics>)>>,
299299
file_schema: SchemaRef,
@@ -329,7 +329,7 @@ pub async fn get_statistics_with_limit(
329329
col_stats_set[index].null_count = file_column.null_count;
330330
col_stats_set[index].max_value = file_column.max_value;
331331
col_stats_set[index].min_value = file_column.min_value;
332-
col_stats_set[index].sum_value = file_column.sum_value;
332+
col_stats_set[index].sum_value = file_column.sum_value.cast_to_sum_type();
333333
}
334334

335335
// If the number of rows exceeds the limit, we can stop processing
@@ -374,7 +374,7 @@ pub async fn get_statistics_with_limit(
374374
col_stats.null_count = col_stats.null_count.add(file_nc);
375375
col_stats.max_value = col_stats.max_value.max(file_max);
376376
col_stats.min_value = col_stats.min_value.min(file_min);
377-
col_stats.sum_value = col_stats.sum_value.add(file_sum);
377+
col_stats.sum_value = col_stats.sum_value.add_for_sum(file_sum);
378378
col_stats.byte_size = col_stats.byte_size.add(file_sbs);
379379
}
380380

@@ -497,3 +497,78 @@ pub fn add_row_stats(
497497
) -> Precision<usize> {
498498
file_num_rows.add(&num_rows)
499499
}
500+
501+
#[cfg(test)]
502+
mod tests {
503+
use super::*;
504+
use crate::PartitionedFile;
505+
use arrow::datatypes::{DataType, Field, Schema};
506+
use futures::stream;
507+
508+
fn file_stats(sum: u32) -> Statistics {
509+
Statistics {
510+
num_rows: Precision::Exact(1),
511+
total_byte_size: Precision::Exact(4),
512+
column_statistics: vec![ColumnStatistics {
513+
null_count: Precision::Exact(0),
514+
max_value: Precision::Exact(ScalarValue::UInt32(Some(sum))),
515+
min_value: Precision::Exact(ScalarValue::UInt32(Some(sum))),
516+
sum_value: Precision::Exact(ScalarValue::UInt32(Some(sum))),
517+
distinct_count: Precision::Exact(1),
518+
byte_size: Precision::Exact(4),
519+
}],
520+
}
521+
}
522+
523+
#[tokio::test]
524+
#[expect(deprecated)]
525+
async fn test_get_statistics_with_limit_casts_first_file_sum_to_sum_type()
526+
-> Result<()> {
527+
let schema =
528+
Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)]));
529+
530+
let files = stream::iter(vec![Ok((
531+
PartitionedFile::new("f1.parquet", 1),
532+
Arc::new(file_stats(100)),
533+
))]);
534+
535+
let (_group, stats) =
536+
get_statistics_with_limit(files, schema, None, false).await?;
537+
538+
assert_eq!(
539+
stats.column_statistics[0].sum_value,
540+
Precision::Exact(ScalarValue::UInt64(Some(100)))
541+
);
542+
543+
Ok(())
544+
}
545+
546+
#[tokio::test]
547+
#[expect(deprecated)]
548+
async fn test_get_statistics_with_limit_merges_sum_with_unsigned_widening()
549+
-> Result<()> {
550+
let schema =
551+
Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)]));
552+
553+
let files = stream::iter(vec![
554+
Ok((
555+
PartitionedFile::new("f1.parquet", 1),
556+
Arc::new(file_stats(100)),
557+
)),
558+
Ok((
559+
PartitionedFile::new("f2.parquet", 1),
560+
Arc::new(file_stats(200)),
561+
)),
562+
]);
563+
564+
let (_group, stats) =
565+
get_statistics_with_limit(files, schema, None, true).await?;
566+
567+
assert_eq!(
568+
stats.column_statistics[0].sum_value,
569+
Precision::Exact(ScalarValue::UInt64(Some(300)))
570+
);
571+
572+
Ok(())
573+
}
574+
}

datafusion/physical-expr/src/projection.rs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -693,12 +693,15 @@ impl ProjectionExprs {
693693
Precision::Absent
694694
};
695695

696-
let sum_value = Precision::<ScalarValue>::from(stats.num_rows)
697-
.cast_to(&value.data_type())
698-
.ok()
699-
.map(|row_count| {
700-
Precision::Exact(value.clone()).multiply(&row_count)
696+
let widened_sum = Precision::Exact(value.clone()).cast_to_sum_type();
697+
let sum_value = widened_sum
698+
.get_value()
699+
.and_then(|sum| {
700+
Precision::<ScalarValue>::from(stats.num_rows)
701+
.cast_to(&sum.data_type())
702+
.ok()
701703
})
704+
.map(|row_count| widened_sum.multiply(&row_count))
702705
.unwrap_or(Precision::Absent);
703706

704707
ColumnStatistics {
@@ -2866,6 +2869,35 @@ pub(crate) mod tests {
28662869
Ok(())
28672870
}
28682871

2872+
#[test]
2873+
fn test_project_statistics_with_i32_literal_sum_widens_to_i64() -> Result<()> {
2874+
let input_stats = get_stats();
2875+
let input_schema = get_schema();
2876+
2877+
let projection = ProjectionExprs::new(vec![
2878+
ProjectionExpr {
2879+
expr: Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2880+
alias: "constant".to_string(),
2881+
},
2882+
ProjectionExpr {
2883+
expr: Arc::new(Column::new("col0", 0)),
2884+
alias: "num".to_string(),
2885+
},
2886+
]);
2887+
2888+
let output_stats = projection.project_statistics(
2889+
input_stats,
2890+
&projection.project_schema(&input_schema)?,
2891+
)?;
2892+
2893+
assert_eq!(
2894+
output_stats.column_statistics[0].sum_value,
2895+
Precision::Exact(ScalarValue::Int64(Some(50)))
2896+
);
2897+
2898+
Ok(())
2899+
}
2900+
28692901
// Test statistics calculation for NULL literal (constant NULL column)
28702902
#[test]
28712903
fn test_project_statistics_with_null_literal() -> Result<()> {

0 commit comments

Comments
 (0)