diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index ceec748a6e776..cfec5b09065b6 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -175,9 +175,10 @@ fn general_repeat( array: &ArrayRef, count_array: &Int64Array, ) -> Result { - let total_repeated_values: usize = (0..count_array.len()) - .map(|i| get_count_with_validity(count_array, i)) - .sum(); + let total_repeated_values = checked_sum_counts( + (0..count_array.len()).map(|i| get_count_with_validity(count_array, i)), + "array_repeat: total repeated values overflowed usize", + )?; let mut take_indices = Vec::with_capacity(total_repeated_values); let mut offsets = Vec::with_capacity(count_array.len() + 1); @@ -238,11 +239,24 @@ fn general_list_repeat( for i in 0..count_array.len() { let count = get_count_with_validity(count_array, i); if count > 0 { - outer_total += count; + outer_total = outer_total.checked_add(count).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: outer total overflowed usize".to_string(), + ) + })?; if list_array.is_valid(i) { let len = list_offsets[i + 1].to_usize().unwrap() - list_offsets[i].to_usize().unwrap(); - inner_total += len * count; + let repeated_len = len.checked_mul(count).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: inner total overflowed usize".to_string(), + ) + })?; + inner_total = inner_total.checked_add(repeated_len).ok_or_else(|| { + DataFusionError::Execution( + "array_repeat: inner total overflowed usize".to_string(), + ) + })?; } } } @@ -320,3 +334,14 @@ fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize { if c > 0 { c as usize } else { 0 } } } + +fn checked_sum_counts( + counts: impl IntoIterator, + overflow_message: &'static str, +) -> Result { + counts.into_iter().try_fold(0usize, |total, count| { + total + .checked_add(count) + .ok_or_else(|| DataFusionError::Execution(overflow_message.to_string())) + }) +} diff --git a/datafusion/sqllogictest/test_files/array/array_repeat.slt b/datafusion/sqllogictest/test_files/array/array_repeat.slt index 8052f09cb32c7..799986b99764e 100644 --- a/datafusion/sqllogictest/test_files/array/array_repeat.slt +++ b/datafusion/sqllogictest/test_files/array/array_repeat.slt @@ -76,6 +76,16 @@ Select ---- [] [] [] [] +# array_repeat returns an execution error on scalar output-size overflow +query error DataFusion error: Execution error: array_repeat: total repeated values overflowed usize +SELECT array_repeat(1, c) +FROM ( + VALUES + (9223372036854775807), + (9223372036854775807), + (9223372036854775807) +) AS t(c); + # array_repeat with columns #1 statement ok