Skip to content

Commit c66d57d

Browse files
committed
feat: support Spark luhn_check via StaticInvoke
Register datafusion-spark's SparkLuhnCheck UDF and add StaticInvoke handler for ExpressionImplUtils.isLuhnNumber (Spark 3.5+).
1 parent 45b670a commit c66d57d

3 files changed

Lines changed: 46 additions & 2 deletions

File tree

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ use datafusion_spark::function::math::hex::SparkHex;
5555
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
5656
use datafusion_spark::function::string::char::CharFunc;
5757
use datafusion_spark::function::string::concat::SparkConcat;
58+
use datafusion_spark::function::string::luhn_check::SparkLuhnCheck;
5859
use futures::poll;
5960
use futures::stream::StreamExt;
6061
use jni::objects::JByteBuffer;
@@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
400401
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default()));
401402
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default()));
402403
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default()));
404+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLuhnCheck::default()));
403405
}
404406

405407
/// Prepares arrow arrays for output.

spark/src/main/scala/org/apache/comet/serde/statics.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
package org.apache.comet.serde
2121

22-
import org.apache.spark.sql.catalyst.expressions.Attribute
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils}
2323
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
2424
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
25+
import org.apache.spark.sql.types.BooleanType
2526

2627
import org.apache.comet.CometSparkSessionExtensions.withInfo
28+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}
2729

2830
object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
2931

@@ -34,7 +36,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
3436
: Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] =
3537
Map(
3638
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
37-
"read_side_padding"))
39+
"read_side_padding"),
40+
("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometLuhnCheck)
3841

3942
override def convert(
4043
expr: StaticInvoke,
@@ -52,3 +55,23 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
5255
}
5356
}
5457
}
58+
59+
/**
60+
* Handler for ExpressionImplUtils.isLuhnNumber StaticInvoke (Spark 3.5+).
61+
* Maps to datafusion-spark's built-in luhn_check function.
62+
*/
63+
private object CometLuhnCheck extends CometExpressionSerde[StaticInvoke] {
64+
65+
override def convert(
66+
expr: StaticInvoke,
67+
inputs: Seq[Attribute],
68+
binding: Boolean): Option[ExprOuterClass.Expr] = {
69+
val childExpr = exprToProtoInternal(expr.arguments.head, inputs, binding)
70+
val optExpr = scalarFunctionExprToProtoWithReturnType(
71+
"luhn_check",
72+
BooleanType,
73+
false,
74+
childExpr)
75+
optExprWithInfo(optExpr, expr, expr.arguments.head)
76+
}
77+
}

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,25 @@ class CometStringExpressionSuite extends CometTestBase {
148148
}
149149
}
150150

151+
test("luhn_check") {
152+
val data = Seq(
153+
"79927398710", // invalid (fails Luhn)
154+
"79927398713", // valid Luhn number
155+
"1234567812345670", // valid credit card-like
156+
"0", // valid single digit
157+
"", // empty string
158+
"abc", // non-numeric
159+
null).map(Tuple1(_))
160+
withParquetTable(data, "tbl") {
161+
checkSparkAnswerAndOperator("SELECT luhn_check(_1) FROM tbl")
162+
// literal values
163+
checkSparkAnswerAndOperator("SELECT luhn_check('79927398713') FROM tbl")
164+
// null handling
165+
checkSparkAnswerAndOperator("SELECT luhn_check(NULL) FROM tbl")
166+
}
167+
}
168+
169+
151170
test("split string basic") {
152171
withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") {
153172
withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") {

0 commit comments

Comments
 (0)