Skip to content

Commit b35253c

Browse files
committed
fix(stats): widen sum_value integer arithmetic to SUM-compatible types
1 parent a3dc8fa commit b35253c

File tree

5 files changed

+274
-34
lines changed

5 files changed

+274
-34
lines changed

datafusion/common/src/stats.rs

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ impl Precision<usize> {
203203
}
204204

205205
impl Precision<ScalarValue> {
206+
fn sum_data_type(data_type: &DataType) -> DataType {
207+
match data_type {
208+
DataType::Int8 | DataType::Int16 | DataType::Int32 => DataType::Int64,
209+
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => DataType::UInt64,
210+
_ => data_type.clone(),
211+
}
212+
}
213+
206214
/// Calculates the sum of two (possibly inexact) [`ScalarValue`] values,
207215
/// conservatively propagating exactness information. If one of the input
208216
/// values is [`Precision::Absent`], the result is `Absent` too.
@@ -228,6 +236,46 @@ impl Precision<ScalarValue> {
228236
}
229237
}
230238

239+
/// Casts integer values to the wider SQL `SUM` return type.
240+
///
241+
/// This narrows overflow risk when `sum_value` statistics are merged:
242+
/// `Int8/Int16/Int32 -> Int64` and `UInt8/UInt16/UInt32 -> UInt64`.
243+
pub fn cast_to_sum_type(&self) -> Precision<ScalarValue> {
244+
match self {
245+
Precision::Exact(value) => {
246+
let source_type = value.data_type();
247+
let target_type = Self::sum_data_type(&source_type);
248+
if source_type == target_type {
249+
Precision::Exact(value.clone())
250+
} else {
251+
value
252+
.cast_to(&target_type)
253+
.map(Precision::Exact)
254+
.unwrap_or(Precision::Absent)
255+
}
256+
}
257+
Precision::Inexact(value) => {
258+
let source_type = value.data_type();
259+
let target_type = Self::sum_data_type(&source_type);
260+
if source_type == target_type {
261+
Precision::Inexact(value.clone())
262+
} else {
263+
value
264+
.cast_to(&target_type)
265+
.map(Precision::Inexact)
266+
.unwrap_or(Precision::Absent)
267+
}
268+
}
269+
Precision::Absent => Precision::Absent,
270+
}
271+
}
272+
273+
/// SUM-style addition with integer widening to match SQL `SUM` return
274+
/// types for smaller integral inputs.
275+
pub fn add_for_sum(&self, other: &Precision<ScalarValue>) -> Precision<ScalarValue> {
276+
self.cast_to_sum_type().add(&other.cast_to_sum_type())
277+
}
278+
231279
/// Calculates the difference of two (possibly inexact) [`ScalarValue`] values,
232280
/// conservatively propagating exactness information. If one of the input
233281
/// values is [`Precision::Absent`], the result is `Absent` too.
@@ -620,7 +668,7 @@ impl Statistics {
620668
/// assert_eq!(merged.column_statistics[0].max_value,
621669
/// Precision::Exact(ScalarValue::from(200)));
622670
/// assert_eq!(merged.column_statistics[0].sum_value,
623-
/// Precision::Exact(ScalarValue::from(1500)));
671+
/// Precision::Exact(ScalarValue::Int64(Some(1500))));
624672
/// ```
625673
pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result<Statistics>
626674
where
@@ -664,7 +712,7 @@ impl Statistics {
664712
null_count: cs.null_count,
665713
max_value: cs.max_value.clone(),
666714
min_value: cs.min_value.clone(),
667-
sum_value: cs.sum_value.clone(),
715+
sum_value: cs.sum_value.cast_to_sum_type(),
668716
distinct_count: cs.distinct_count,
669717
byte_size: cs.byte_size,
670718
})
@@ -693,7 +741,8 @@ impl Statistics {
693741
};
694742
col_stats.min_value = col_stats.min_value.min(&item_cs.min_value);
695743
col_stats.max_value = col_stats.max_value.max(&item_cs.max_value);
696-
precision_add(&mut col_stats.sum_value, &item_cs.sum_value);
744+
let item_sum_value = item_cs.sum_value.cast_to_sum_type();
745+
precision_add(&mut col_stats.sum_value, &item_sum_value);
697746
col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size);
698747
}
699748
}
@@ -1095,6 +1144,45 @@ mod tests {
10951144
assert_eq!(precision.add(&Precision::Absent), Precision::Absent);
10961145
}
10971146

