diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index b551d2ac707a9..a53f1e2e4fc42 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -26,7 +26,9 @@ use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; use datafusion_common::utils::take_function_args; -use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err}; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, +}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -166,7 +168,21 @@ fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { if count <= 0 { return Ok(String::new()); } - let result_len = s.len().saturating_mul(count as usize); + let result_len = repeat_len(s.len(), count, max_size)?; + debug_assert!(result_len <= max_size); + let count = repeat_count(count, max_size)?; + Ok(s.repeat(count)) +} + +fn repeat_len(string_len: usize, count: i64, max_size: usize) -> Result { + let count = repeat_count(count, max_size)?; + let result_len = string_len.checked_mul(count).ok_or_else(|| { + exec_datafusion_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + usize::MAX + ) + })?; if result_len > max_size { return exec_err!( "string size overflow on repeat, max size is {}, but got {}", @@ -174,7 +190,18 @@ fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { result_len ); } - Ok(s.repeat(count as usize)) + Ok(result_len) +} + +fn repeat_count(count: i64, max_size: usize) -> Result { + match usize::try_from(count) { + Ok(count) => Ok(count), + Err(_) => exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + usize::MAX + ), + } } /// Repeats string the specified number of times. @@ -227,22 +254,22 @@ fn calculate_capacities<'a, S>( where S: StringArrayType<'a>, { - let mut total_capacity = 0; - let mut max_item_capacity = 0; + let mut total_capacity = 0usize; + let mut max_item_capacity = 0usize; string_array.iter().zip(number_array.iter()).try_for_each( |(string, number)| -> Result<(), DataFusionError> { match (string, number) { (Some(string), Some(number)) if number >= 0 => { - let item_capacity = string.len() * number as usize; - if item_capacity > max_str_len { - return exec_err!( - "string size overflow on repeat, max size is {}, but got {}", - max_str_len, - number as usize * string.len() - ); - } - total_capacity += item_capacity; + let item_capacity = repeat_len(string.len(), number, max_str_len)?; + total_capacity = + total_capacity.checked_add(item_capacity).ok_or_else(|| { + exec_datafusion_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_str_len, + usize::MAX + ) + })?; max_item_capacity = max_item_capacity.max(item_capacity); } _ => (), @@ -487,6 +514,18 @@ mod tests { assert_sliced_offset_output::(result); } + #[test] + fn test_repeat_string_array_overflow() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![Some("abc")])); + let counts: ArrayRef = Arc::new(Int64Array::from(vec![Some(i64::MAX)])); + + let err = super::repeat(&strings, &counts).unwrap_err().to_string(); + assert!( + err.contains("string size overflow on repeat"), + "unexpected error: {err}" + ); + } + #[test] fn test_repeat_sliced_large_string_with_null_offset() { let (strings, counts) = diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 97f2a40c13fea..d7547bf145dd9 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -391,6 +391,10 @@ SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) ---- foofoofoo +query error DataFusion error: Execution error: string size overflow on repeat, max size is 2147483647, but got \d+ +SELECT repeat(x, 9223372036854775807) +FROM (VALUES ('abc')) AS t(x); + query T SELECT arrow_typeof(repeat('foo', 3)) ----