Skip to content

Commit dfadd2d

Browse files
authored
chore: use DF scalar functions for StartsWith, EndsWith, Contains, DF LikeExpr (apache#1887)
1 parent 5ddd921 commit dfadd2d

8 files changed

Lines changed: 137 additions & 224 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use datafusion::{
4949
logical_expr::Operator as DataFusionOperator,
5050
physical_expr::{
5151
expressions::{
52-
in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr,
52+
in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr,
5353
Literal as DataFusionLiteral, NotExpr,
5454
},
5555
PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
@@ -104,10 +104,10 @@ use datafusion_comet_proto::{
104104
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
105105
};
106106
use datafusion_comet_spark_expr::{
107-
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Contains, Correlation, Covariance,
108-
CreateNamedStruct, EndsWith, GetArrayStructFields, GetStructField, IfExpr, Like, ListExtract,
109-
NormalizeNaNAndZero, RLike, RandExpr, SparkCastOptions, StartsWith, Stddev, StringSpaceExpr,
110-
SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance,
107+
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
108+
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike,
109+
RandExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
110+
TimestampTruncExpr, ToJson, UnboundColumn, Variance,
111111
};
112112
use itertools::Itertools;
113113
use jni::objects::GlobalRef;
@@ -511,33 +511,12 @@ impl PhysicalPlanner {
511511

512512
Ok(Arc::new(StringSpaceExpr::new(child)))
513513
}
514-
ExprStruct::Contains(expr) => {
515-
let left =
516-
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
517-
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
518-
519-
Ok(Arc::new(Contains::new(left, right)))
520-
}
521-
ExprStruct::StartsWith(expr) => {
522-
let left =
523-
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
524-
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
525-
526-
Ok(Arc::new(StartsWith::new(left, right)))
527-
}
528-
ExprStruct::EndsWith(expr) => {
529-
let left =
530-
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
531-
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
532-
533-
Ok(Arc::new(EndsWith::new(left, right)))
534-
}
535514
ExprStruct::Like(expr) => {
536515
let left =
537516
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
538517
let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
539518

540-
Ok(Arc::new(Like::new(left, right)))
519+
Ok(Arc::new(LikeExpr::new(false, false, left, right)))
541520
}
542521
ExprStruct::Rlike(expr) => {
543522
let left =
@@ -987,17 +966,25 @@ impl PhysicalPlanner {
987966
let predicate =
988967
self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?;
989968

990-
let filter: Arc<dyn ExecutionPlan> = if filter.use_datafusion_filter {
991-
Arc::new(DataFusionFilterExec::try_new(
992-
predicate,
993-
Arc::clone(&child.native_plan),
994-
)?)
995-
} else {
996-
Arc::new(CometFilterExec::try_new(
997-
predicate,
998-
Arc::clone(&child.native_plan),
999-
)?)
1000-
};
969+
let filter: Arc<dyn ExecutionPlan> =
970+
match (filter.wrap_child_in_copy_exec, filter.use_datafusion_filter) {
971+
(true, true) => Arc::new(DataFusionFilterExec::try_new(
972+
predicate,
973+
Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)),
974+
)?),
975+
(true, false) => Arc::new(CometFilterExec::try_new(
976+
predicate,
977+
Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)),
978+
)?),
979+
(false, true) => Arc::new(DataFusionFilterExec::try_new(
980+
predicate,
981+
Arc::clone(&child.native_plan),
982+
)?),
983+
(false, false) => Arc::new(CometFilterExec::try_new(
984+
predicate,
985+
Arc::clone(&child.native_plan),
986+
)?),
987+
};
1001988

1002989
Ok((
1003990
scans,
@@ -2868,6 +2855,7 @@ mod tests {
28682855
op_struct: Some(OpStruct::Filter(spark_operator::Filter {
28692856
predicate: Some(expr),
28702857
use_datafusion_filter: false,
2858+
wrap_child_in_copy_exec: false,
28712859
})),
28722860
}
28732861
}

native/proto/src/proto/expr.proto

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ message Expr {
5151
Second second = 24;
5252
CheckOverflow check_overflow = 25;
5353
BinaryExpr like = 26;
54-
BinaryExpr startsWith = 27;
55-
BinaryExpr endsWith = 28;
56-
BinaryExpr contains = 29;
5754
BinaryExpr rlike = 30;
5855
ScalarFunc scalarFunc = 31;
5956
BinaryExpr eqNullSafe = 32;

native/proto/src/proto/operator.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ message Projection {
110110
message Filter {
111111
spark.spark_expression.Expr predicate = 1;
112112
bool use_datafusion_filter = 2;
113+
// Some expressions don't support dictionary arrays, so may need to wrap the child in a CopyExec
114+
bool wrap_child_in_copy_exec = 3;
113115
}
114116

115117
message Sort {

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
// under the License.
1717

1818
mod chr;
19-
mod prediction;
2019
mod string_space;
2120
mod substring;
2221

2322
pub use chr::SparkChrFunc;
24-
pub use prediction::*;
2523
pub use string_space::StringSpaceExpr;
2624
pub use substring::SubstringExpr;

native/spark-expr/src/string_funcs/prediction.rs

Lines changed: 0 additions & 141 deletions
This file was deleted.

spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String
4646

4747
import org.apache.comet.parquet.SourceFilterSerde.{createBinaryExpr, createNameExpr, createUnaryExpr, createValueExpr}
4848
import org.apache.comet.serde.ExprOuterClass
49+
import org.apache.comet.serde.QueryPlanSerde.scalarFunctionExprToProto
4950
import org.apache.comet.shims.ShimSQLConf
5051

5152
/**
@@ -1011,23 +1012,29 @@ class ParquetFilters(
10111012
}
10121013
}
10131014

1014-
case sources.StringStartsWith(name, prefix)
1015-
if pushDownStringPredicate && canMakeFilterOn(name, prefix) =>
1016-
nameValueBinaryExpr(name, prefix) { (builder, binaryExpr) =>
1017-
builder.setStartsWith(binaryExpr)
1015+
case sources.StringStartsWith(attribute, prefix)
1016+
if pushDownStringPredicate && canMakeFilterOn(attribute, prefix) =>
1017+
val attributeExpr = createNameExpr(attribute, dataSchema)
1018+
val prefixExpr = attributeExpr.flatMap { case (dataType, _) =>
1019+
createValueExpr(prefix, dataType)
10181020
}
1021+
scalarFunctionExprToProto("starts_with", Some(attributeExpr.get._2), prefixExpr)
10191022

1020-
case sources.StringEndsWith(name, suffix)
1021-
if pushDownStringPredicate && canMakeFilterOn(name, suffix) =>
1022-
nameValueBinaryExpr(name, suffix) { (builder, binaryExpr) =>
1023-
builder.setEndsWith(binaryExpr)
1023+
case sources.StringEndsWith(attribute, suffix)
1024+
if pushDownStringPredicate && canMakeFilterOn(attribute, suffix) =>
1025+
val attributeExpr = createNameExpr(attribute, dataSchema)
1026+
val suffixExpr = attributeExpr.flatMap { case (dataType, _) =>
1027+
createValueExpr(suffix, dataType)
10241028
}
1029+
scalarFunctionExprToProto("ends_with", Some(attributeExpr.get._2), suffixExpr)
10251030

1026-
case sources.StringContains(name, value)
1027-
if pushDownStringPredicate && canMakeFilterOn(name, value) =>
1028-
nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
1029-
builder.setContains(binaryExpr)
1031+
case sources.StringContains(attribute, value)
1032+
if pushDownStringPredicate && canMakeFilterOn(attribute, value) =>
1033+
val attributeExpr = createNameExpr(attribute, dataSchema)
1034+
val valueExpr = attributeExpr.flatMap { case (dataType, _) =>
1035+
createValueExpr(value, dataType)
10301036
}
1037+
scalarFunctionExprToProto("contains", Some(attributeExpr.get._2), valueExpr)
10311038

10321039
case _ => None
10331040
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,32 +1026,20 @@ object QueryPlanSerde extends Logging with CometExprShim {
10261026
binding,
10271027
(builder, binaryExpr) => builder.setRlike(binaryExpr))
10281028

1029-
case StartsWith(left, right) =>
1030-
createBinaryExpr(
1031-
expr,
1032-
left,
1033-
right,
1034-
inputs,
1035-
binding,
1036-
(builder, binaryExpr) => builder.setStartsWith(binaryExpr))
1037-
1038-
case EndsWith(left, right) =>
1039-
createBinaryExpr(
1040-
expr,
1041-
left,
1042-
right,
1043-
inputs,
1044-
binding,
1045-
(builder, binaryExpr) => builder.setEndsWith(binaryExpr))
1046-
1047-
case Contains(left, right) =>
1048-
createBinaryExpr(
1049-
expr,
1050-
left,
1051-
right,
1052-
inputs,
1053-
binding,
1054-
(builder, binaryExpr) => builder.setContains(binaryExpr))
1029+
case StartsWith(attribute, prefix) =>
1030+
val attributeExpr = exprToProtoInternal(attribute, inputs, binding)
1031+
val prefixExpr = exprToProtoInternal(prefix, inputs, binding)
1032+
scalarFunctionExprToProto("starts_with", attributeExpr, prefixExpr)
1033+
1034+
case EndsWith(attribute, suffix) =>
1035+
val attributeExpr = exprToProtoInternal(attribute, inputs, binding)
1036+
val suffixExpr = exprToProtoInternal(suffix, inputs, binding)
1037+
scalarFunctionExprToProto("ends_with", attributeExpr, suffixExpr)
1038+
1039+
case Contains(attribute, value) =>
1040+
val attributeExpr = exprToProtoInternal(attribute, inputs, binding)
1041+
val valueExpr = exprToProtoInternal(value, inputs, binding)
1042+
scalarFunctionExprToProto("contains", attributeExpr, valueExpr)
10551043

10561044
case StringSpace(child) =>
10571045
createUnaryExpr(
@@ -2326,10 +2314,20 @@ object QueryPlanSerde extends Logging with CometExprShim {
23262314
}
23272315
}
23282316

2317+
// Some native expressions do not support operating on dictionary-encoded arrays, so
2318+
// wrap the child in a CopyExec to unpack dictionaries first.
2319+
def wrapChildInCopyExec(condition: Expression): Boolean = {
2320+
condition.exists(expr => {
2321+
expr.isInstanceOf[StartsWith] || expr.isInstanceOf[EndsWith] || expr
2322+
.isInstanceOf[Contains]
2323+
})
2324+
}
2325+
23292326
val filterBuilder = OperatorOuterClass.Filter
23302327
.newBuilder()
23312328
.setPredicate(cond.get)
23322329
.setUseDatafusionFilter(!containsNativeCometScan(op))
2330+
.setWrapChildInCopyExec(wrapChildInCopyExec(condition))
23332331
Some(result.setFilter(filterBuilder).build())
23342332
} else {
23352333
withInfo(op, condition, child)

0 commit comments

Comments
 (0)