Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use datafusion_spark::function::math::hex::SparkHex;
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
use datafusion_spark::function::string::char::CharFunc;
use datafusion_spark::function::string::concat::SparkConcat;
use datafusion_spark::function::string::luhn_check::SparkLuhnCheck;
use futures::poll;
use futures::stream::StreamExt;
use jni::objects::JByteBuffer;
Expand Down Expand Up @@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLuhnCheck::default()));
}

/// Prepares arrow arrays for output.
Expand Down
24 changes: 22 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.types.BooleanType

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}

object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {

Expand All @@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"))
"read_side_padding"),
("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometLuhnCheck)

override def convert(
expr: StaticInvoke,
Expand All @@ -52,3 +55,20 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
}
}
}

/**
* Handler for ExpressionImplUtils.isLuhnNumber StaticInvoke (Spark 3.5+). Maps to
* datafusion-spark's built-in luhn_check function.
*/
private object CometLuhnCheck extends CometExpressionSerde[StaticInvoke] {

override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val childExpr = exprToProtoInternal(expr.arguments.head, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("luhn_check", BooleanType, false, childExpr)
optExprWithInfo(optExpr, expr, expr.arguments.head)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ class CometStringExpressionSuite extends CometTestBase {
}
}

test("luhn_check") {
val data = Seq(
"79927398710", // invalid (fails Luhn)
"79927398713", // valid Luhn number
"1234567812345670", // valid credit card-like
"0", // valid single digit
"", // empty string
"abc", // non-numeric
null).map(Tuple1(_))
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator("SELECT luhn_check(_1) FROM tbl")
// literal values
checkSparkAnswerAndOperator("SELECT luhn_check('79927398713') FROM tbl")
// null handling
checkSparkAnswerAndOperator("SELECT luhn_check(NULL) FROM tbl")
}
}

test("split string basic") {
withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") {
withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") {
Expand Down
Loading