1147+
#[test]
1148+
fn test_add_for_sum_scalar_integer_widening() {
1149+
let precision = Precision::Exact(ScalarValue::Int32(Some(42)));
1150+
1151+
assert_eq!(
1152+
precision.add_for_sum(&Precision::Exact(ScalarValue::Int32(Some(23)))),
1153+
Precision::Exact(ScalarValue::Int64(Some(65))),
1154+
);
1155+
assert_eq!(
1156+
precision.add_for_sum(&Precision::Inexact(ScalarValue::Int32(Some(23)))),
1157+
Precision::Inexact(ScalarValue::Int64(Some(65))),
1158+
);
1159+
}
1160+
1161+
#[test]
1162+
fn test_add_for_sum_prevents_int32_overflow() {
1163+
let lhs = Precision::Exact(ScalarValue::Int32(Some(i32::MAX)));
1164+
let rhs = Precision::Exact(ScalarValue::Int32(Some(1)));
1165+
1166+
assert_eq!(
1167+
lhs.add_for_sum(&rhs),
1168+
Precision::Exact(ScalarValue::Int64(Some(i64::from(i32::MAX) + 1))),
1169+
);
1170+
}
1171+
1172+
#[test]
1173+
fn test_add_for_sum_scalar_unsigned_integer_widening() {
1174+
let precision = Precision::Exact(ScalarValue::UInt32(Some(42)));
1175+
1176+
assert_eq!(
1177+
precision.add_for_sum(&Precision::Exact(ScalarValue::UInt32(Some(23)))),
1178+
Precision::Exact(ScalarValue::UInt64(Some(65))),
1179+
);
1180+
assert_eq!(
1181+
precision.add_for_sum(&Precision::Inexact(ScalarValue::UInt32(Some(23)))),
1182+
Precision::Inexact(ScalarValue::UInt64(Some(65))),
1183+
);
1184+
}
1185+
10981186
#[test]
10991187
fn test_sub() {
11001188
let precision1 = Precision::Exact(42);
@@ -1340,7 +1428,7 @@ mod tests {
13401428
);
13411429
assert_eq!(
13421430
col1_stats.sum_value,
1343-
Precision::Exact(ScalarValue::Int32(Some(1100)))
1431+
Precision::Exact(ScalarValue::Int64(Some(1100)))
13441432
); // 500 + 600
13451433

13461434
let col2_stats = &summary_stats.column_statistics[1];
@@ -1355,7 +1443,7 @@ mod tests {
13551443
);
13561444
assert_eq!(
13571445
col2_stats.sum_value,
1358-
Precision::Exact(ScalarValue::Int32(Some(2200)))
1446+
Precision::Exact(ScalarValue::Int64(Some(2200)))
13591447
); // 1000 + 1200
13601448
}
13611449

