diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index ceec748a6e776..825530923fa79 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -38,8 +38,12 @@ use datafusion_expr::{ }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; +use std::mem::size_of; use std::sync::Arc; +const ARRAY_REPEAT_LENGTH_EXCEEDED: &str = + "array_repeat: requested length exceeds maximum array size"; + make_udf_expr_and_func!( ArrayRepeat, array_repeat, @@ -175,28 +179,12 @@ fn general_repeat( array: &ArrayRef, count_array: &Int64Array, ) -> Result { - let total_repeated_values: usize = (0..count_array.len()) - .map(|i| get_count_with_validity(count_array, i)) - .sum(); + let (offsets, total_repeated_values) = build_repeat_offsets::(count_array)?; let mut take_indices = Vec::with_capacity(total_repeated_values); - let mut offsets = Vec::with_capacity(count_array.len() + 1); - offsets.push(O::zero()); - let mut running_offset = 0usize; for idx in 0..count_array.len() { let count = get_count_with_validity(count_array, idx); - running_offset = running_offset.checked_add(count).ok_or_else(|| { - DataFusionError::Execution( - "array_repeat: running_offset overflowed usize".to_string(), - ) - })?; - let offset = O::from_usize(running_offset).ok_or_else(|| { - DataFusionError::Execution(format!( - "array_repeat: offset {running_offset} exceeds the maximum value for offset type" - )) - })?; - offsets.push(offset); take_indices.extend(std::iter::repeat_n(idx as u64, count)); } @@ -231,23 +219,23 @@ fn general_list_repeat( count_array: &Int64Array, ) -> Result { let list_offsets = list_array.value_offsets(); + let (outer_offsets, outer_total) = build_repeat_offsets::(count_array)?; // calculate capacities for pre-allocation - let mut outer_total = 0usize; let mut inner_total = 0usize; for i in 0..count_array.len() { let count = get_count_with_validity(count_array, i); - if count > 0 { - outer_total += count; - if list_array.is_valid(i) { - let len = list_offsets[i + 1].to_usize().unwrap() - - list_offsets[i].to_usize().unwrap(); - inner_total += len * count; - } + if count > 0 && list_array.is_valid(i) { + let len = list_offsets[i + 1].to_usize().unwrap() + - list_offsets[i].to_usize().unwrap(); + inner_total = + checked_repeat_len_add(inner_total, checked_repeat_len_mul(len, count)?)?; + ensure_array_repeat_output_len::(inner_total)?; } } // Build inner structures + ensure_vec_capacity::(checked_repeat_len_add(outer_total, 1)?)?; let mut inner_offsets = Vec::with_capacity(outer_total + 1); let mut take_indices = Vec::with_capacity(inner_total); let mut inner_nulls = BooleanBufferBuilder::new(outer_total); @@ -262,11 +250,8 @@ fn general_list_repeat( let row_len = end - start; for _ in 0..count { - inner_running = inner_running.checked_add(row_len).ok_or_else(|| { - DataFusionError::Execution( - "array_repeat: inner offset overflowed usize".to_string(), - ) - })?; + inner_running = checked_repeat_len_add(inner_running, row_len)?; + ensure_array_repeat_output_len::(inner_running)?; let offset = O::from_usize(inner_running).ok_or_else(|| { DataFusionError::Execution(format!( "array_repeat: offset {inner_running} exceeds the maximum value for offset type" @@ -299,16 +284,85 @@ fn general_list_repeat( list_array.data_type().to_owned(), true, )), - OffsetBuffer::::from_lengths( - count_array - .iter() - .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)), - ), + OffsetBuffer::new(outer_offsets.into()), Arc::new(inner_list), count_array.nulls().cloned(), )?)) } +fn build_repeat_offsets( + count_array: &Int64Array, +) -> Result<(Vec, usize)> { + let mut offsets = Vec::with_capacity(count_array.len() + 1); + offsets.push(O::zero()); + let mut running_offset = 0usize; + + for idx in 0..count_array.len() { + let count = get_count_with_validity(count_array, idx); + running_offset = checked_repeat_len_add(running_offset, count)?; + ensure_array_repeat_output_len::(running_offset)?; + let offset = O::from_usize(running_offset).ok_or_else(|| { + DataFusionError::Execution(format!( + "array_repeat: offset {running_offset} exceeds the maximum value for offset type" + )) + })?; + offsets.push(offset); + } + + Ok((offsets, running_offset)) +} + +fn checked_repeat_len_add(lhs: usize, rhs: usize) -> Result { + lhs.checked_add(rhs).ok_or_else(|| { + DataFusionError::Execution(ARRAY_REPEAT_LENGTH_EXCEEDED.to_string()) + }) +} + +fn checked_repeat_len_mul(lhs: usize, rhs: usize) -> Result { + lhs.checked_mul(rhs).ok_or_else(|| { + DataFusionError::Execution(ARRAY_REPEAT_LENGTH_EXCEEDED.to_string()) + }) +} + +fn ensure_array_repeat_output_len(len: usize) -> Result<()> { + if len > max_array_repeat_output_len::() { + return Err(DataFusionError::Execution( + ARRAY_REPEAT_LENGTH_EXCEEDED.to_string(), + )); + } + + Ok(()) +} + +fn ensure_vec_capacity(len: usize) -> Result<()> { + if len > max_vec_elements::() { + return Err(DataFusionError::Execution( + ARRAY_REPEAT_LENGTH_EXCEEDED.to_string(), + )); + } + + Ok(()) +} + +fn max_array_repeat_output_len() -> usize { + max_offset_elements::().min(max_vec_elements::()) +} + +fn max_offset_elements() -> usize { + if size_of::() == size_of::() { + i32::MAX as usize + } else { + i64::MAX as usize + } +} + +fn max_vec_elements() -> usize { + let element_size = size_of::(); + (isize::MAX as usize) + .checked_div(element_size) + .unwrap_or(usize::MAX) +} + /// Helper function to get count from count_array at given index /// Return 0 for null values or non-positive count. #[inline] @@ -320,3 +374,22 @@ fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize { if c > 0 { c as usize } else { 0 } } } + +#[cfg(test)] +mod tests { + use super::array_repeat_inner; + use arrow::array::{ArrayRef, Int64Array}; + use std::sync::Arc; + + #[test] + fn scalar_count_exceeding_max_array_size_returns_error() { + let element: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let count: ArrayRef = Arc::new(Int64Array::from(vec![i64::MAX])); + + let err = array_repeat_inner(&[element, count]).unwrap_err(); + assert_eq!( + err.to_string(), + "Execution error: array_repeat: requested length exceeds maximum array size" + ); + } +} diff --git a/datafusion/sqllogictest/test_files/array/array_repeat.slt b/datafusion/sqllogictest/test_files/array/array_repeat.slt index 8052f09cb32c7..9f17c449c88c2 100644 --- a/datafusion/sqllogictest/test_files/array/array_repeat.slt +++ b/datafusion/sqllogictest/test_files/array/array_repeat.slt @@ -43,6 +43,9 @@ select ---- [[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[NULL, NULL], [NULL, NULL], [NULL, NULL]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +query error DataFusion error: Execution error: array_repeat: requested length exceeds maximum array size +select array_repeat(1, 9223372036854775807); + query ???? select array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5),