@@ -22,23 +22,22 @@ package org.apache.comet.udf
2222import org .apache .arrow .vector .{BigIntVector , BitVector , DateDayVector , DecimalVector , FieldVector , Float4Vector , Float8Vector , IntVector , SmallIntVector , TimeStampMicroTZVector , TimeStampMicroVector , TinyIntVector , ValueVector , VarBinaryVector , VarCharVector }
2323import org .apache .arrow .vector .complex .{ListVector , MapVector , StructVector }
2424import org .apache .spark .internal .Logging
25- import org .apache .spark .sql .catalyst .expressions .{BoundReference , Expression , Literal , RegExpReplace , Unevaluable }
25+ import org .apache .spark .sql .catalyst .expressions .{BoundReference , Expression , Literal , Unevaluable }
2626import org .apache .spark .sql .catalyst .expressions .codegen .{CodeAndComment , CodeFormatter , CodegenContext , CodeGenerator , CodegenFallback , ExprCode , GeneratedClass }
2727import org .apache .spark .sql .internal .SQLConf
28- import org .apache .spark .sql .types .{ DataType , StringType }
28+ import org .apache .spark .sql .types .DataType
2929
3030import org .apache .comet .shims .CometExprTraitShim
3131
3232/**
33- * Compiles a bound [[Expression ]] plus an input schema into a specialized [[CometBatchKernel ]]
34- * that fuses Arrow input reads, expression evaluation, and Arrow output writes into one
35- * Janino-compiled method per (expression, schema) pair.
33+ * Compiles a bound [[Expression ]] plus an input schema into a [[CometBatchKernel ]] that fuses
34+ * Arrow input reads, expression evaluation, and Arrow output writes into one Janino-compiled
35+ * method per (expression, schema) pair.
3636 *
3737 * Input- and output-side emission live in [[CometBatchKernelCodegenInput ]] and
3838 * [[CometBatchKernelCodegenOutput ]]. This file is the orchestrator: the [[ArrowColumnSpec ]]
3939 * vocabulary, [[canHandle ]] / [[allocateOutput ]] / [[compile ]] / [[generateSource ]] entry points,
40- * and the cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant,
41- * per-expression specialized emitters).
40+ * and the cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant).
4241 *
4342 * The generated kernel '''is''' the `InternalRow` that Spark's `BoundReference.genCode` reads
4443 * from. `ctx.INPUT_ROW = "row"` plus `InternalRow row = this;` inside `process` routes every
@@ -47,8 +46,8 @@ import org.apache.comet.shims.CometExprTraitShim
4746 * `splitExpressions` uses INPUT_ROW as a helper-method parameter name and `this` is a reserved
4847 * Java keyword.
4948 *
50- * For the full feature list (type surface, optimizations, cache layers, specialized emitters,
51- * open work items), see `docs/source/contributor-guide/jvm_udf_dispatch.md`.
49+ * For the full feature list (type surface, optimizations, cache layers, open work items), see
50+ * `docs/source/contributor-guide/jvm_udf_dispatch.md`.
5251 */
5352object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
5453
@@ -292,7 +291,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
292291 /**
293292 * Generate the Java source for a kernel without compiling it. Factored out of [[compile ]] so
294293 * tests can assert on the emitted source (null short-circuit present, non-nullable `isNullAt`
295- * returns literal `false`, specialized emitter engaged, etc.) without paying for Janino.
294+ * returns literal `false`, etc.) without paying for Janino.
296295 */
297296 def generateSource (
298297 boundExpr : Expression ,
@@ -313,39 +312,34 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
313312 val valueVectorClass = classOf [ValueVector ].getName
314313 val fieldVectorClass = classOf [FieldVector ].getName
315314
316- // Pick the per-row body. Specialized emitters get priority; the default reuses
317- // Spark's doGenCode.
315+ // Build the per-row body via Spark's doGenCode.
318316 //
319317 // `outputSetup` holds once-per-batch declarations (typed child-vector casts for complex
320318 // output) that `emitOutputWriter` factors out of the per-row body so they do not repeat on
321- // every row. Scalar outputs return an empty string here. Specialized emitters (like
322- // RegExpReplace) do not need setup because they write directly to the root `output`.
319+ // every row. Scalar outputs return an empty string here.
323320 //
324321 // TODO(method-size): perRowBody is inlined inside process's for-loop and not split.
325322 // Sufficiently deep trees can exceed Janino's 64KB method size; wrap in
326323 // ctx.splitExpressionsWithCurrentInputs when hit. See
327324 // docs/source/contributor-guide/jvm_udf_dispatch.md#open-items.
328- val (concreteOutClass, outputSetup, perRowBody) = boundExpr match {
329- case rr : RegExpReplace if canSpecializeRegExpReplace(rr) =>
330- (classOf [VarCharVector ].getName, " " , specializedRegExpReplaceBody(ctx, rr, inputSchema))
331- case _ =>
332- // Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the
333- // hood, which populates `ctx.subexprFunctions` with per-row helper calls that write
334- // common subexpression results into `addMutableState`-allocated fields; the returned
335- // `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated
336- // helper invocation block, spliced into the per-row body by `defaultBody` (inside the
337- // NullIntolerant else-branch when that short-circuit fires, otherwise before
338- // `ev.code`). See the "Subexpression elimination" section of the object-level
339- // Scaladoc for why we use this variant rather than the WSCG one.
340- val ev = if (SQLConf .get.subexpressionEliminationEnabled) {
341- ctx.generateExpressions(Seq (boundExpr), doSubexpressionElimination = true ).head
342- } else {
343- boundExpr.genCode(ctx)
344- }
345- val subExprsCode = ctx.subexprFunctionsCode
346- val (cls, setup, snippet) =
347- CometBatchKernelCodegenOutput .emitOutputWriter(boundExpr.dataType, ev.value, ctx)
348- (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
325+ val (concreteOutClass, outputSetup, perRowBody) = {
326+ // Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the
327+ // hood, which populates `ctx.subexprFunctions` with per-row helper calls that write
328+ // common subexpression results into `addMutableState`-allocated fields; the returned
329+ // `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated
330+ // helper invocation block, spliced into the per-row body by `defaultBody` (inside the
331+ // NullIntolerant else-branch when that short-circuit fires, otherwise before
332+ // `ev.code`). See the "Subexpression elimination" section of the object-level
333+ // Scaladoc for why we use this variant rather than the WSCG one.
334+ val ev = if (SQLConf .get.subexpressionEliminationEnabled) {
335+ ctx.generateExpressions(Seq (boundExpr), doSubexpressionElimination = true ).head
336+ } else {
337+ boundExpr.genCode(ctx)
338+ }
339+ val subExprsCode = ctx.subexprFunctionsCode
340+ val (cls, setup, snippet) =
341+ CometBatchKernelCodegenOutput .emitOutputWriter(boundExpr.dataType, ev.value, ctx)
342+ (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
349343 }
350344
351345 val typedFieldDecls = CometBatchKernelCodegenInput .emitInputFieldDecls(inputSchema)
@@ -431,14 +425,8 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
431425 }
432426 // One log per unique (expr, schema) compile; the caller caches the result so subsequent
433427 // batches with the same shape reuse this compile.
434- val specialized = boundExpr match {
435- case _ : RegExpReplace
436- if canSpecializeRegExpReplace(boundExpr.asInstanceOf [RegExpReplace ]) =>
437- " [specialized]"
438- case _ => " "
439- }
440428 logInfo(
441- s " CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName}$specialized " +
429+ s " CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " +
442430 s " -> ${boundExpr.dataType} inputs= " +
443431 inputSchema
444432 .map(s => s " ${s.vectorClass.getSimpleName}${if (s.nullable) " ?" else " " }" )
@@ -453,106 +441,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
453441 }
454442
455443 /**
456- * Can this `RegExpReplace` instance be handled by the specialized emitter? Requires a direct
457- * column reference as subject, non-null foldable pattern and replacement, and offset of 1.
458- * Other shapes fall back to the default `doGenCode` path.
459- */
460- private def canSpecializeRegExpReplace (rr : RegExpReplace ): Boolean = {
461- val subjectIsBound =
462- rr.subject.isInstanceOf [BoundReference ] && rr.subject.dataType == StringType
463- val patternOk =
464- rr.regexp.foldable && rr.regexp.dataType == StringType && rr.regexp.eval() != null
465- val replOk = rr.rep.foldable && rr.rep.dataType == StringType && rr.rep.eval() != null
466- val posIsOne = rr.pos match {
467- case Literal (v : Int , _) => v == 1
468- case _ => false
469- }
470- subjectIsBound && patternOk && replOk && posIsOne
471- }
472-
473- /**
474- * Emit the per-row body for `RegExpReplace`. Per-row shape: read Arrow subject bytes, decode to
475- * Java `String`, run `Matcher.replaceAll` with a cached `Pattern` and the replacement String,
476- * re-encode to bytes, write to Arrow.
477- *
478- * ==Why this specialization exists==
479- *
480- * The default path runs `boundExpr.genCode(ctx)` and wraps it with kernel-side getter reads and
481- * a `UTF8String -> bytes -> Arrow` write. For `RegExpReplace` specifically, Spark's generated
482- * code does not stay in `UTF8String` space: `java.util.regex.Matcher` requires a
483- * `CharSequence`, so the generated code materializes a Java `String` from the input
484- * `UTF8String` (a UTF-8 decode, allocating a `char[]`), runs the matcher, then wraps the result
485- * String back into a `UTF8String` (a UTF-8 encode, allocating a `byte[]`). The per-row shape
486- * is:
487- *
488- * {{{
489- * default: Arrow bytes -> UTF8String -> String -> Matcher ->
490- * String -> UTF8String -> bytes -> Arrow
491- * }}}
492- *
493- * On a wide-match workload (every character of the row gets replaced, so the output is the full
494- * row length), the round trip added ~44% per-row cost versus a tight byte-oriented loop with
495- * shape:
496- *
497- * {{{
498- * specialized: Arrow bytes -> String -> Matcher -> String -> bytes -> Arrow
499- * }}}
500- *
501- * This specialization emits the byte-oriented shape directly. No `UTF8String` appears in the
502- * generated per-row loop. The expression remains a first-class citizen of the dispatcher
503- * (plan-time serde, schema-keyed caching, zero-config for the caller).
504- *
505- * ==When to add a specialization==
506- *
507- * The general rule: specialize when an expression's `doGenCode` output shape forces conversions
508- * that an Arrow-aware byte-oriented implementation does not pay. The common case is expressions
509- * whose implementation requires a Java `String` (anything using `java.util.regex` and some
510- * `DateTimeFormatter` expressions), because Spark's `UTF8String <-> String` round-trip is not
511- * free for wide outputs. Keep specializations minimal so comparisons stay honest. Avoid
512- * layering speculative optimizations; let the default-path optimization menu handle the common
513- * cases.
514- */
515- private def specializedRegExpReplaceBody (
516- ctx : CodegenContext ,
517- rr : RegExpReplace ,
518- inputSchema : Seq [ArrowColumnSpec ]): String = {
519- val subjectOrd = rr.subject.asInstanceOf [BoundReference ].ordinal
520- val subjectClass = inputSchema(subjectOrd).vectorClass
521- require(
522- subjectClass == classOf [VarCharVector ],
523- " specializedRegExpReplaceBody expects VarCharVector at ordinal " +
524- s " $subjectOrd, got ${subjectClass.getSimpleName}" )
525-
526- val patternStr = rr.regexp.eval().toString
527- val replStr = rr.rep.eval().toString
528- val compiledPattern = java.util.regex.Pattern .compile(patternStr)
529-
530- // addReferenceObj adds a class-level field initialized from references[] in the constructor,
531- // so the Pattern and replacement String are resolved once, not per row.
532- val patternRef =
533- ctx.addReferenceObj(" pattern" , compiledPattern, " java.util.regex.Pattern" )
534- val replRef = ctx.addReferenceObj(" replacement" , replStr, " java.lang.String" )
535-
536- val sb = ctx.freshName(" sb" )
537- val s = ctx.freshName(" s" )
538- val r = ctx.freshName(" r" )
539- val rb = ctx.freshName(" rb" )
540-
541- s """
542- |if (this.col $subjectOrd.isNull(i)) {
543- | output.setNull(i);
544- |} else {
545- | byte[] $sb = this.col $subjectOrd.get(i);
546- | String $s = new String( $sb, java.nio.charset.StandardCharsets.UTF_8);
547- | String $r = $patternRef.matcher( $s).replaceAll( $replRef);
548- | byte[] $rb = $r.getBytes(java.nio.charset.StandardCharsets.UTF_8);
549- | output.setSafe(i, $rb, 0, $rb.length);
550- |}
551- """ .stripMargin
552- }
553-
554- /**
555- * Per-row body for the default (non-specialized) path.
444+ * Per-row body for the default path.
556445 *
557446 * For expressions that implement the `NullIntolerant` marker trait (null in any input -> null
558447 * output), emits a short-circuit that skips expression evaluation entirely when any input
0 commit comments