Skip to content

Commit 17dddd0

Browse files
authored
refactor: move arithmetic and math support checks to getSupportLevel (#4674)
1 parent b26a8df commit 17dddd0

2 files changed

Lines changed: 62 additions & 51 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,24 @@ trait MathBase {
8383
false
8484
}
8585

86+
def mathDataTypeSupportLevel(dt: DataType): SupportLevel =
87+
if (supportedDataType(dt)) {
88+
Compatible()
89+
} else {
90+
Unsupported(Some(s"Unsupported datatype $dt"))
91+
}
92+
8693
}
8794

8895
object CometAdd extends CometExpressionSerde[Add] with MathBase {
8996

97+
override def getSupportLevel(expr: Add): SupportLevel =
98+
mathDataTypeSupportLevel(expr.left.dataType)
99+
90100
override def convert(
91101
expr: Add,
92102
inputs: Seq[Attribute],
93103
binding: Boolean): Option[ExprOuterClass.Expr] = {
94-
if (!supportedDataType(expr.left.dataType)) {
95-
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
96-
return None
97-
}
98104
createMathExpression(
99105
expr,
100106
expr.left,
@@ -109,14 +115,13 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase {
109115

110116
object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
111117

118+
override def getSupportLevel(expr: Subtract): SupportLevel =
119+
mathDataTypeSupportLevel(expr.left.dataType)
120+
112121
override def convert(
113122
expr: Subtract,
114123
inputs: Seq[Attribute],
115124
binding: Boolean): Option[ExprOuterClass.Expr] = {
116-
if (!supportedDataType(expr.left.dataType)) {
117-
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
118-
return None
119-
}
120125
createMathExpression(
121126
expr,
122127
expr.left,
@@ -131,14 +136,13 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
131136

132137
object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
133138

139+
override def getSupportLevel(expr: Multiply): SupportLevel =
140+
mathDataTypeSupportLevel(expr.left.dataType)
141+
134142
override def convert(
135143
expr: Multiply,
136144
inputs: Seq[Attribute],
137145
binding: Boolean): Option[ExprOuterClass.Expr] = {
138-
if (!supportedDataType(expr.left.dataType)) {
139-
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
140-
return None
141-
}
142146
createMathExpression(
143147
expr,
144148
expr.left,
@@ -153,6 +157,9 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
153157

154158
object CometDivide extends CometExpressionSerde[Divide] with MathBase {
155159

160+
override def getSupportLevel(expr: Divide): SupportLevel =
161+
mathDataTypeSupportLevel(expr.left.dataType)
162+
156163
override def convert(
157164
expr: Divide,
158165
inputs: Seq[Attribute],
@@ -162,10 +169,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
162169
// For now, use NullIf to swap zeros with nulls.
163170
val rightExpr =
164171
if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(expr.right) else expr.right
165-
if (!supportedDataType(expr.left.dataType)) {
166-
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
167-
return None
168-
}
169172
val divideExpr = createMathExpression(
170173
expr,
171174
expr.left,
@@ -195,14 +198,13 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
195198

196199
object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase {
197200

201+
override def getSupportLevel(expr: IntegralDivide): SupportLevel =
202+
mathDataTypeSupportLevel(expr.left.dataType)
203+
198204
override def convert(
199205
expr: IntegralDivide,
200206
inputs: Seq[Attribute],
201207
binding: Boolean): Option[ExprOuterClass.Expr] = {
202-
if (!supportedDataType(expr.left.dataType)) {
203-
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
204-
return None
205-
}
206208

207209
// Precision is set to 19 (max precision for a numerical data type except DecimalType)
208210

@@ -259,15 +261,13 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
259261

260262
object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
261263

264+
override def getSupportLevel(expr: Remainder): SupportLevel =
265+
mathDataTypeSupportLevel(expr.left.dataType)
266+
262267
override def convert(
263268
expr: Remainder,
264269
inputs: Seq[Attribute],
265270
binding: Boolean): Option[ExprOuterClass.Expr] = {
266-
if (!supportedDataType(expr.left.dataType)) {
267-
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
268-
return None
269-
}
270-
271271
createMathExpression(
272272
expr,
273273
expr.left,
@@ -282,6 +282,29 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
282282

283283
object CometRound extends CometExpressionSerde[Round] {
284284

285+
override def getSupportLevel(expr: Round): SupportLevel = expr.child.dataType match {
286+
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
287+
Unsupported(Some("Decimal type has negative scale"))
288+
case _: FloatType | DoubleType =>
289+
// We cannot properly match with the Spark behavior for floating-point numbers.
290+
// Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
291+
// double to string internally in order to create its own internal representation.
292+
// The problem is BigDecimal uses java.lang.Double.toString() and it has complicated
293+
// rounding algorithm. E.g. -5.81855622136895E8 is actually
294+
// -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of
295+
// 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
296+
// difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be
297+
// -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that
298+
// toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can
299+
// be rounded up to 6.13171162472835E18 that still represents the same double number.
300+
// I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not.
301+
// That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead
302+
// of 6.1317116247283999E18.
303+
Unsupported(Some("Comet does not support Spark's BigDecimal rounding"))
304+
case _ =>
305+
Compatible()
306+
}
307+
285308
override def convert(
286309
r: Round,
287310
inputs: Seq[Attribute],
@@ -292,30 +315,10 @@ object CometRound extends CometExpressionSerde[Round] {
292315

293316
lazy val childExpr = exprToProtoInternal(r.child, inputs, binding)
294317
r.child.dataType match {
295-
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
296-
withFallbackReason(r, "Decimal type has negative scale")
297-
None
298318
case _ if scaleV == null =>
299319
exprToProtoInternal(Literal(null), inputs, binding)
300320
case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 =>
301321
childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark
302-
case _: FloatType | DoubleType =>
303-
// We cannot properly match with the Spark behavior for floating-point numbers.
304-
// Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
305-
// double to string internally in order to create its own internal representation.
306-
// The problem is BigDecimal uses java.lang.Double.toString() and it has complicated
307-
// rounding algorithm. E.g. -5.81855622136895E8 is actually
308-
// -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of
309-
// 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
310-
// difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be
311-
// -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that
312-
// toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can
313-
// be rounded up to 6.13171162472835E18 that still represents the same double number.
314-
// I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not.
315-
// That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead
316-
// of 6.1317116247283999E18.
317-
withFallbackReason(r, "Comet does not support Spark's BigDecimal rounding")
318-
None
319322
case _ =>
320323
// `scale` must be Int64 type in DataFusion
321324
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding)

spark/src/main/scala/org/apache/comet/serde/math.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ object CometAtan2 extends CometExpressionSerde[Atan2] {
4141
}
4242

4343
object CometCeil extends CometExpressionSerde[Ceil] {
44+
override def getSupportLevel(expr: Ceil): SupportLevel = expr.child.dataType match {
45+
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
46+
Unsupported(Some(s"Decimal type $t has negative scale"))
47+
case _ =>
48+
Compatible()
49+
}
50+
4451
override def convert(
4552
expr: Ceil,
4653
inputs: Seq[Attribute],
@@ -49,9 +56,6 @@ object CometCeil extends CometExpressionSerde[Ceil] {
4956
expr.child.dataType match {
5057
case t: DecimalType if t.scale == 0 => // zero scale is no-op
5158
childExpr
52-
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
53-
withFallbackReason(expr, s"Decimal type $t has negative scale")
54-
None
5559
case _ =>
5660
val optExpr =
5761
scalarFunctionExprToProtoWithReturnType("ceil", expr.dataType, false, childExpr)
@@ -61,6 +65,13 @@ object CometCeil extends CometExpressionSerde[Ceil] {
6165
}
6266

6367
object CometFloor extends CometExpressionSerde[Floor] {
68+
override def getSupportLevel(expr: Floor): SupportLevel = expr.child.dataType match {
69+
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
70+
Unsupported(Some(s"Decimal type $t has negative scale"))
71+
case _ =>
72+
Compatible()
73+
}
74+
6475
override def convert(
6576
expr: Floor,
6677
inputs: Seq[Attribute],
@@ -69,9 +80,6 @@ object CometFloor extends CometExpressionSerde[Floor] {
6980
expr.child.dataType match {
7081
case t: DecimalType if t.scale == 0 => // zero scale is no-op
7182
childExpr
72-
case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
73-
withFallbackReason(expr, s"Decimal type $t has negative scale")
74-
None
7583
case _ =>
7684
val optExpr =
7785
scalarFunctionExprToProtoWithReturnType("floor", expr.dataType, false, childExpr)

0 commit comments

Comments
 (0)