datafusion/datasource/src/statistics.rs

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ fn sort_columns_from_physical_sort_exprs(
293293
since = "47.0.0",
294294
note = "Please use `get_files_with_limit` and `compute_all_files_statistics` instead"
295295
)]
296-
#[expect(unused)]
296+
#[cfg_attr(not(test), expect(unused))]
297297
pub async fn get_statistics_with_limit(
298298
all_files: impl Stream<Item = Result<(PartitionedFile, Arc<Statistics>)>>,
299299
file_schema: SchemaRef,
@@ -329,7 +329,7 @@ pub async fn get_statistics_with_limit(
329329
col_stats_set[index].null_count = file_column.null_count;
330330
col_stats_set[index].max_value = file_column.max_value;
331331
col_stats_set[index].min_value = file_column.min_value;
332-
col_stats_set[index].sum_value = file_column.sum_value;
332+
col_stats_set[index].sum_value = file_column.sum_value.cast_to_sum_type();
333333
}
334334

335335
// If the number of rows exceeds the limit, we can stop processing
@@ -374,7 +374,7 @@ pub async fn get_statistics_with_limit(
374374
col_stats.null_count = col_stats.null_count.add(file_nc);
375375
col_stats.max_value = col_stats.max_value.max(file_max);
376376
col_stats.min_value = col_stats.min_value.min(file_min);
377-
col_stats.sum_value = col_stats.sum_value.add(file_sum);
377+
col_stats.sum_value = col_stats.sum_value.add_for_sum(file_sum);
378378
col_stats.byte_size = col_stats.byte_size.add(file_sbs);
379379
}
380380

@@ -497,3 +497,78 @@ pub fn add_row_stats(
497497
) -> Precision<usize> {
498498
file_num_rows.add(&num_rows)
499499
}
500+
501+
#[cfg(test)]
502+
mod tests {
503+
use super::*;
504+
use crate::PartitionedFile;
505+
use arrow::datatypes::{DataType, Field, Schema};
506+
use futures::stream;
507+
508+
fn file_stats(sum: u32) -> Statistics {
509+
Statistics {
510+
num_rows: Precision::Exact(1),
511+
total_byte_size: Precision::Exact(4),
512+
column_statistics: vec![ColumnStatistics {
513+
null_count: Precision::Exact(0),
514+
max_value: Precision::Exact(ScalarValue::UInt32(Some(sum))),
515+
min_value: Precision::Exact(ScalarValue::UInt32(Some(sum))),
516+
sum_value: Precision::Exact(ScalarValue::UInt32(Some(sum))),
517+
distinct_count: Precision::Exact(1),
518+
byte_size: Precision::Exact(4),
519+
}],
520+
}
521+
}
522+
523+
#[tokio::test]
524+
#[expect(deprecated)]
525+
async fn test_get_statistics_with_limit_casts_first_file_sum_to_sum_type()
526+
-> Result<()> {
527+
let schema =
528+
Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)]));
529+
530+
let files = stream::iter(vec![Ok((
531+
PartitionedFile::new("f1.parquet", 1),
532+
Arc::new(file_stats(100)),
533+
))]);
534+
535+
let (_group, stats) =
536+
get_statistics_with_limit(files, schema, None, false).await?;
537+
538+
assert_eq!(
539+
stats.column_statistics[0].sum_value,
540+
Precision::Exact(ScalarValue::UInt64(Some(100)))
541+
);
542+
543+
Ok(())
544+
}
545+
546+
#[tokio::test]
547+
#[expect(deprecated)]
548+
async fn test_get_statistics_with_limit_merges_sum_with_unsigned_widening()
549+
-> Result<()> {
550+
let schema =
551+
Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)]));
552+
553+
let files = stream::iter(vec![
554+
Ok((
555+
PartitionedFile::new("f1.parquet", 1),
556+
Arc::new(file_stats(100)),
557+
)),
558+
Ok((
559+
PartitionedFile::new("f2.parquet", 1),
560+
Arc::new(file_stats(200)),
561+
)),
562+
]);
563+
564+
let (_group, stats) =
565+
get_statistics_with_limit(files, schema, None, true).await?;
566+
567+
assert_eq!(
568+
stats.column_statistics[0].sum_value,
569+
Precision::Exact(ScalarValue::UInt64(Some(300)))
570+
);
571+
572+
Ok(())
573+
}
574+
}

datafusion/physical-expr/src/projection.rs

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -693,12 +693,15 @@ impl ProjectionExprs {
693693
Precision::Absent
694694
};
695695

696-
let sum_value = Precision::<ScalarValue>::from(stats.num_rows)
697-
.cast_to(&value.data_type())
698-
.ok()
699-
.map(|row_count| {
700-
Precision::Exact(value.clone()).multiply(&row_count)
696+
let widened_sum = Precision::Exact(value.clone()).cast_to_sum_type();
697+
let sum_value = widened_sum
698+
.get_value()
699+
.and_then(|sum| {
700+
Precision::<ScalarValue>::from(stats.num_rows)
701+
.cast_to(&sum.data_type())
702+
.ok()
701703
})
704+
.map(|row_count| widened_sum.multiply(&row_count))
702705
.unwrap_or(Precision::Absent);
703706

704707
ColumnStatistics {
@@ -2864,6 +2867,35 @@ pub(crate) mod tests {
28642867
Ok(())
28652868
}
28662869

2870+
#[test]
2871+
fn test_project_statistics_with_i32_literal_sum_widens_to_i64() -> Result<()> {
2872+
let input_stats = get_stats();
2873+
let input_schema = get_schema();
2874+
2875+
let projection = ProjectionExprs::new(vec![
2876+
ProjectionExpr {
2877+
expr: Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
2878+
alias: "constant".to_string(),
2879+
},
2880+
ProjectionExpr {
2881+
expr: Arc::new(Column::new("col0", 0)),
2882+
alias: "num".to_string(),
2883+
},
2884+
]);
2885+
2886+
let output_stats = projection.project_statistics(
2887+
input_stats,
2888+
&projection.project_schema(&input_schema)?,
2889+
)?;
2890+
2891+
assert_eq!(
2892+
output_stats.column_statistics[0].sum_value,
2893+
Precision::Exact(ScalarValue::Int64(Some(50)))
2894+
);
2895+
2896+
Ok(())
2897+
}
2898+
28672899
// Test statistics calculation for NULL literal (constant NULL column)
28682900
#[test]
28692901
fn test_project_statistics_with_null_literal() -> Result<()> {

0 commit comments

Comments
 (0)