Skip to content

Commit 5966055

Browse files
committed
cleanup
1 parent cbf96df commit 5966055

9 files changed

Lines changed: 174 additions & 147 deletions

File tree

common/src/main/java/org/apache/comet/udf/CometUdfBridge.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,6 @@ private static void evaluateInternal(
199199
}
200200

201201
result = udf.evaluate(inputs, numRows);
202-
assert result instanceof FieldVector
203-
: "CometUDF implementations must return FieldVector; got "
204-
+ (result == null ? "null" : result.getClass().getName());
205202
if (!(result instanceof FieldVector)) {
206203
throw new RuntimeException(
207204
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());

common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,53 +27,36 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2727
import org.apache.comet.shims.CometInternalRowShim
2828

2929
/**
30-
* Shim base for Comet-owned [[ArrayData]] views used by the Arrow-direct codegen kernel.
31-
*
30+
* Shim base for things that implement Spark's [[ArrayData]] in the Arrow-direct codegen kernel.
3231
* Provides `UnsupportedOperationException` defaults for every abstract method on `ArrayData` and
33-
* `SpecializedGetters`. Codegen emits a concrete subclass per complex-typed input column,
34-
* overriding only the small set of getters the element type requires (e.g. `numElements`,
35-
* `isNullAt`, and `getUTF8String` for an `ArrayType(StringType)` input).
32+
* `SpecializedGetters`; codegen-emitted subclasses override only the getters their element type
33+
* needs (e.g. `numElements`, `isNullAt`, and `getUTF8String` for an `ArrayType(StringType)`
34+
* input).
35+
*
36+
* Consumer: `InputArray_${path}` nested classes the input emitter generates per `ArrayType` input
37+
* column. These back the kernel's `getArray(ord)` switch and the recursive nested classes for
38+
* `Array<Array<...>>` / array-typed map keys / array-typed struct fields.
3639
*
37-
* Pattern mirrors [[CometInternalRow]]: centralize the boilerplate throws so the codegen- emitted
38-
* subclasses stay short, and absorb forward-compat breakage if Spark adds abstract methods to
39-
* `ArrayData` in a future version.
40+
* Why this exists separately from [[CometInternalRow]]: in Spark, `ArrayData` and `InternalRow`
41+
* are sibling abstract classes. They both extend `SpecializedGetters` (so they share the typed
42+
* scalar getters) but neither inherits the other, so a base aimed at one cannot serve the other.
43+
* The `get(ordinal, dataType)` dispatch body that '''is''' shared between the two lives in
44+
* [[CometSpecializedGettersDispatch]].
45+
*
46+
* [[CometMapData]] is the third sibling for `MapType` views; it backs `InputMap_*` and routes
47+
* `keyArray()` / `valueArray()` through `CometArrayData` instances.
4048
*
4149
* Mixes in [[CometInternalRowShim]] for the same reason `CometInternalRow` does: Spark 4.x adds
42-
* new abstract getters (`getVariant`, `getGeography`, `getGeometry`) on `SpecializedGetters` that
43-
* both `InternalRow` and `ArrayData` inherit. The shim is per-profile and provides throwing
44-
* defaults only on the profiles that declare those methods abstract.
50+
* abstract `SpecializedGetters` methods (`getVariant`, `getGeography`, `getGeometry`) that both
51+
* `InternalRow` and `ArrayData` inherit. The shim is per-profile and provides throwing defaults
52+
* only on the profiles where those methods are abstract.
4553
*/
4654
abstract class CometArrayData extends ArrayData with CometInternalRowShim {
4755

4856
override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval")
4957

50-
/**
51-
* Generic `get(ordinal, dataType)` dispatcher. Spark codegen sometimes calls this rather than
52-
* the typed getter (`SafeProjection` uses it when deserializing struct-valued ScalaUDF args,
53-
* for example); leaving it as a throw leaks NPEs once callers catch the
54-
* `UnsupportedOperationException` and propagate null. Dispatches to the typed getter matching
55-
* `dataType`; a null entry returns `null` outright.
56-
*/
57-
override def get(ordinal: Int, dataType: DataType): AnyRef = {
58-
if (isNullAt(ordinal)) return null
59-
dataType match {
60-
case BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal))
61-
case ByteType => java.lang.Byte.valueOf(getByte(ordinal))
62-
case ShortType => java.lang.Short.valueOf(getShort(ordinal))
63-
case IntegerType | DateType => java.lang.Integer.valueOf(getInt(ordinal))
64-
case LongType | TimestampType | TimestampNTZType =>
65-
java.lang.Long.valueOf(getLong(ordinal))
66-
case FloatType => java.lang.Float.valueOf(getFloat(ordinal))
67-
case DoubleType => java.lang.Double.valueOf(getDouble(ordinal))
68-
case _: StringType => getUTF8String(ordinal)
69-
case BinaryType => getBinary(ordinal)
70-
case dt: DecimalType => getDecimal(ordinal, dt.precision, dt.scale)
71-
case st: StructType => getStruct(ordinal, st.size)
72-
case _: ArrayType => getArray(ordinal)
73-
case _: MapType => getMap(ordinal)
74-
case other => unsupported(s"get for dataType $other")
75-
}
76-
}
58+
override def get(ordinal: Int, dataType: DataType): AnyRef =
59+
CometSpecializedGettersDispatch.get(this, ordinal, dataType)
7760

