Skip to content

Commit 7943199

Browse files
authored
perf: Optimize contains expression with SIMD-based scalar pattern sea… (#2991)
1 parent 599af33 commit 7943199

4 files changed

Lines changed: 269 additions & 3 deletions

File tree

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
2222
use crate::{
2323
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
2424
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
25-
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateDiff, SparkDateTrunc,
26-
SparkSizeFunc, SparkStringSpace,
25+
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff,
26+
SparkDateTrunc, SparkSizeFunc, SparkStringSpace,
2727
};
2828
use arrow::datatypes::DataType;
2929
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
192192
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
193193
vec![
194194
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
195+
Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
195196
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
196197
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
197198
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Optimized `contains` string function for Spark compatibility.
19+
//!
20+
//! Optimized for scalar pattern case by passing scalar directly to arrow_contains
21+
//! instead of expanding to arrays like DataFusion's built-in contains.
22+
23+
use arrow::array::{Array, ArrayRef, BooleanArray, StringArray};
24+
use arrow::compute::kernels::comparison::contains as arrow_contains;
25+
use arrow::datatypes::DataType;
26+
use datafusion::common::{exec_err, Result, ScalarValue};
27+
use datafusion::logical_expr::{
28+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
29+
};
30+
use std::any::Any;
31+
use std::sync::Arc;
32+
33+
/// Spark-optimized contains function.
34+
/// Returns true if the first string argument contains the second string argument.
35+
/// Optimized for scalar pattern constants.
36+
#[derive(Debug, PartialEq, Eq, Hash)]
37+
pub struct SparkContains {
38+
signature: Signature,
39+
}
40+
41+
impl Default for SparkContains {
42+
fn default() -> Self {
43+
Self::new()
44+
}
45+
}
46+
47+
impl SparkContains {
48+
pub fn new() -> Self {
49+
Self {
50+
signature: Signature::variadic_any(Volatility::Immutable),
51+
}
52+
}
53+
}
54+
55+
impl ScalarUDFImpl for SparkContains {
56+
fn as_any(&self) -> &dyn Any {
57+
self
58+
}
59+
60+
fn name(&self) -> &str {
61+
"contains"
62+
}
63+
64+
fn signature(&self) -> &Signature {
65+
&self.signature
66+
}
67+
68+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
69+
Ok(DataType::Boolean)
70+
}
71+
72+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
73+
if args.args.len() != 2 {
74+
return exec_err!("contains function requires exactly 2 arguments");
75+
}
76+
spark_contains(&args.args[0], &args.args[1])
77+
}
78+
}
79+
80+
/// Execute contains function with optimized scalar pattern handling.
81+
fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result<ColumnarValue> {
82+
match (haystack, needle) {
83+
// Both arrays - use arrow's contains directly
84+
(ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => {
85+
let result = arrow_contains(haystack_array, needle_array)?;
86+
Ok(ColumnarValue::Array(Arc::new(result)))
87+
}
88+
89+
// Array haystack, scalar needle - OPTIMIZED PATH
90+
(ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => {
91+
let result = contains_with_arrow_scalar(haystack_array, needle_scalar)?;
92+
Ok(ColumnarValue::Array(result))
93+
}
94+
95+
// Scalar haystack, array needle - less common
96+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => {
97+
let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?;
98+
let result = arrow_contains(&haystack_array, needle_array)?;
99+
Ok(ColumnarValue::Array(Arc::new(result)))
100+
}
101+
102+
// Both scalars - compute single result
103+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => {
104+
let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?;
105+
Ok(ColumnarValue::Scalar(result))
106+
}
107+
}
108+
}
109+
110+
/// Optimized contains for array haystack with scalar needle.
111+
/// Uses Arrow's native scalar handling for better performance.
112+
fn contains_with_arrow_scalar(
113+
haystack_array: &ArrayRef,
114+
needle_scalar: &ScalarValue,
115+
) -> Result<ArrayRef> {
116+
// Handle null needle
117+
if needle_scalar.is_null() {
118+
return Ok(Arc::new(BooleanArray::new_null(haystack_array.len())));
119+
}
120+
121+
// Extract the needle string
122+
let needle_str = match needle_scalar {
123+
ScalarValue::Utf8(Some(s))
124+
| ScalarValue::LargeUtf8(Some(s))
125+
| ScalarValue::Utf8View(Some(s)) => s.clone(),
126+
_ => {
127+
return exec_err!(
128+
"contains function requires string type for needle, got {:?}",
129+
needle_scalar.data_type()
130+
)
131+
}
132+
};
133+
134+
// Create scalar array for needle - tells Arrow to use optimized paths
135+
let needle_scalar_array = StringArray::new_scalar(needle_str);
136+
137+
// Use Arrow's contains which detects scalar case and uses optimized paths
138+
let result = arrow_contains(haystack_array, &needle_scalar_array)?;
139+
Ok(Arc::new(result))
140+
}
141+
142+
/// Contains for two scalar values.
143+
fn contains_scalar_scalar(
144+
haystack_scalar: &ScalarValue,
145+
needle_scalar: &ScalarValue,
146+
) -> Result<ScalarValue> {
147+
// Handle nulls
148+
if haystack_scalar.is_null() || needle_scalar.is_null() {
149+
return Ok(ScalarValue::Boolean(None));
150+
}
151+
152+
let haystack_str = match haystack_scalar {
153+
ScalarValue::Utf8(Some(s))
154+
| ScalarValue::LargeUtf8(Some(s))
155+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
156+
_ => {
157+
return exec_err!(
158+
"contains function requires string type for haystack, got {:?}",
159+
haystack_scalar.data_type()
160+
)
161+
}
162+
};
163+
164+
let needle_str = match needle_scalar {
165+
ScalarValue::Utf8(Some(s))
166+
| ScalarValue::LargeUtf8(Some(s))
167+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
168+
_ => {
169+
return exec_err!(
170+
"contains function requires string type for needle, got {:?}",
171+
needle_scalar.data_type()
172+
)
173+
}
174+
};
175+
176+
Ok(ScalarValue::Boolean(Some(
177+
haystack_str.contains(needle_str),
178+
)))
179+
}
180+
181+
#[cfg(test)]
182+
mod tests {
183+
use super::*;
184+
use arrow::array::StringArray;
185+
186+
#[test]
187+
fn test_contains_array_scalar() {
188+
let haystack = Arc::new(StringArray::from(vec![
189+
Some("hello world"),
190+
Some("foo bar"),
191+
Some("testing"),
192+
None,
193+
])) as ArrayRef;
194+
let needle = ScalarValue::Utf8(Some("world".to_string()));
195+
196+
let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
197+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
198+
199+
assert!(bool_array.value(0)); // "hello world" contains "world"
200+
assert!(!bool_array.value(1)); // "foo bar" does not contain "world"
201+
assert!(!bool_array.value(2)); // "testing" does not contain "world"
202+
assert!(bool_array.is_null(3)); // null input => null output
203+
}
204+
205+
#[test]
206+
fn test_contains_scalar_scalar() {
207+
let haystack = ScalarValue::Utf8(Some("hello world".to_string()));
208+
let needle = ScalarValue::Utf8(Some("world".to_string()));
209+
210+
let result = contains_scalar_scalar(&haystack, &needle).unwrap();
211+
assert_eq!(result, ScalarValue::Boolean(Some(true)));
212+
213+
let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string()));
214+
let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap();
215+
assert_eq!(result, ScalarValue::Boolean(Some(false)));
216+
}
217+
218+
#[test]
219+
fn test_contains_null_needle() {
220+
let haystack = Arc::new(StringArray::from(vec![
221+
Some("hello world"),
222+
Some("foo bar"),
223+
])) as ArrayRef;
224+
let needle = ScalarValue::Utf8(None);
225+
226+
let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
227+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
228+
229+
// Null needle should produce null results
230+
assert!(bool_array.is_null(0));
231+
assert!(bool_array.is_null(1));
232+
}
233+
234+
#[test]
235+
fn test_contains_empty_needle() {
236+
let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef;
237+
let needle = ScalarValue::Utf8(Some("".to_string()));
238+
239+
let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
240+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
241+
242+
// Empty string is contained in any string
243+
assert!(bool_array.value(0));
244+
assert!(bool_array.value(1));
245+
}
246+
}

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod contains;
1819
mod string_space;
1920
mod substring;
2021

