@@ -21,11 +21,12 @@ package org.apache.comet.codegen
2121
2222import org .apache .arrow .vector ._
2323import org .apache .arrow .vector .complex .{ListVector , MapVector , StructVector }
24+ import org .apache .arrow .vector .types .pojo .Field
2425import org .apache .spark .internal .Logging
2526import org .apache .spark .sql .catalyst .expressions .{BoundReference , Expression , Literal , Unevaluable }
2627import org .apache .spark .sql .catalyst .expressions .codegen ._
2728import org .apache .spark .sql .internal .SQLConf
28- import org .apache .spark .sql .types .DataType
29+ import org .apache .spark .sql .types ._
2930
3031import org .apache .comet .shims .CometExprTraitShim
3132
@@ -83,18 +84,35 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
8384 case other => throw new IllegalArgumentException (s " unknown Arrow vector class: $other" )
8485 }
8586
87+ /**
88+ * Type surface the kernel covers, on both the input getter side and the output writer side.
89+ * Recursive: `ArrayType` / `StructType` / `MapType` are supported when their children are.
90+ * Input and output use a single predicate today; if they ever need to diverge, split this back
91+ * into per-direction methods.
92+ */
93+ def isSupportedDataType (dt : DataType ): Boolean = dt match {
94+ case BooleanType | ByteType | ShortType | IntegerType | LongType => true
95+ case FloatType | DoubleType => true
96+ case _ : DecimalType => true
97+ case _ : StringType | _ : BinaryType => true
98+ case DateType | TimestampType | TimestampNTZType => true
99+ case ArrayType (inner, _) => isSupportedDataType(inner)
100+ case st : StructType => st.fields.forall(f => isSupportedDataType(f.dataType))
101+ case mt : MapType => isSupportedDataType(mt.keyType) && isSupportedDataType(mt.valueType)
102+ case _ => false
103+ }
104+
86105 /**
87106 * Plan-time predicate: can the codegen dispatcher handle this bound expression end to end? If
88107 * it returns `None`, the serde is free to emit the codegen proto. If it returns `Some(reason)`,
89108 * the serde must fall back (usually via `withInfo(...) + None`) so Spark runs the expression
90109 * rather than crashing in the Janino compile at execute time.
91110 *
92111 * Checks:
93- * - every `BoundReference`'s data type is in
94- * [[CometBatchKernelCodegenInput.isSupportedInputType ]] (i.e. the kernel has a typed getter
95- * for it)
96- * - the overall `expr.dataType` is in [[CometBatchKernelCodegenOutput.isSupportedOutputType ]]
97- * (i.e. `allocateOutput` and `emitWrite` know how to materialize it)
112+ * - every `BoundReference`'s data type is in [[isSupportedDataType ]] (i.e. the kernel has a
113+ * typed getter for it)
114+ * - the overall `expr.dataType` is in [[isSupportedDataType ]] (i.e. `allocateOutput` and
115+ * `emitWrite` know how to materialize it)
98116 * - the expression is scalar (no `AggregateFunction`, no generators). These never reach a
99117 * scalar serde, but we belt-and-suspenders anyway.
100118 *
@@ -103,7 +121,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
103121 * the output vector) touch Arrow.
104122 */
105123 def canHandle (boundExpr : Expression ): Option [String ] = {
106- if (! CometBatchKernelCodegenOutput .isSupportedOutputType (boundExpr.dataType)) {
124+ if (! isSupportedDataType (boundExpr.dataType)) {
107125 return Some (s " codegen dispatch: unsupported output type ${boundExpr.dataType}" )
108126 }
109127 // Reject expressions that can't be safely compiled or cached:
@@ -155,7 +173,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
155173 case None =>
156174 }
157175 val badRef = boundExpr.collectFirst {
158- case b : BoundReference if ! CometBatchKernelCodegenInput .isSupportedInputType (b.dataType) =>
176+ case b : BoundReference if ! isSupportedDataType (b.dataType) =>
159177 b
160178 }
161179 badRef.map(b =>
@@ -175,6 +193,10 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
175193 estimatedBytes : Int = - 1 ): FieldVector =
176194 CometBatchKernelCodegenOutput .allocateOutput(dataType, name, numRows, estimatedBytes)
177195
196+ /** Variant that takes a pre-computed Arrow `Field`, letting hot-path callers cache it. */
197+ def allocateOutput (field : Field , numRows : Int , estimatedBytes : Int ): FieldVector =
198+ CometBatchKernelCodegenOutput .allocateOutput(field, numRows, estimatedBytes)
199+
178200 def compile (boundExpr : Expression , inputSchema : Seq [ArrowColumnSpec ]): CompiledKernel = {
179201 val src = generateSource(boundExpr, inputSchema)
180202 val (clazz, _) =
@@ -188,8 +210,6 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
188210 t)
189211 throw t
190212 }
191- // One log per unique (expr, schema) compile; the caller caches the result so subsequent
192- // batches with the same shape reuse this compile.
193213 logInfo(
194214 s " CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " +
195215 s " -> ${boundExpr.dataType} inputs= " +
@@ -529,8 +549,9 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
529549 ScalarColumnSpec (vectorClass, nullable)
530550
531551 /**
532- * Backward-compatible extractor for the common scalar case. Callers that want array / struct
533- * / future map specs should pattern match on the subclass directly.
552+ * Trait-level extractor that destructures only the scalar case. Pattern-match callers use
553+ * `case ArrowColumnSpec(cls, nullable)` to filter on scalar specs and pull out their vector
554+ * class and nullability in one step; complex specs return `None` and skip the case.
534555 */
535556 def unapply (spec : ArrowColumnSpec ): Option [(Class [_ <: ValueVector ], Boolean )] = spec match {
536557 case ScalarColumnSpec (c, n) => Some ((c, n))
0 commit comments