7861
override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt")
7962

common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ package org.apache.comet.codegen
2121

2222
import org.apache.arrow.vector._
2323
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
24+
import org.apache.arrow.vector.types.pojo.Field
2425
import org.apache.spark.internal.Logging
2526
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, Unevaluable}
2627
import org.apache.spark.sql.catalyst.expressions.codegen._
2728
import org.apache.spark.sql.internal.SQLConf
28-
import org.apache.spark.sql.types.DataType
29+
import org.apache.spark.sql.types._
2930

3031
import org.apache.comet.shims.CometExprTraitShim
3132

@@ -83,18 +84,35 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
8384
case other => throw new IllegalArgumentException(s"unknown Arrow vector class: $other")
8485
}
8586

87+
/**
88+
* Type surface the kernel covers, on both the input getter side and the output writer side.
89+
* Recursive: `ArrayType` / `StructType` / `MapType` are supported when their children are.
90+
* Input and output use a single predicate today; if they ever need to diverge, split this back
91+
* into per-direction methods.
92+
*/
93+
def isSupportedDataType(dt: DataType): Boolean = dt match {
94+
case BooleanType | ByteType | ShortType | IntegerType | LongType => true
95+
case FloatType | DoubleType => true
96+
case _: DecimalType => true
97+
case _: StringType | _: BinaryType => true
98+
case DateType | TimestampType | TimestampNTZType => true
99+
case ArrayType(inner, _) => isSupportedDataType(inner)
100+
case st: StructType => st.fields.forall(f => isSupportedDataType(f.dataType))
101+
case mt: MapType => isSupportedDataType(mt.keyType) && isSupportedDataType(mt.valueType)
102+
case _ => false
103+
}
104+
86105
/**
87106
* Plan-time predicate: can the codegen dispatcher handle this bound expression end to end? If
88107
* it returns `None`, the serde is free to emit the codegen proto. If it returns `Some(reason)`,
89108
* the serde must fall back (usually via `withInfo(...) + None`) so Spark runs the expression
90109
* rather than crashing in the Janino compile at execute time.
91110
*
92111
* Checks:
93-
* - every `BoundReference`'s data type is in
94-
* [[CometBatchKernelCodegenInput.isSupportedInputType]] (i.e. the kernel has a typed getter
95-
* for it)
96-
* - the overall `expr.dataType` is in [[CometBatchKernelCodegenOutput.isSupportedOutputType]]
97-
* (i.e. `allocateOutput` and `emitWrite` know how to materialize it)
112+
* - every `BoundReference`'s data type is in [[isSupportedDataType]] (i.e. the kernel has a
113+
* typed getter for it)
114+
* - the overall `expr.dataType` is in [[isSupportedDataType]] (i.e. `allocateOutput` and
115+
* `emitWrite` know how to materialize it)
98116
* - the expression is scalar (no `AggregateFunction`, no generators). These never reach a
99117
* scalar serde, but we belt-and-suspenders anyway.
100118
*
@@ -103,7 +121,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
103121
* the output vector) touch Arrow.
104122
*/
105123
def canHandle(boundExpr: Expression): Option[String] = {
106-
if (!CometBatchKernelCodegenOutput.isSupportedOutputType(boundExpr.dataType)) {
124+
if (!isSupportedDataType(boundExpr.dataType)) {
107125
return Some(s"codegen dispatch: unsupported output type ${boundExpr.dataType}")
108126
}
109127
// Reject expressions that can't be safely compiled or cached:
@@ -155,7 +173,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
155173
case None =>
156174
}
157175
val badRef = boundExpr.collectFirst {
158-
case b: BoundReference if !CometBatchKernelCodegenInput.isSupportedInputType(b.dataType) =>
176+
case b: BoundReference if !isSupportedDataType(b.dataType) =>
159177
b
160178
}
161179
badRef.map(b =>
@@ -175,6 +193,10 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
175193
estimatedBytes: Int = -1): FieldVector =
176194
CometBatchKernelCodegenOutput.allocateOutput(dataType, name, numRows, estimatedBytes)
177195

196+
/** Variant that takes a pre-computed Arrow `Field`, letting hot-path callers cache it. */
197+
def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector =
198+
CometBatchKernelCodegenOutput.allocateOutput(field, numRows, estimatedBytes)
199+
178200
def compile(boundExpr: Expression, inputSchema: Seq[ArrowColumnSpec]): CompiledKernel = {
179201
val src = generateSource(boundExpr, inputSchema)
180202
val (clazz, _) =
@@ -188,8 +210,6 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
188210
t)
189211
throw t
190212
}
191-
// One log per unique (expr, schema) compile; the caller caches the result so subsequent
192-
// batches with the same shape reuse this compile.
193213
logInfo(
194214
s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " +
195215
s"-> ${boundExpr.dataType} inputs=" +
@@ -529,8 +549,9 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
529549
ScalarColumnSpec(vectorClass, nullable)
530550

531551
/**
532-
* Backward-compatible extractor for the common scalar case. Callers that want array / struct
533-
* / future map specs should pattern match on the subclass directly.
552+
* Trait-level extractor that destructures only the scalar case. Pattern-match callers use
553+
* `case ArrowColumnSpec(cls, nullable)` to filter on scalar specs and pull out their vector
554+
* class and nullability in one step; complex specs return `None` and skip the case.
534555
*/
535556
def unapply(spec: ArrowColumnSpec): Option[(Class[_ <: ValueVector], Boolean)] = spec match {
536557
case ScalarColumnSpec(c, n) => Some((c, n))

common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,6 @@ private[codegen] object CometBatchKernelCodegenInput {
9494
classOf[TimeStampMicroTZVector])
9595
private val cometPlainVectorName: String = classOf[CometPlainVector].getName
9696

97-
/**
98-
* Input types the kernel has a typed getter for. Recursive: `ArrayType(inner)` supported when
99-
* `inner` is supported; `StructType` when every field is; `MapType` when key and value types
100-
* are both supported.
101-
*/
102-
def isSupportedInputType(dt: DataType): Boolean = dt match {
103-
case BooleanType | ByteType | ShortType | IntegerType | LongType => true
104-
case FloatType | DoubleType => true
105-
case _: DecimalType => true
106-
case _: StringType | _: BinaryType => true
107-
case DateType | TimestampType | TimestampNTZType => true
108-
case ArrayType(inner, _) => isSupportedInputType(inner)
109-
case st: StructType => st.fields.forall(f => isSupportedInputType(f.dataType))
110-
case mt: MapType => isSupportedInputType(mt.keyType) && isSupportedInputType(mt.valueType)
111-
case _ => false
112-
}
113-
11497
/**
11598
* Emit the kernel's typed vector-field declarations for every level of every input column's
11699
* spec tree. Top-level complex columns additionally get an instance-field declaration for the
@@ -215,10 +198,10 @@ private[codegen] object CometBatchKernelCodegenInput {
215198
val fastPath = emitDecimalFastBodyUnsafe(valueAddr, "this.rowIdx", " ")
216199
val slowPath = emitDecimalSlowBody(slowField, "this.rowIdx", " ")
217200
val body = known match {
218-
case Some(dt) if dt.precision <= 18 => fastPath
201+
case Some(dt) if dt.precision <= Decimal.MAX_LONG_DIGITS => fastPath
219202
case Some(_) => slowPath
220203
case None =>
221-
s""" if (precision <= 18) {
204+
s""" if (precision <= ${Decimal.MAX_LONG_DIGITS}) {
222205
|$fastPath
223206
| } else {
224207
|$slowPath
@@ -608,7 +591,7 @@ private[codegen] object CometBatchKernelCodegenInput {
608591
collectNestedClasses(s"${path}_f$fi", f.child, out)
609592
}
610593
case mp: MapColumnSpec =>
611-
out += emitMapClass(path, mp)
594+
out += emitMapClass(path)
612595
// Emit InputArray_${path}_k and InputArray_${path}_v - the ArrayData views returned by
613596
// `MapData.keyArray()` / `valueArray()`. They follow the standard array-element
614597
// convention: each reads from `${classPath}_e` which maps to the key / value vector
@@ -754,7 +737,7 @@ private[codegen] object CometBatchKernelCodegenInput {
754737
| }""".stripMargin
755738
case dt: DecimalType =>
756739
val body =
757-
if (dt.precision <= 18) {
740+
if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
758741
emitDecimalFastBodyUnsafe(s"${childField}_valueAddr", "startIndex + i", " ")
759742
} else {
760743
emitDecimalSlowBody(childField, "startIndex + i", " ")
@@ -947,7 +930,7 @@ private[codegen] object CometBatchKernelCodegenInput {
947930
val dt = f.sparkType.asInstanceOf[DecimalType]
948931
val field = s"${path}_f$fi"
949932
val body =
950-
if (dt.precision <= 18) {
933+
if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
951934
emitDecimalFastBodyUnsafe(s"${field}_valueAddr", "this.rowIdx", " ")
952935
} else {
953936
emitDecimalSlowBody(field, "this.rowIdx", " ")
@@ -1024,8 +1007,7 @@ private[codegen] object CometBatchKernelCodegenInput {
10241007
* `keyArray()` / `valueArray()` through pre-allocated `InputArray_${path}_k` /
10251008
* `InputArray_${path}_v` instances (emitted by [[collectNestedClasses]]).
10261009
*/
1027-
private def emitMapClass(path: String, spec: MapColumnSpec): String = {
1028-
val _ = spec // key/value arrays declared via path convention below
1010+
private def emitMapClass(path: String): String = {
10291011
val baseClassName = classOf[CometMapData].getName
10301012
val keyPath = s"${path}_k"
10311013
val valPath = s"${path}_v"

common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.comet.codegen
2121

2222
import org.apache.arrow.vector._
2323
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
24+
import org.apache.arrow.vector.types.pojo.Field
2425
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2526
import org.apache.spark.sql.comet.util.Utils
2627
import org.apache.spark.sql.types._
@@ -37,23 +38,6 @@ import org.apache.comet.CometArrowAllocator
3738
*/
3839
private[codegen] object CometBatchKernelCodegenOutput {
3940

40-
/**
41-
* Output types [[allocateOutput]] and [[emitOutputWriter]] can materialize. Recursive: complex
42-
* types are supported when their children are.
43-
*/
44-
def isSupportedOutputType(dt: DataType): Boolean = dt match {
45-
case BooleanType | ByteType | ShortType | IntegerType | LongType => true
46-
case FloatType | DoubleType => true
47-
case _: DecimalType => true
48-
case _: StringType | _: BinaryType => true
49-
case DateType | TimestampType | TimestampNTZType => true
50-
case ArrayType(inner, _) => isSupportedOutputType(inner)
51-
case st: StructType => st.fields.forall(f => isSupportedOutputType(f.dataType))
52-
case mt: MapType =>
53-
isSupportedOutputType(mt.keyType) && isSupportedOutputType(mt.valueType)
54-
case _ => false
55-
}
56-
5741
/**
5842
* Allocate an Arrow output vector matching `dataType`. Delegates field and vector construction
5943
* to [[Utils.toArrowField]] + `Field.createVector`, which is the pattern the rest of Comet uses
@@ -73,15 +57,19 @@ private[codegen] object CometBatchKernelCodegenOutput {
7357
dataType: DataType,
7458
name: String,
7559
numRows: Int,
76-
estimatedBytes: Int = -1): FieldVector = {
77-
val field = Utils.toArrowField(name, dataType, nullable = true, "UTC")
60+
estimatedBytes: Int = -1): FieldVector =
61+
allocateOutput(
62+
Utils.toArrowField(name, dataType, nullable = true, "UTC"),
63+
numRows,
64+
estimatedBytes)
65+
66+
/** Variant that takes a pre-computed Arrow `Field`, letting hot-path callers cache it. */
67+
def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector = {
7868
val vec = field.createVector(CometArrowAllocator).asInstanceOf[FieldVector]
7969
try {
8070
vec.setInitialCapacity(numRows)
8171
vec match {
82-
case v: VarCharVector if estimatedBytes > 0 =>
83-
v.allocateNew(estimatedBytes.toLong, numRows)
84-
case v: VarBinaryVector if estimatedBytes > 0 =>
72+
case v: BaseVariableWidthVector if estimatedBytes > 0 =>
8573
v.allocateNew(estimatedBytes.toLong, numRows)
8674
case _ =>
8775
vec.allocateNew()
@@ -172,8 +160,11 @@ private[codegen] object CometBatchKernelCodegenOutput {
172160
// `DecimalVector.setSafe(int, long)` and skip the `java.math.BigDecimal` allocation
173161
// `setSafe(int, BigDecimal)` requires. For p > 18 the BigDecimal path is unavoidable.
174162
val write =
175-
if (dt.precision <= 18) s"$targetVec.setSafe($idx, $source.toUnscaledLong());"
176-
else s"$targetVec.setSafe($idx, $source.toJavaBigDecimal());"
163+
if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
164+
s"$targetVec.setSafe($idx, $source.toUnscaledLong());"
165+
} else {
166+
s"$targetVec.setSafe($idx, $source.toJavaBigDecimal());"
167+
}
177168
OutputEmit("", write)
178169
case _: StringType =>
179170
// Optimization: Utf8OutputOnHeapShortcut.

0 commit comments

Comments
 (0)