Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,43 @@ fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
if count <= 0 {
return Ok(String::new());
}
let result_len = s.len().saturating_mul(count as usize);
let result_len = repeat_len(s.len(), count, max_size)?;
debug_assert!(result_len <= max_size);
let count = repeat_count(count, max_size)?;
Ok(s.repeat(count))
}

fn repeat_len(string_len: usize, count: i64, max_size: usize) -> Result<usize> {
let count = repeat_count(count, max_size)?;
let result_len = match string_len.checked_mul(count) {
Some(result_len) => result_len,
None => {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_size,
usize::MAX
);
}
};
if result_len > max_size {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_size,
result_len
);
}
Ok(s.repeat(count as usize))
Ok(result_len)
}

fn repeat_count(count: i64, max_size: usize) -> Result<usize> {
match usize::try_from(count) {
Ok(count) => Ok(count),
Err(_) => exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_size,
usize::MAX
),
}
}

/// Repeats string the specified number of times.
Expand Down Expand Up @@ -227,22 +255,24 @@ fn calculate_capacities<'a, S>(
where
S: StringArrayType<'a>,
{
let mut total_capacity = 0;
let mut max_item_capacity = 0;
let mut total_capacity = 0usize;
let mut max_item_capacity = 0usize;

string_array.iter().zip(number_array.iter()).try_for_each(
|(string, number)| -> Result<(), DataFusionError> {
match (string, number) {
(Some(string), Some(number)) if number >= 0 => {
let item_capacity = string.len() * number as usize;
if item_capacity > max_str_len {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_str_len,
number as usize * string.len()
);
}
total_capacity += item_capacity;
let item_capacity = repeat_len(string.len(), number, max_str_len)?;
total_capacity = match total_capacity.checked_add(item_capacity) {
Some(total_capacity) => total_capacity,
None => {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_str_len,
usize::MAX
);
}
};
max_item_capacity = max_item_capacity.max(item_capacity);
}
_ => (),
Expand Down Expand Up @@ -487,6 +517,18 @@ mod tests {
assert_sliced_offset_output::<StringArray>(result);
}

#[test]
fn test_repeat_string_array_overflow() {
let strings: ArrayRef = Arc::new(StringArray::from(vec![Some("abc")]));
let counts: ArrayRef = Arc::new(Int64Array::from(vec![Some(i64::MAX)]));

let err = super::repeat(&strings, &counts).unwrap_err().to_string();
assert!(
err.contains("string size overflow on repeat"),
"unexpected error: {err}"
);
}

#[test]
fn test_repeat_sliced_large_string_with_null_offset() {
let (strings, counts) =
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_literal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3)
----
foofoofoo

query error DataFusion error: Execution error: string size overflow on repeat, max size is 2147483647, but got 18446744073709551615
SELECT repeat(x, 9223372036854775807)
FROM (VALUES ('abc')) AS t(x);

query T
SELECT arrow_typeof(repeat('foo', 3))
----
Expand Down