Skip to content

Commit 63fb715

Browse files
author
Kazantsev Maksim
committed
Feat: impl elt function
1 parent d7857b2 commit 63fb715

4 files changed

Lines changed: 52 additions & 4 deletions

File tree

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ use datafusion_spark::function::math::expm1::SparkExpm1;
5151
use datafusion_spark::function::math::hex::SparkHex;
5252
use datafusion_spark::function::string::char::CharFunc;
5353
use datafusion_spark::function::string::concat::SparkConcat;
54+
use datafusion_spark::function::string::elt::SparkElt;
5455
use futures::poll;
5556
use futures::stream::StreamExt;
5657
use jni::objects::JByteBuffer;
@@ -351,6 +352,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
351352
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default()));
352353
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
353354
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default()));
355+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkElt::default()));
354356
}
355357

356358
/// Prepares arrow arrays for output.

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
172172
classOf[StringTrimRight] -> CometScalarFunction("rtrim"),
173173
classOf[Left] -> CometLeft,
174174
classOf[Substring] -> CometSubstring,
175-
classOf[Upper] -> CometUpper)
175+
classOf[Upper] -> CometUpper,
176+
classOf[Elt] -> CometElt)
176177

177178
private val bitwiseExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
178179
classOf[BitwiseAnd] -> CometBitwiseAnd,

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Elt, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
2626

2727
import org.apache.comet.CometConf
@@ -289,6 +289,16 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] {
289289
}
290290
}
291291

292+
object CometElt extends CometScalarFunction[Elt]("elt") {
293+
294+
override def getSupportLevel(expr: Elt): SupportLevel = {
295+
if (expr.failOnError) {
296+
return Unsupported(Some("failOnError=true is not supported"))
297+
}
298+
Compatible(None)
299+
}
300+
}
301+
292302
trait CommonStringExprs {
293303

294304
def stringDecode(

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ package org.apache.comet
2222
import scala.util.Random
2323

2424
import org.apache.parquet.hadoop.ParquetOutputFormat
25-
import org.apache.spark.sql.{CometTestBase, DataFrame}
25+
import org.apache.spark.sql.{functions, CometTestBase, DataFrame}
26+
import org.apache.spark.sql.functions.lit
2627
import org.apache.spark.sql.internal.SQLConf
27-
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
28+
import org.apache.spark.sql.types.{DataTypes, StringType, StructField, StructType}
2829

2930
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
3031

@@ -391,4 +392,38 @@ class CometStringExpressionSuite extends CometTestBase {
391392
}
392393
}
393394

395+
test("elt") {
396+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
397+
val r = new Random(42)
398+
val fieldsCount = 10
399+
val indexes = Seq.range(1, fieldsCount)
400+
val edgeCasesIndexes = Seq(-1, 0, -100, fieldsCount + 100)
401+
val schema = indexes
402+
.foldLeft(new StructType())((schema, idx) =>
403+
schema.add(s"c$idx", StringType, nullable = true))
404+
val df = FuzzDataGenerator.generateDataFrame(
405+
r,
406+
spark,
407+
schema,
408+
100,
409+
DataGenOptions(maxStringLength = 6))
410+
df.withColumn(
411+
"idx",
412+
lit(Random.shuffle(indexes ++ edgeCasesIndexes).headOption.getOrElse(-1)))
413+
.createOrReplaceTempView("t1")
414+
checkSparkAnswerAndOperator(
415+
sql(s"SELECT elt(idx, ${schema.fieldNames.mkString(",")}) FROM t1"))
416+
checkSparkAnswerAndOperator(
417+
sql(s"SELECT elt(cast(null as int), ${schema.fieldNames.mkString(",")}) FROM t1"))
418+
checkSparkAnswerMaybeThrows(sql(s"SELECT elt(1) FROM t1")) match {
419+
case (Some(spark), Some(comet)) =>
420+
assert(spark.getMessage.contains("WRONG_NUM_ARGS.WITHOUT_SUGGESTION"))
421+
assert(comet.getMessage.contains("WRONG_NUM_ARGS.WITHOUT_SUGGESTION"))
422+
case (spark, comet) =>
423+
fail(
424+
s"Expected Spark and Comet to throw exception, but got\nSpark: $spark\nComet: $comet")
425+
}
426+
}
427+
}
428+
394429
}

0 commit comments

Comments
 (0)