Skip to content

Commit 576070a

Browse files
committed
[SPARK-57181][SQL] Simplify Pmod codegen by sharing a MathUtils.pmod helper with eval
### What changes were proposed in this pull request? `Pmod.doGenCode` emitted the positive-modulo `remainder`/adjust block inline, once for byte/short and once for the int/long/float/double case (~6-8 lines each), duplicating the algorithm already implemented by `Pmod`'s private `pmod` eval methods. This adds `MathUtils.pmod` overloads (Int, Long, Byte, Short, Float, Double) -- the exact bodies moved out of `Pmod` -- and routes both the eval dispatch (`pmodFunc`) and codegen through them. The primitive codegen cases collapse to a single `MathUtils.pmod(left, right)` call. The Decimal case (which returns null / applies `toPrecision`) is unchanged. ### Why are the changes needed? Part of SPARK-56908 (umbrella). `Pmod` over IntegerType is emitted by every `HashPartitioning` (`Pmod(Murmur3Hash(...), numPartitions)`), so collapsing the inline block shrinks the generated Java on a very common path, and the eval and codegen paths now share one implementation instead of duplicating the algorithm (helping with the JVM 64KB method / constant-pool limits, Janino compile time, and JIT work). ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? Existing `ArithmeticExpressionSuite."pmod"` covers all numeric types, negative operands / divisors, mod-by-zero (ANSI on/off), and `checkConsistencyBetweenInterpretedAndCodegenAllowingException` across all numeric types (which verifies eval and codegen agree -- exactly the invariant this refactor must preserve). ``` build/sbt "catalyst/testOnly *ArithmeticExpressionSuite" ``` 35/35 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code (Opus 4.8) Closes #56232 from gengliangwang/spark-pmod-codegen. Authored-by: Gengliang Wang <gengliang@apache.org> Signed-off-by: Gengliang Wang <gengliang@apache.org>
1 parent 2d83237 commit 576070a

2 files changed

Lines changed: 46 additions & 54 deletions

File tree

sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,40 @@ object MathUtils {
8989

9090
def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))
9191

