Skip to content

Commit 7ed806a

Browse files
committed
reduce the scope to just ScalaUDF instead of general spark expressions, tests pass
1 parent ff8ee79 commit 7ed806a

17 files changed

Lines changed: 396 additions & 1188 deletions

File tree

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -380,45 +380,16 @@ object CometConf extends ShimCometConf {
380380
.booleanConf
381381
.createWithDefault(false)
382382

383-
val REGEXP_ENGINE_RUST = "rust"
384-
val REGEXP_ENGINE_JAVA = "java"
385-
386-
val COMET_REGEXP_ENGINE: ConfigEntry[String] =
387-
conf("spark.comet.exec.regexp.engine")
383+
val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] =
384+
conf("spark.comet.exec.scalaUDF.codegen.enabled")
388385
.category(CATEGORY_EXEC)
389386
.doc(
390-
"Experimental. Selects the engine used to evaluate supported regular-expression " +
391-
s"expressions. `$REGEXP_ENGINE_RUST` uses the native DataFusion regexp engine. " +
392-
s"`$REGEXP_ENGINE_JAVA` routes through a JVM-side UDF (java.util.regex.Pattern) for " +
393-
"Spark-compatible semantics, at the cost of JNI roundtrips per batch. Expressions " +
394-
"routed when set to java: rlike, regexp_extract, regexp_extract_all, regexp_replace, " +
395-
"regexp_instr, and split.")
396-
.stringConf
397-
.transform(_.toLowerCase(Locale.ROOT))
398-
.checkValues(Set(REGEXP_ENGINE_RUST, REGEXP_ENGINE_JAVA))
399-
.createWithDefault(REGEXP_ENGINE_JAVA)
400-
401-
val CODEGEN_DISPATCH_AUTO = "auto"
402-
val CODEGEN_DISPATCH_DISABLED = "disabled"
403-
val CODEGEN_DISPATCH_FORCE = "force"
404-
405-
val COMET_CODEGEN_DISPATCH_MODE: ConfigEntry[String] =
406-
conf("spark.comet.exec.codegenDispatch.mode")
407-
.category(CATEGORY_EXEC)
408-
.doc("Controls whether Comet routes eligible scalar expressions through the Arrow-direct " +
409-
"codegen dispatcher (`CometCodegenDispatchUDF`) rather than through a native " +
410-
s"DataFusion implementation or falling back to Spark. `$CODEGEN_DISPATCH_AUTO` lets " +
411-
"each expression's serde decide its preferred path based on measured evidence " +
412-
"(e.g. for regex, codegen is preferred when " +
413-
s"spark.comet.exec.regexp.engine=$REGEXP_ENGINE_JAVA). " +
414-
s"`$CODEGEN_DISPATCH_DISABLED` never uses codegen dispatch. `$CODEGEN_DISPATCH_FORCE` " +
415-
"inverts the chain: every serde tries codegen first and falls through to its next " +
416-
"preferred path only when `canHandle` rejects the expression. Useful for debugging " +
417-
"and benchmarking.")
418-
.stringConf
419-
.transform(_.toLowerCase(Locale.ROOT))
420-
.checkValues(Set(CODEGEN_DISPATCH_AUTO, CODEGEN_DISPATCH_DISABLED, CODEGEN_DISPATCH_FORCE))
421-
.createWithDefault(CODEGEN_DISPATCH_AUTO)
387+
"Whether to route Spark `ScalaUDF` expressions through Comet's Arrow-direct codegen " +
388+
"dispatcher. When enabled, a supported ScalaUDF is compiled into a per-batch kernel " +
389+
"that reads and writes Arrow vectors directly from native execution. When disabled, " +
390+
"plans containing a ScalaUDF fall back to Spark for the enclosing operator.")
391+
.booleanConf
392+
.createWithDefault(true)
422393

423394
val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
424395
conf("spark.comet.native.shuffle.partitioning.hash.enabled")

common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegen.scala

Lines changed: 31 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,22 @@ package org.apache.comet.udf
2222
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector}
2323
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
2424
import org.apache.spark.internal.Logging
25-
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, RegExpReplace, Unevaluable}
25+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, Unevaluable}
2626
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodegenContext, CodeGenerator, CodegenFallback, ExprCode, GeneratedClass}
2727
import org.apache.spark.sql.internal.SQLConf
28-
import org.apache.spark.sql.types.{DataType, StringType}
28+
import org.apache.spark.sql.types.DataType
2929

3030
import org.apache.comet.shims.CometExprTraitShim
3131

