Skip to content

Commit 7583a15

Browse files
committed
PR check fixes
1 parent 6f76a80 commit 7583a15

3 files changed

Lines changed: 31 additions & 5 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ package org.apache.comet.serde
2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
2323

2424
import org.apache.comet.serde.ExprOuterClass.Expr
25-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}
25+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
2626

2727
/** Serde for scalar function. */
2828
case class CometScalarFunction[T <: Expression](name: String) extends CometExpressionSerde[T] {
2929
override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
3030
val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
31-
val optExpr = scalarFunctionExprToProtoWithReturnType(name, expr.dataType, false, childExpr: _*)
31+
val optExpr = scalarFunctionExprToProto(name, childExpr: _*)
3232
optExprWithInfo(optExpr, expr, expr.children: _*)
3333
}
3434
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
166166
classOf[StringRPad] -> CometStringRPad,
167167
classOf[StringLPad] -> CometStringLPad,
168168
classOf[StringSpace] -> CometScalarFunction("string_space"),
169-
classOf[StringSplit] -> CometScalarFunction("split"),
169+
classOf[StringSplit] -> CometStringSplit,
170170
classOf[StringTranslate] -> CometScalarFunction("translate"),
171171
classOf[StringTrim] -> CometScalarFunction("trim"),
172172
classOf[StringTrimBoth] -> CometScalarFunction("btrim"),

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
2626

2727
import org.apache.comet.CometConf
2828
import org.apache.comet.CometSparkSessionExtensions.withInfo
2929
import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp}
3030
import org.apache.comet.serde.ExprOuterClass.Expr
31-
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
31+
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}
3232

3333
object CometStringRepeat extends CometExpressionSerde[StringRepeat] {
3434

@@ -289,6 +289,32 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] {
289289
}
290290
}
291291

292+
/**
293+
* Serde for StringSplit expression.
294+
* This is a custom Comet function (not a built-in DataFusion function),
295+
* so we need to include the return type in the protobuf to avoid
296+
* DataFusion registry lookup failures.
297+
*/
298+
object CometStringSplit extends CometExpressionSerde[StringSplit] {
299+
300+
override def convert(
301+
expr: StringSplit,
302+
inputs: Seq[Attribute],
303+
binding: Boolean): Option[Expr] = {
304+
val strExpr = exprToProtoInternal(expr.str, inputs, binding)
305+
val regexExpr = exprToProtoInternal(expr.regex, inputs, binding)
306+
val limitExpr = exprToProtoInternal(expr.limit, inputs, binding)
307+
val optExpr = scalarFunctionExprToProtoWithReturnType(
308+
"split",
309+
expr.dataType,
310+
false,
311+
strExpr,
312+
regexExpr,
313+
limitExpr)
314+
optExprWithInfo(optExpr, expr, expr.str, expr.regex, expr.limit)
315+
}
316+
}
317+
292318
trait CommonStringExprs {
293319

294320
def stringDecode(

0 commit comments

Comments
 (0)