From 5fd1588a4b947fa839ce072decf41c8296689229 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 22 May 2026 12:56:03 +0000 Subject: [PATCH] [SPARK-57004][SQL] Refactor CheckOverflowInSum codegen under ANSI mode ### What changes were proposed in this pull request? Introduce `DecimalExpressionUtils.java` with a `checkOverflowInSum(Decimal, int, int, boolean, QueryContext)` static helper and call it from `CheckOverflowInSum.doGenCode` and the eval path. Codegen body shrinks from a 10-line if/else block (init `ev.value` to null, branch on `childGen.isNull`, conditionally throw, conditionally call `toPrecision` + re-set `ev.isNull`) to 4 lines (single helper call + post-check). Eval is now a single delegating call. ### Why are the changes needed? Part of SPARK-56908 (umbrella). `CheckOverflowInSum` is emitted around every decimal `Sum` and is one of the longer remaining inline ANSI bodies. Collapsing the per-call-site body shrinks generated Java source and Janino compile time on aggregation-heavy plans. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *DecimalExpressionSuite *AggregateExpressionSuite" build/sbt "sql/testOnly *DataFrameAggregateSuite -- -z sum" ``` 20/20 pass (incl. the SPARK-39208 `CheckOverflowInSum` runtime-context test). ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../expressions/DecimalExpressionUtils.java | 53 +++++++++++++++++++ .../expressions/decimalExpressions.scala | 41 ++++---------- 2 files changed, 62 insertions(+), 32 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DecimalExpressionUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DecimalExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DecimalExpressionUtils.java new file mode 100644 index 0000000000000..f3454339e9a2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DecimalExpressionUtils.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.errors.QueryExecutionErrors; +import org.apache.spark.sql.types.Decimal; + +/** + * Static helpers shared by decimal expression {@code doGenCode} (and + * corresponding eval paths). The codegen invokes these via a single static + * call, replacing the multi-line inline overflow-handling body. + */ +public final class DecimalExpressionUtils { + + private DecimalExpressionUtils() {} + + /** + * Apply the target {@code precision}/{@code scale} to a {@code Sum} aggregate + * result and convert a {@code null} input into the {@code Sum}-specific + * overflow error. {@code Sum} uses {@code null} in its aggregation buffer to + * indicate "running total overflowed"; this method either rethrows that as + * {@code overflowInSumOfDecimalError} or propagates the {@code null}, gated + * by {@code nullOnOverflow}. + */ + public static Decimal checkOverflowInSum( + Decimal value, + int precision, + int scale, + boolean nullOnOverflow, + QueryContext context) { + if (value == null) { + if (nullOnOverflow) return null; + throw QueryExecutionErrors.overflowInSumOfDecimalError(context, "try_sum"); + } + return value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP(), nullOnOverflow, context); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 3e463595ba674..f24c907681502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -163,45 +163,22 @@ case class CheckOverflowInSum( override def nullable: Boolean = true - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - if (nullOnOverflow) null - else { - throw QueryExecutionErrors.overflowInSumOfDecimalError(context, suggestedFunc = "try_sum") - } - } else { - value.asInstanceOf[Decimal].toPrecision( - dataType.precision, - dataType.scale, - Decimal.ROUND_HALF_UP, - nullOnOverflow, - context) - } - } + override def eval(input: InternalRow): Any = DecimalExpressionUtils.checkOverflowInSum( + child.eval(input).asInstanceOf[Decimal], + dataType.precision, dataType.scale, nullOnOverflow, context) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) - val nullHandling = if (nullOnOverflow) { - "" - } else { - s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_sum");""" - } - // scalastyle:off line.size.limit + val helper = classOf[DecimalExpressionUtils].getName + val input = ctx.freshName("input") val code = code""" |${childGen.code} - |boolean ${ev.isNull} = ${childGen.isNull}; - |Decimal ${ev.value} = null; - |if (${childGen.isNull}) { - | $nullHandling - |} else { - | ${ev.value} = ${childGen.value}.toPrecision( - | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode); - | ${ev.isNull} = ${ev.value} == null; - |} + |Decimal $input = ${childGen.isNull} ? null : ${childGen.value}; + |Decimal ${ev.value} = $helper.checkOverflowInSum( + | $input, ${dataType.precision}, ${dataType.scale}, $nullOnOverflow, $errorContextCode); + |boolean ${ev.isNull} = ${ev.value} == null; |""".stripMargin - // scalastyle:on line.size.limit ev.copy(code = code) }