3232
/**
33-
* Compiles a bound [[Expression]] plus an input schema into a specialized [[CometBatchKernel]]
34-
* that fuses Arrow input reads, expression evaluation, and Arrow output writes into one
35-
* Janino-compiled method per (expression, schema) pair.
33+
* Compiles a bound [[Expression]] plus an input schema into a [[CometBatchKernel]] that fuses
34+
* Arrow input reads, expression evaluation, and Arrow output writes into one Janino-compiled
35+
* method per (expression, schema) pair.
3636
*
3737
* Input- and output-side emission live in [[CometBatchKernelCodegenInput]] and
3838
* [[CometBatchKernelCodegenOutput]]. This file is the orchestrator: the [[ArrowColumnSpec]]
3939
* vocabulary, [[canHandle]] / [[allocateOutput]] / [[compile]] / [[generateSource]] entry points,
40-
* and the cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant,
41-
* per-expression specialized emitters).
40+
* and the cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant).
4241
*
4342
* The generated kernel '''is''' the `InternalRow` that Spark's `BoundReference.genCode` reads
4443
* from. `ctx.INPUT_ROW = "row"` plus `InternalRow row = this;` inside `process` routes every
@@ -47,8 +46,8 @@ import org.apache.comet.shims.CometExprTraitShim
4746
* `splitExpressions` uses INPUT_ROW as a helper-method parameter name and `this` is a reserved
4847
* Java keyword.
4948
*
50-
* For the full feature list (type surface, optimizations, cache layers, specialized emitters,
51-
* open work items), see `docs/source/contributor-guide/jvm_udf_dispatch.md`.
49+
* For the full feature list (type surface, optimizations, cache layers, open work items), see
50+
* `docs/source/contributor-guide/jvm_udf_dispatch.md`.
5251
*/
5352
object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
5453

@@ -292,7 +291,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
292291
/**
293292
* Generate the Java source for a kernel without compiling it. Factored out of [[compile]] so
294293
* tests can assert on the emitted source (null short-circuit present, non-nullable `isNullAt`
295-
* returns literal `false`, specialized emitter engaged, etc.) without paying for Janino.
294+
* returns literal `false`, etc.) without paying for Janino.
296295
*/
297296
def generateSource(
298297
boundExpr: Expression,
@@ -313,39 +312,34 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
313312
val valueVectorClass = classOf[ValueVector].getName
314313
val fieldVectorClass = classOf[FieldVector].getName
315314

316-
// Pick the per-row body. Specialized emitters get priority; the default reuses
317-
// Spark's doGenCode.
315+
// Build the per-row body via Spark's doGenCode.
318316
//
319317
// `outputSetup` holds once-per-batch declarations (typed child-vector casts for complex
320318
// output) that `emitOutputWriter` factors out of the per-row body so they do not repeat on
321-
// every row. Scalar outputs return an empty string here. Specialized emitters (like
322-
// RegExpReplace) do not need setup because they write directly to the root `output`.
319+
// every row. Scalar outputs return an empty string here.
323320
//
324321
// TODO(method-size): perRowBody is inlined inside process's for-loop and not split.
325322
// Sufficiently deep trees can exceed Janino's 64KB method size; wrap in
326323
// ctx.splitExpressionsWithCurrentInputs when hit. See
327324
// docs/source/contributor-guide/jvm_udf_dispatch.md#open-items.
328-
val (concreteOutClass, outputSetup, perRowBody) = boundExpr match {
329-
case rr: RegExpReplace if canSpecializeRegExpReplace(rr) =>
330-
(classOf[VarCharVector].getName, "", specializedRegExpReplaceBody(ctx, rr, inputSchema))
331-
case _ =>
332-
// Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the
333-
// hood, which populates `ctx.subexprFunctions` with per-row helper calls that write
334-
// common subexpression results into `addMutableState`-allocated fields; the returned
335-
// `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated
336-
// helper invocation block, spliced into the per-row body by `defaultBody` (inside the
337-
// NullIntolerant else-branch when that short-circuit fires, otherwise before
338-
// `ev.code`). See the "Subexpression elimination" section of the object-level
339-
// Scaladoc for why we use this variant rather than the WSCG one.
340-
val ev = if (SQLConf.get.subexpressionEliminationEnabled) {
341-
ctx.generateExpressions(Seq(boundExpr), doSubexpressionElimination = true).head
342-
} else {
343-
boundExpr.genCode(ctx)
344-
}
345-
val subExprsCode = ctx.subexprFunctionsCode
346-
val (cls, setup, snippet) =
347-
CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx)
348-
(cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
325+
val (concreteOutClass, outputSetup, perRowBody) = {
326+
// Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the
327+
// hood, which populates `ctx.subexprFunctions` with per-row helper calls that write
328+
// common subexpression results into `addMutableState`-allocated fields; the returned
329+
// `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated
330+
// helper invocation block, spliced into the per-row body by `defaultBody` (inside the
331+
// NullIntolerant else-branch when that short-circuit fires, otherwise before
332+
// `ev.code`). See the "Subexpression elimination" section of the object-level
333+
// Scaladoc for why we use this variant rather than the WSCG one.
334+
val ev = if (SQLConf.get.subexpressionEliminationEnabled) {
335+
ctx.generateExpressions(Seq(boundExpr), doSubexpressionElimination = true).head
336+
} else {
337+
boundExpr.genCode(ctx)
338+
}
339+
val subExprsCode = ctx.subexprFunctionsCode
340+
val (cls, setup, snippet) =
341+
CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx)
342+
(cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
349343
}
350344

351345
val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema)
@@ -431,14 +425,8 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
431425
}
432426
// One log per unique (expr, schema) compile; the caller caches the result so subsequent
433427
// batches with the same shape reuse this compile.
434-
val specialized = boundExpr match {
435-
case _: RegExpReplace
436-
if canSpecializeRegExpReplace(boundExpr.asInstanceOf[RegExpReplace]) =>
437-
" [specialized]"
438-
case _ => ""
439-
}
440428
logInfo(
441-
s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName}$specialized " +
429+
s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " +
442430
s"-> ${boundExpr.dataType} inputs=" +
443431
inputSchema
444432
.map(s => s"${s.vectorClass.getSimpleName}${if (s.nullable) "?" else ""}")
@@ -453,106 +441,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
453441
}
454442

