diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0193f3012c..fe639f47fc 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -55,6 +55,7 @@ use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::math::width_bucket::SparkWidthBucket; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; +use datafusion_spark::function::string::elt::SparkElt; use futures::poll; use futures::stream::StreamExt; use jni::objects::JByteBuffer; @@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkElt::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..21db29afcd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -178,7 +178,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Left] -> CometLeft, classOf[Right] -> CometRight, classOf[Substring] -> CometSubstring, - classOf[Upper] -> CometUpper) + classOf[Upper] -> CometUpper, + classOf[Elt] -> CometElt) private val bitwiseExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[BitwiseAnd] -> CometBitwiseAnd, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 64ba644048..3484be56bc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Elt, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -29,7 +29,7 @@ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp} import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde._ object CometStringRepeat extends CometExpressionSerde[StringRepeat] { @@ -382,6 +382,16 @@ object CometStringSplit extends CometExpressionSerde[StringSplit] { } } +object CometElt extends CometScalarFunction[Elt]("elt") { + + override def getSupportLevel(expr: Elt): SupportLevel = { + if (expr.failOnError) { + return Unsupported(Some("ANSI mode not supported")) + } + Compatible(None) + } +} + trait CommonStringExprs { def stringDecode( diff --git a/spark/src/test/resources/sql-tests/expressions/string/elt.sql b/spark/src/test/resources/sql-tests/expressions/string/elt.sql new file mode 100644 index 0000000000..fe49bd25fd --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/elt.sql @@ -0,0 +1,27 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_elt(a string, b string, c string, idx int) USING parquet + +statement +INSERT INTO test_elt VALUES ('a', 'b', 'c', 1), ('a', 'b', '', 2), (NULL, 'b', 'c', NULL), ('a', NULL, 'c', -100), (NULL, NULL, NULL, 0) + +query +SELECT elt(idx, a, b, c) FROM test_elt diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 121d7f7d5a..7968e10fe5 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -23,8 +23,9 @@ import scala.util.Random import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.spark.sql.types.{DataTypes, StringType, StructField, StructType} import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} @@ -478,4 +479,43 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("elt") { + val wrongNumArgsWithoutSuggestionExceptionMsg = + "[WRONG_NUM_ARGS.WITHOUT_SUGGESTION] The `elt` requires > 1 parameters but the actual number is 1." + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "false", + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + val r = new Random(42) + val fieldsCount = 10 + val indexes = Seq.range(1, fieldsCount) + val edgeCasesIndexes = Seq(-1, 0, -100, fieldsCount + 100) + val schema = indexes + .foldLeft(new StructType())((schema, idx) => + schema.add(s"c$idx", StringType, nullable = true)) + val df = FuzzDataGenerator.generateDataFrame( + r, + spark, + schema, + 100, + DataGenOptions(maxStringLength = 6)) + df.withColumn( + "idx", + lit(Random.shuffle(indexes ++ edgeCasesIndexes).headOption.getOrElse(-1))) + .createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql(s"SELECT elt(idx, ${schema.fieldNames.mkString(",")}) FROM t1")) + checkSparkAnswerAndOperator( + sql(s"SELECT elt(cast(null as int), ${schema.fieldNames.mkString(",")}) FROM t1")) + checkSparkAnswerMaybeThrows(sql("SELECT elt(1) FROM t1")) match { + case (Some(spark), Some(comet)) => + assert(spark.getMessage.contains(wrongNumArgsWithoutSuggestionExceptionMsg)) + assert(comet.getMessage.contains(wrongNumArgsWithoutSuggestionExceptionMsg)) + case (spark, comet) => + fail( + s"Expected Spark and Comet to throw exception, but got\nSpark: $spark\nComet: $comet") + } + checkSparkAnswerAndOperator("SELECT elt(2, 'a', 'b', 'c')") + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index c7c750aed6..d219477d77 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -76,7 +76,8 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("substring", "select substring(c1, 1, 100) from parquetV1Table"), StringExprConfig("translate", "select translate(c1, '123456', 'aBcDeF') from parquetV1Table"), StringExprConfig("trim", "select trim(c1) from parquetV1Table"), - StringExprConfig("upper", "select upper(c1) from parquetV1Table")) + StringExprConfig("upper", "select upper(c1) from parquetV1Table"), + StringExprConfig("elt", "select elt(2, c1, c1) from parquetV1Table")) override def runCometBenchmark(mainArgs: Array[String]): Unit = { runBenchmarkWithTable("String expressions", 1024) { v =>