22+
pub use contains::SparkContains;
2123
pub use string_space::SparkStringSpace;
2224
pub use substring::SubstringExpr;

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11631163

11641164
// Filter rows that contains 'rose' in 'name' column
11651165
val queryContains = sql(s"select id from $table where contains (name, 'rose')")
1166-
checkAnswer(queryContains, Row(5) :: Nil)
1166+
checkSparkAnswerAndOperator(queryContains)
1167+
1168+
// Additional test cases for optimized contains implementation
1169+
// Test with empty pattern (should match all non-null rows)
1170+
val queryEmptyPattern = sql(s"select id from $table where contains (name, '')")
1171+
checkSparkAnswerAndOperator(queryEmptyPattern)
1172+
1173+
// Test with pattern not found
1174+
val queryNotFound = sql(s"select id from $table where contains (name, 'xyz')")
1175+
checkSparkAnswerAndOperator(queryNotFound)
1176+
1177+
// Test with pattern at start
1178+
val queryStart = sql(s"select id from $table where contains (name, 'James')")
1179+
checkSparkAnswerAndOperator(queryStart)
1180+
1181+
// Test with pattern at end
1182+
val queryEnd = sql(s"select id from $table where contains (name, 'Smith')")
1183+
checkSparkAnswerAndOperator(queryEnd)
11671184
}
11681185
}
11691186

0 commit comments

Comments
 (0)