|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use std::sync::Arc; |
19 | | - |
| 18 | +use crate::strings::{ |
| 19 | + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder, |
| 20 | +}; |
20 | 21 | use crate::utils::make_scalar_function; |
21 | 22 | use DataType::{LargeUtf8, Utf8, Utf8View}; |
22 | | -use arrow::array::{ |
23 | | - Array, ArrayRef, AsArray, LargeStringBuilder, StringArrayType, StringBuilder, |
24 | | - StringLikeArrayBuilder, StringViewBuilder, |
25 | | -}; |
| 23 | +use arrow::array::{Array, ArrayRef, AsArray, StringArrayType}; |
26 | 24 | use arrow::datatypes::DataType; |
27 | 25 | use datafusion_common::{Result, exec_err}; |
28 | 26 | use datafusion_expr::{ |
@@ -103,56 +101,79 @@ fn reverse(args: &[ArrayRef]) -> Result<ArrayRef> { |
103 | 101 | let len = args[0].len(); |
104 | 102 |
|
105 | 103 | match args[0].data_type() { |
| 104 | + LargeUtf8 => reverse_impl( |
| 105 | + &args[0].as_string::<i64>(), |
| 106 | + GenericStringArrayBuilder::<i64>::with_capacity(len, 1024), |
| 107 | + ), |
106 | 108 | Utf8 => reverse_impl( |
107 | 109 | &args[0].as_string::<i32>(), |
108 | | - StringBuilder::with_capacity(len, 1024), |
| 110 | + GenericStringArrayBuilder::<i32>::with_capacity(len, 1024), |
109 | 111 | ), |
110 | 112 | Utf8View => reverse_impl( |
111 | 113 | &args[0].as_string_view(), |
112 | | - StringViewBuilder::with_capacity(len), |
113 | | - ), |
114 | | - LargeUtf8 => reverse_impl( |
115 | | - &args[0].as_string::<i64>(), |
116 | | - LargeStringBuilder::with_capacity(len, 1024), |
| 114 | + StringViewArrayBuilder::with_capacity(len), |
117 | 115 | ), |
118 | 116 | _ => unreachable!( |
119 | 117 | "Reverse can only be applied to Utf8View, Utf8 and LargeUtf8 types" |
120 | 118 | ), |
121 | 119 | } |
122 | 120 | } |
123 | 121 |
|
124 | | -fn reverse_impl<'a, StringArrType, StringBuilderType>( |
| 122 | +fn reverse_impl<'a, StringArrType, B>( |
125 | 123 | string_array: &StringArrType, |
126 | | - mut array_builder: StringBuilderType, |
| 124 | + mut array_builder: B, |
127 | 125 | ) -> Result<ArrayRef> |
128 | 126 | where |
129 | 127 | StringArrType: StringArrayType<'a>, |
130 | | - StringBuilderType: StringLikeArrayBuilder, |
| 128 | + B: BulkNullStringArrayBuilder, |
131 | 129 | { |
| 130 | + let item_len = string_array.len(); |
| 131 | + // Null-preserving: reuse the input null buffer as the output null buffer. |
| 132 | + let nulls = string_array.nulls().cloned(); |
132 | 133 | let mut string_buf = String::new(); |
133 | 134 | let mut byte_buf = Vec::<u8>::new(); |
134 | 135 |
|
135 | | - for string in string_array.iter() { |
136 | | - if let Some(s) = string { |
137 | | - if s.is_ascii() { |
138 | | - // reverse bytes directly since ASCII characters are single bytes |
139 | | - byte_buf.extend(s.as_bytes()); |
140 | | - byte_buf.reverse(); |
141 | | - // SAFETY: Since the original string was ASCII, reversing the bytes still results in valid UTF-8. |
142 | | - let reversed = unsafe { std::str::from_utf8_unchecked(&byte_buf) }; |
143 | | - array_builder.append_value(reversed); |
144 | | - byte_buf.clear(); |
| 136 | + if let Some(ref n) = nulls { |
| 137 | + for i in 0..item_len { |
| 138 | + if n.is_null(i) { |
| 139 | + array_builder.append_placeholder(); |
145 | 140 | } else { |
146 | | - string_buf.extend(s.chars().rev()); |
147 | | - array_builder.append_value(&string_buf); |
148 | | - string_buf.clear(); |
| 141 | + // SAFETY: `n.is_null(i)` was false in the branch above. |
| 142 | + let s = unsafe { string_array.value_unchecked(i) }; |
| 143 | + append_reversed(s, &mut array_builder, &mut byte_buf, &mut string_buf); |
149 | 144 | } |
150 | | - } else { |
151 | | - array_builder.append_null(); |
| 145 | + } |
| 146 | + } else { |
| 147 | + for i in 0..item_len { |
| 148 | + // SAFETY: no null buffer means every index is valid. |
| 149 | + let s = unsafe { string_array.value_unchecked(i) }; |
| 150 | + append_reversed(s, &mut array_builder, &mut byte_buf, &mut string_buf); |
152 | 151 | } |
153 | 152 | } |
154 | 153 |
|
155 | | - Ok(Arc::new(array_builder.finish()) as ArrayRef) |
| 154 | + array_builder.finish(nulls) |
| 155 | +} |
| 156 | + |
| 157 | +#[inline] |
| 158 | +fn append_reversed<B: BulkNullStringArrayBuilder>( |
| 159 | + s: &str, |
| 160 | + builder: &mut B, |
| 161 | + byte_buf: &mut Vec<u8>, |
| 162 | + string_buf: &mut String, |
| 163 | +) { |
| 164 | + if s.is_ascii() { |
| 165 | + // reverse bytes directly since ASCII characters are single bytes |
| 166 | + byte_buf.extend(s.as_bytes()); |
| 167 | + byte_buf.reverse(); |
| 168 | + // SAFETY: input was ASCII, so reversed bytes are still valid UTF-8. |
| 169 | + let reversed = unsafe { std::str::from_utf8_unchecked(byte_buf) }; |
| 170 | + builder.append_value(reversed); |
| 171 | + byte_buf.clear(); |
| 172 | + } else { |
| 173 | + string_buf.extend(s.chars().rev()); |
| 174 | + builder.append_value(string_buf); |
| 175 | + string_buf.clear(); |
| 176 | + } |
156 | 177 | } |
157 | 178 |
|
158 | 179 | #[cfg(test)] |
@@ -213,4 +234,58 @@ mod tests { |
213 | 234 |
|
214 | 235 | Ok(()) |
215 | 236 | } |
| 237 | + |
| 238 | + #[test] |
| 239 | + fn test_array_with_nulls() { |
| 240 | + use crate::unicode::reverse::reverse; |
| 241 | + use arrow::array::ArrayRef; |
| 242 | + use std::sync::Arc; |
| 243 | + |
| 244 | + let input_values = vec![Some("abcd"), None, Some("XYZ"), Some("héllo"), None]; |
| 245 | + let expected: Vec<Option<&str>> = |
| 246 | + vec![Some("dcba"), None, Some("ZYX"), Some("olléh"), None]; |
| 247 | + |
| 248 | + let cases: Vec<(&str, ArrayRef)> = vec![ |
| 249 | + ( |
| 250 | + "StringArray", |
| 251 | + Arc::new(StringArray::from(input_values.clone())), |
| 252 | + ), |
| 253 | + ( |
| 254 | + "LargeStringArray", |
| 255 | + Arc::new(LargeStringArray::from(input_values.clone())), |
| 256 | + ), |
| 257 | + ( |
| 258 | + "StringViewArray", |
| 259 | + Arc::new(StringViewArray::from(input_values.clone())), |
| 260 | + ), |
| 261 | + ]; |
| 262 | + |
| 263 | + for (label, input) in cases { |
| 264 | + let out = reverse(&[input]).unwrap(); |
| 265 | + assert_eq!(out.len(), expected.len(), "{label}: length mismatch"); |
| 266 | + |
| 267 | + let actual: Vec<Option<&str>> = match out.data_type() { |
| 268 | + Utf8 => out |
| 269 | + .as_any() |
| 270 | + .downcast_ref::<StringArray>() |
| 271 | + .unwrap() |
| 272 | + .iter() |
| 273 | + .collect(), |
| 274 | + LargeUtf8 => out |
| 275 | + .as_any() |
| 276 | + .downcast_ref::<LargeStringArray>() |
| 277 | + .unwrap() |
| 278 | + .iter() |
| 279 | + .collect(), |
| 280 | + Utf8View => out |
| 281 | + .as_any() |
| 282 | + .downcast_ref::<StringViewArray>() |
| 283 | + .unwrap() |
| 284 | + .iter() |
| 285 | + .collect(), |
| 286 | + other => panic!("{label}: unexpected output type {other:?}"), |
| 287 | + }; |
| 288 | + assert_eq!(actual, expected, "{label}: value mismatch"); |
| 289 | + } |
| 290 | + } |
216 | 291 | } |
0 commit comments