455443
/**
456-
* Can this `RegExpReplace` instance be handled by the specialized emitter? Requires a direct
457-
* column reference as subject, non-null foldable pattern and replacement, and offset of 1.
458-
* Other shapes fall back to the default `doGenCode` path.
459-
*/
460-
private def canSpecializeRegExpReplace(rr: RegExpReplace): Boolean = {
461-
val subjectIsBound =
462-
rr.subject.isInstanceOf[BoundReference] && rr.subject.dataType == StringType
463-
val patternOk =
464-
rr.regexp.foldable && rr.regexp.dataType == StringType && rr.regexp.eval() != null
465-
val replOk = rr.rep.foldable && rr.rep.dataType == StringType && rr.rep.eval() != null
466-
val posIsOne = rr.pos match {
467-
case Literal(v: Int, _) => v == 1
468-
case _ => false
469-
}
470-
subjectIsBound && patternOk && replOk && posIsOne
471-
}
472-
473-
/**
474-
* Emit the per-row body for `RegExpReplace`. Per-row shape: read Arrow subject bytes, decode to
475-
* Java `String`, run `Matcher.replaceAll` with a cached `Pattern` and the replacement String,
476-
* re-encode to bytes, write to Arrow.
477-
*
478-
* ==Why this specialization exists==
479-
*
480-
* The default path runs `boundExpr.genCode(ctx)` and wraps it with kernel-side getter reads and
481-
* a `UTF8String -> bytes -> Arrow` write. For `RegExpReplace` specifically, Spark's generated
482-
* code does not stay in `UTF8String` space: `java.util.regex.Matcher` requires a
483-
* `CharSequence`, so the generated code materializes a Java `String` from the input
484-
* `UTF8String` (a UTF-8 decode, allocating a `char[]`), runs the matcher, then wraps the result
485-
* String back into a `UTF8String` (a UTF-8 encode, allocating a `byte[]`). The per-row shape
486-
* is:
487-
*
488-
* {{{
489-
* default: Arrow bytes -> UTF8String -> String -> Matcher ->
490-
* String -> UTF8String -> bytes -> Arrow
491-
* }}}
492-
*
493-
* On a wide-match workload (every character of the row gets replaced, so the output is the full
494-
* row length), the round trip added ~44% per-row cost versus a tight byte-oriented loop with
495-
* shape:
496-
*
497-
* {{{
498-
* specialized: Arrow bytes -> String -> Matcher -> String -> bytes -> Arrow
499-
* }}}
500-
*
501-
* This specialization emits the byte-oriented shape directly. No `UTF8String` appears in the
502-
* generated per-row loop. The expression remains a first-class citizen of the dispatcher
503-
* (plan-time serde, schema-keyed caching, zero-config for the caller).
504-
*
505-
* ==When to add a specialization==
506-
*
507-
* The general rule: specialize when an expression's `doGenCode` output shape forces conversions
508-
* that an Arrow-aware byte-oriented implementation does not pay. The common case is expressions
509-
* whose implementation requires a Java `String` (anything using `java.util.regex` and some
510-
* `DateTimeFormatter` expressions), because Spark's `UTF8String <-> String` round-trip is not
511-
* free for wide outputs. Keep specializations minimal so comparisons stay honest. Avoid
512-
* layering speculative optimizations; let the default-path optimization menu handle the common
513-
* cases.
514-
*/
515-
private def specializedRegExpReplaceBody(
516-
ctx: CodegenContext,
517-
rr: RegExpReplace,
518-
inputSchema: Seq[ArrowColumnSpec]): String = {
519-
val subjectOrd = rr.subject.asInstanceOf[BoundReference].ordinal
520-
val subjectClass = inputSchema(subjectOrd).vectorClass
521-
require(
522-
subjectClass == classOf[VarCharVector],
523-
"specializedRegExpReplaceBody expects VarCharVector at ordinal " +
524-
s"$subjectOrd, got ${subjectClass.getSimpleName}")
525-
526-
val patternStr = rr.regexp.eval().toString
527-
val replStr = rr.rep.eval().toString
528-
val compiledPattern = java.util.regex.Pattern.compile(patternStr)
529-
530-
// addReferenceObj adds a class-level field initialized from references[] in the constructor,
531-
// so the Pattern and replacement String are resolved once, not per row.
532-
val patternRef =
533-
ctx.addReferenceObj("pattern", compiledPattern, "java.util.regex.Pattern")
534-
val replRef = ctx.addReferenceObj("replacement", replStr, "java.lang.String")
535-
536-
val sb = ctx.freshName("sb")
537-
val s = ctx.freshName("s")
538-
val r = ctx.freshName("r")
539-
val rb = ctx.freshName("rb")
540-
541-
s"""
542-
|if (this.col$subjectOrd.isNull(i)) {
543-
| output.setNull(i);
544-
|} else {
545-
| byte[] $sb = this.col$subjectOrd.get(i);
546-
| String $s = new String($sb, java.nio.charset.StandardCharsets.UTF_8);
547-
| String $r = $patternRef.matcher($s).replaceAll($replRef);
548-
| byte[] $rb = $r.getBytes(java.nio.charset.StandardCharsets.UTF_8);
549-
| output.setSafe(i, $rb, 0, $rb.length);
550-
|}
551-
""".stripMargin
552-
}
553-
554-
/**
555-
* Per-row body for the default (non-specialized) path.
444+
* Per-row body for the default path.
556445
*
557446
* For expressions that implement the `NullIntolerant` marker trait (null in any input -> null
558447
* output), emits a short-circuit that skips expression evaluation entirely when any input

native/spark-expr/src/jvm_udf/mod.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,13 @@ pub struct JvmScalarUdfExpr {
4141
args: Vec<Arc<dyn PhysicalExpr>>,
4242
return_type: DataType,
4343
return_nullable: bool,
44-
/// Spark `TaskContext` captured on the driving Spark task thread, stashed in the
45-
/// [`ExecutionContext`] at `createPlan` time, and threaded here by the planner. Passed
46-
/// through the JNI bridge so [`CometUdfBridge.evaluate`] can install it as the
47-
/// thread-local `TaskContext` on the Tokio worker that drives the UDF call. Without this,
48-
/// partition-sensitive built-ins inside a user UDF tree (`Rand`, `Uuid`,
49-
/// `MonotonicallyIncreasingID`, custom UDF code that reads
50-
/// `TaskContext.get().partitionId()`) see a null `TaskContext` and seed / branch
51-
/// incorrectly. `None` means the surrounding driver had no `TaskContext` to propagate
52-
/// (unit tests, direct native driver runs); the bridge then leaves whatever
53-
/// `TaskContext.get()` already returns in place.
44+
/// Captured at `createPlan` time and threaded here by the planner. Passed through the
45+
/// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's
46+
/// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF
47+
/// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
48+
/// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving
49+
/// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already
50+
/// returns in place.
5451
task_context: Option<Arc<Global<JObject<'static>>>>,
5552
}
5653

@@ -183,7 +180,7 @@ impl PhysicalExpr for JvmScalarUdfExpr {
183180
CometError::from(ExecutionError::GeneralError(
184181
"JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \
185182
class was not found on the JVM classpath. Set \
186-
spark.comet.exec.regexp.engine=rust to disable this path."
183+
spark.comet.exec.scalaUDF.codegen.enabled=false to disable this path."
187184
.to_string(),
188185
))
189186
})?;

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,8 @@ class CometExecIterator(
128128
taskAttemptId,
129129
taskCPUs,
130130
keyUnwrapper,
131-
// Capture the Spark task thread's TaskContext at `createPlan` time. Stashed native-side
132-
// in the ExecutionContext and passed through the JVM UDF bridge so that Tokio workers
133-
// running JVM UDFs see the real `TaskContext` via their thread-local. See
134-
// `CometUdfBridge.evaluate` and `CometTaskContextShim` for the receive side.
131+
// Propagated to Tokio workers running JVM UDFs so they see this Spark task's
132+
// TaskContext. See CometUdfBridge.evaluate.
135133
TaskContext.get())
136134
}
137135

0 commit comments

Comments
 (0)