92+
// Positive modulo (`pmod`): the remainder `a % n` adjusted to share the sign of `n`.
93+
// Unlike `floorMod`, this matches the `pmod` SQL function / `HashPartitioning` semantics.
94+
// Shared by `Pmod`'s eval and codegen paths so the two never diverge.
95+
96+
def pmod(a: Int, n: Int): Int = {
97+
val r = a % n
98+
if (r < 0) (r + n) % n else r
99+
}
100+
101+
def pmod(a: Long, n: Long): Long = {
102+
val r = a % n
103+
if (r < 0) (r + n) % n else r
104+
}
105+
106+
def pmod(a: Byte, n: Byte): Byte = {
107+
val r = a % n
108+
if (r < 0) ((r + n) % n).toByte else r.toByte
109+
}
110+
111+
def pmod(a: Short, n: Short): Short = {
112+
val r = a % n
113+
if (r < 0) ((r + n) % n).toShort else r.toShort
114+
}
115+
116+
def pmod(a: Float, n: Float): Float = {
117+
val r = a % n
118+
if (r < 0) (r + n) % n else r
119+
}
120+
121+
def pmod(a: Double, n: Double): Double = {
122+
val r = a % n
123+
if (r < 0) (r + n) % n else r
124+
}
125+
92126
def withOverflow[A](f: => A, hint: String = "", context: QueryContext = null): A = {
93127
try {
94128
f

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,12 +1080,12 @@ case class Pmod(
10801080
}
10811081

10821082
private lazy val pmodFunc: (Any, Any) => Any = dataType match {
1083-
case _: IntegerType => (l, r) => pmod(l.asInstanceOf[Int], r.asInstanceOf[Int])
1084-
case _: LongType => (l, r) => pmod(l.asInstanceOf[Long], r.asInstanceOf[Long])
1085-
case _: ShortType => (l, r) => pmod(l.asInstanceOf[Short], r.asInstanceOf[Short])
1086-
case _: ByteType => (l, r) => pmod(l.asInstanceOf[Byte], r.asInstanceOf[Byte])
1087-
case _: FloatType => (l, r) => pmod(l.asInstanceOf[Float], r.asInstanceOf[Float])
1088-
case _: DoubleType => (l, r) => pmod(l.asInstanceOf[Double], r.asInstanceOf[Double])
1083+
case _: IntegerType => (l, r) => MathUtils.pmod(l.asInstanceOf[Int], r.asInstanceOf[Int])
1084+
case _: LongType => (l, r) => MathUtils.pmod(l.asInstanceOf[Long], r.asInstanceOf[Long])
1085+
case _: ShortType => (l, r) => MathUtils.pmod(l.asInstanceOf[Short], r.asInstanceOf[Short])
1086+
case _: ByteType => (l, r) => MathUtils.pmod(l.asInstanceOf[Byte], r.asInstanceOf[Byte])
1087+
case _: FloatType => (l, r) => MathUtils.pmod(l.asInstanceOf[Float], r.asInstanceOf[Float])
1088+
case _: DoubleType => (l, r) => MathUtils.pmod(l.asInstanceOf[Double], r.asInstanceOf[Double])
10891089
case DecimalType.Fixed(precision, scale) => (l, r) => checkDecimalOverflow(
10901090
pmod(l.asInstanceOf[Decimal], r.asInstanceOf[Decimal]), precision, scale)
10911091
}
@@ -1120,6 +1120,7 @@ case class Pmod(
11201120
val remainder = ctx.freshName("remainder")
11211121
val javaType = CodeGenerator.javaType(dataType)
11221122
val errorContext = getContextOrNullCode(ctx)
1123+
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
11231124
val result = dataType match {
11241125
case DecimalType.Fixed(precision, scale) =>
11251126
val decimalAdd = "$plus"
@@ -1135,25 +1136,12 @@ case class Pmod(
11351136
|${ev.isNull} = ${ev.value} == null;
11361137
|""".stripMargin
11371138

1138-
// byte and short are casted into int when add, minus, times or divide
1139-
case ByteType | ShortType =>
1140-
s"""
1141-
$javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
1142-
if ($remainder < 0) {
1143-
${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value});
1144-
} else {
1145-
${ev.value}=$remainder;
1146-
}
1147-
"""
1139+
// The positive-modulo arithmetic is the same fixed algorithm for every primitive numeric
1140+
// type, so delegate to the shared MathUtils.pmod helper (also used by the eval path) instead
1141+
// of emitting the remainder/adjust block inline. byte/short are widened to int by `%`, and
1142+
// the matching MathUtils.pmod overload narrows the result back.
11481143
case _ =>
1149-
s"""
1150-
$javaType $remainder = ${eval1.value} % ${eval2.value};
1151-
if ($remainder < 0) {
1152-
${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
1153-
} else {
1154-
${ev.value}=$remainder;
1155-
}
1156-
"""
1144+
s"${ev.value} = $mathUtils.pmod(${eval1.value}, ${eval2.value});"
11571145
}
11581146

11591147
// evaluate right first as we have a chance to skip left if right is 0
@@ -1198,36 +1186,6 @@ case class Pmod(
11981186
}
11991187
}
12001188

1201-
private def pmod(a: Int, n: Int): Int = {
1202-
val r = a % n
1203-
if (r < 0) {(r + n) % n} else r
1204-
}
1205-
1206-
private def pmod(a: Long, n: Long): Long = {
1207-
val r = a % n
1208-
if (r < 0) {(r + n) % n} else r
1209-
}
1210-
1211-
private def pmod(a: Byte, n: Byte): Byte = {
1212-
val r = a % n
1213-
if (r < 0) {((r + n) % n).toByte} else r.toByte
1214-
}
1215-
1216-
private def pmod(a: Double, n: Double): Double = {
1217-
val r = a % n
1218-
if (r < 0) {(r + n) % n} else r
1219-
}
1220-
1221-
private def pmod(a: Short, n: Short): Short = {
1222-
val r = a % n
1223-
if (r < 0) {((r + n) % n).toShort} else r.toShort
1224-
}
1225-
1226-
private def pmod(a: Float, n: Float): Float = {
1227-
val r = a % n
1228-
if (r < 0) {(r + n) % n} else r
1229-
}
1230-
12311189
private def pmod(a: Decimal, n: Decimal): Decimal = {
12321190
val r = a % n
12331191
if (r != null && r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r

0 commit comments

Comments
 (0)