Skip to content

Commit 0705dff

Browse files
committed
use cometplainvector part 2
1 parent 421c60c commit 0705dff

2 files changed

Lines changed: 101 additions & 105 deletions

File tree

common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenInput.scala

Lines changed: 98 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ private[udf] object CometBatchKernelCodegenInput {
9595
val lines = new mutable.ArrayBuffer[String]()
9696
inputSchema.zipWithIndex.foreach { case (spec, ord) =>
9797
val path = s"col$ord"
98-
collectVectorFieldDecls(path, spec, topLevel = true, lines)
98+
collectVectorFieldDecls(path, spec, lines)
9999
collectTopLevelInstanceDecl(path, spec, lines)
100100
}
101101
lines.mkString("\n ")
@@ -145,43 +145,53 @@ private[udf] object CometBatchKernelCodegenInput {
145145
private def needsOffsetAddrField(cls: Class[_]): Boolean =
146146
cls == classOf[VarCharVector] || cls == classOf[VarBinaryVector]
147147

148+
/**
149+
* Java method name for the null check on a column's typed field. Primitive scalars wrapped in
150+
* [[CometPlainVector]] expose `isNullAt`; Arrow typed fields (complex containers,
151+
* `DecimalVector`, `VarCharVector`, `VarBinaryVector`) expose `isNull`. Both read the validity
152+
* bitmap.
153+
*/
154+
private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match {
155+
case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt"
156+
case _ => "isNull"
157+
}
158+
148159
private val cometPlainVectorName: String = classOf[CometPlainVector].getName
149160

150161
private def collectVectorFieldDecls(
151162
path: String,
152163
spec: ArrowColumnSpec,
153-
topLevel: Boolean,
154164
out: mutable.ArrayBuffer[String]): Unit = spec match {
155165
case sc: ScalarColumnSpec =>
156-
// CometPlainVector wrapping and cached-address fields apply only at the kernel's top
157-
// level. Nested-class children stay on Arrow typed fields because their generated method
158-
// bodies (inside `InputArray_*` / `InputStruct_*` / `InputMap_*`) call Arrow-style
159-
// `.isNull(i)` / `.get(i)`; converting those too is Phase D.
166+
// Primitive scalar columns (at any nesting depth) are wrapped in CometPlainVector so
167+
// per-row reads go through JIT-inlined Platform.get* against a cached buffer address.
168+
// DecimalVector / VarCharVector / VarBinaryVector stay on the Arrow typed field but
169+
// cache data- and (variable-width) offset-buffer addresses for inline unsafe reads.
160170
val fieldClass =
161-
if (topLevel && wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName
171+
if (wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName
162172
else sc.vectorClass.getName
163173
out += s"private $fieldClass $path;"
164-
if (topLevel && needsValueAddrField(sc.vectorClass)) {
174+
if (needsValueAddrField(sc.vectorClass)) {
165175
out += s"private long ${path}_valueAddr;"
166176
}
167-
if (topLevel && needsOffsetAddrField(sc.vectorClass)) {
177+
if (needsOffsetAddrField(sc.vectorClass)) {
168178
out += s"private long ${path}_offsetAddr;"
169179
}
170180
case ar: ArrayColumnSpec =>
171181
out += s"private ${classOf[ListVector].getName} $path;"
172-
collectVectorFieldDecls(s"${path}_e", ar.element, topLevel = false, out)
182+
collectVectorFieldDecls(s"${path}_e", ar.element, out)
173183
case st: StructColumnSpec =>
174184
out += s"private ${classOf[StructVector].getName} $path;"
175185
st.fields.zipWithIndex.foreach { case (f, fi) =>
176-
collectVectorFieldDecls(s"${path}_f$fi", f.child, topLevel = false, out)
186+
collectVectorFieldDecls(s"${path}_f$fi", f.child, out)
177187
}
178188
case mp: MapColumnSpec =>
179189
out += s"private ${classOf[MapVector].getName} $path;"
180190
// Key and value vectors live at `${P}_k_e` / `${P}_v_e` so the `InputArray_${P}_k` /
181191
// `InputArray_${P}_v` synthetic classes (which follow the array-element convention of
182192
// reading from `${path}_e`) resolve their element reads correctly.
183-
collectVectorFieldDecls(s"${path}_k_e", mp.key, topLevel = false, out)
184-
collectVectorFieldDecls(s"${path}_v_e", mp.value, topLevel = false, out)
193+
collectVectorFieldDecls(s"${path}_k_e", mp.key, out)
194+
collectVectorFieldDecls(s"${path}_v_e", mp.value, out)
185195
}
186196

187197
private def collectTopLevelInstanceDecl(
@@ -208,7 +218,7 @@ private[udf] object CometBatchKernelCodegenInput {
208218
val lines = new mutable.ArrayBuffer[String]()
209219
inputSchema.zipWithIndex.foreach { case (spec, ord) =>
210220
val path = s"col$ord"
211-
collectCasts(path, spec, s"inputs[$ord]", topLevel = true, lines)
221+
collectCasts(path, spec, s"inputs[$ord]", lines)
212222
}
213223
lines.mkString("\n ")
214224
}
@@ -217,38 +227,30 @@ private[udf] object CometBatchKernelCodegenInput {
217227
path: String,
218228
spec: ArrowColumnSpec,
219229
source: String,
220-
topLevel: Boolean,
221230
out: mutable.ArrayBuffer[String]): Unit = spec match {
222231
case sc: ScalarColumnSpec =>
223-
if (topLevel && wrapsInCometPlainVector(sc.vectorClass)) {
232+
if (wrapsInCometPlainVector(sc.vectorClass)) {
224233
// Wrap in CometPlainVector so per-row reads go through Platform.get* against a final
225234
// long buffer address. JIT inlines the one-liner getters, treating the address as a
226-
// register-cached constant across the process loop. useDecimal128 = true matches Spark's
227-
// 128-bit decimal storage.
235+
// register-cached constant across the process loop. useDecimal128 = true matches
236+
// Spark's 128-bit decimal storage.
228237
out += s"this.$path = new $cometPlainVectorName($source, true);"
229238
} else {
230239
out += s"this.$path = (${sc.vectorClass.getName}) $source;"
231240
}
232-
// Address caching applies only at the kernel top level; nested-class reads still go
233-
// through Arrow typed getters (Phase D).
234-
if (topLevel && needsValueAddrField(sc.vectorClass)) {
241+
if (needsValueAddrField(sc.vectorClass)) {
235242
out += s"this.${path}_valueAddr = this.$path.getDataBuffer().memoryAddress();"
236243
}
237-
if (topLevel && needsOffsetAddrField(sc.vectorClass)) {
244+
if (needsOffsetAddrField(sc.vectorClass)) {
238245
out += s"this.${path}_offsetAddr = this.$path.getOffsetBuffer().memoryAddress();"
239246
}
240247
case ar: ArrayColumnSpec =>
241248
out += s"this.$path = (${classOf[ListVector].getName}) $source;"
242-
collectCasts(s"${path}_e", ar.element, s"this.$path.getDataVector()", topLevel = false, out)
249+
collectCasts(s"${path}_e", ar.element, s"this.$path.getDataVector()", out)
243250
case st: StructColumnSpec =>
244251
out += s"this.$path = (${classOf[StructVector].getName}) $source;"
245252
st.fields.zipWithIndex.foreach { case (f, fi) =>
246-
collectCasts(
247-
s"${path}_f$fi",
248-
f.child,
249-
s"this.$path.getChildByOrdinal($fi)",
250-
topLevel = false,
251-
out)
253+
collectCasts(s"${path}_f$fi", f.child, s"this.$path.getChildByOrdinal($fi)", out)
252254
}
253255
case mp: MapColumnSpec =>
254256
// MapVector's data vector is a StructVector with key at child 0 and value at child 1.
@@ -259,18 +261,8 @@ private[udf] object CometBatchKernelCodegenInput {
259261
out += s"this.$path = (${classOf[MapVector].getName}) $source;"
260262
out += s"${classOf[StructVector].getName} $structLocal = " +
261263
s"(${classOf[StructVector].getName}) this.$path.getDataVector();"
262-
collectCasts(
263-
s"${path}_k_e",
264-
mp.key,
265-
s"$structLocal.getChildByOrdinal(0)",
266-
topLevel = false,
267-
out)
268-
collectCasts(
269-
s"${path}_v_e",
270-
mp.value,
271-
s"$structLocal.getChildByOrdinal(1)",
272-
topLevel = false,
273-
out)
264+
collectCasts(s"${path}_k_e", mp.key, s"$structLocal.getChildByOrdinal(0)", out)
265+
collectCasts(s"${path}_v_e", mp.value, s"$structLocal.getChildByOrdinal(1)", out)
274266
}
275267

276268
/**
@@ -506,7 +498,7 @@ private[udf] object CometBatchKernelCodegenInput {
506498
val isNullAt =
507499
s""" @Override
508500
| public boolean isNullAt(int i) {
509-
| return $elemPath.isNull(startIndex + i);
501+
| return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i);
510502
| }""".stripMargin
511503
val elementGetter = emitArrayElementGetter(path, spec)
512504
s""" private final class InputArray_$path extends $baseClassName {
@@ -574,41 +566,42 @@ private[udf] object CometBatchKernelCodegenInput {
574566
case BooleanType =>
575567
s""" @Override
576568
| public boolean getBoolean(int i) {
577-
| return $childField.get(startIndex + i) == 1;
569+
| return $childField.getBoolean(startIndex + i);
578570
| }""".stripMargin
579571
case ByteType =>
580572
s""" @Override
581573
| public byte getByte(int i) {
582-
| return $childField.get(startIndex + i);
574+
| return $childField.getByte(startIndex + i);
583575
| }""".stripMargin
584576
case ShortType =>
585577
s""" @Override
586578
| public short getShort(int i) {
587-
| return $childField.get(startIndex + i);
579+
| return $childField.getShort(startIndex + i);
588580
| }""".stripMargin
589581
case IntegerType | DateType =>
590582
s""" @Override
591583
| public int getInt(int i) {
592-
| return $childField.get(startIndex + i);
584+
| return $childField.getInt(startIndex + i);
593585
| }""".stripMargin
594586
case LongType | TimestampType | TimestampNTZType =>
595587
s""" @Override
596588
| public long getLong(int i) {
597-
| return $childField.get(startIndex + i);
589+
| return $childField.getLong(startIndex + i);
598590
| }""".stripMargin
599591
case FloatType =>
600592
s""" @Override
601593
| public float getFloat(int i) {
602-
| return $childField.get(startIndex + i);
594+
| return $childField.getFloat(startIndex + i);
603595
| }""".stripMargin
604596
case DoubleType =>
605597
s""" @Override
606598
| public double getDouble(int i) {
607-
| return $childField.get(startIndex + i);
599+
| return $childField.getDouble(startIndex + i);
608600
| }""".stripMargin
609601
case dt: DecimalType =>
610602
val body =
611-
if (dt.precision <= 18) emitDecimalFastBody(childField, "startIndex + i", " ")
603+
if (dt.precision <= 18)
604+
emitDecimalFastBodyUnsafe(s"${childField}_valueAddr", "startIndex + i", " ")
612605
else emitDecimalSlowBody(childField, "startIndex + i", " ")
613606
s""" @Override
614607
| public org.apache.spark.sql.types.Decimal getDecimal(
@@ -618,12 +611,20 @@ private[udf] object CometBatchKernelCodegenInput {
618611
case _: StringType =>
619612
s""" @Override
620613
| public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i) {
621-
|${emitUtf8Body(childField, "startIndex + i", " ")}
614+
|${emitUtf8BodyUnsafe(
615+
s"${childField}_valueAddr",
616+
s"${childField}_offsetAddr",
617+
"startIndex + i",
618+
" ")}
622619
| }""".stripMargin
623620
case BinaryType =>
624621
s""" @Override
625622
| public byte[] getBinary(int i) {
626-
| return $childField.get(startIndex + i);
623+
|${emitBinaryBodyUnsafe(
624+
s"${childField}_valueAddr",
625+
s"${childField}_offsetAddr",
626+
"startIndex + i",
627+
" ")}
627628
| }""".stripMargin
628629
case other =>
629630
throw new UnsupportedOperationException(
@@ -677,8 +678,8 @@ private[udf] object CometBatchKernelCodegenInput {
677678
val isNullCases = spec.fields.zipWithIndex.map {
678679
case (f, fi) if !f.nullable =>
679680
s" case $fi: return false;"
680-
case (_, fi) =>
681-
s" case $fi: return ${path}_f$fi.isNull(this.rowIdx);"
681+
case (f, fi) =>
682+
s" case $fi: return ${path}_f$fi.${nullCheckMethod(f.child)}(this.rowIdx);"
682683
}
683684
val scalarGetters = emitStructScalarGetters(path, spec)
684685
val complexGetters = emitStructComplexGetters(path, spec)
@@ -716,15 +717,34 @@ private[udf] object CometBatchKernelCodegenInput {
716717

717718
def fieldReadScalar(fi: Int, dt: DataType): String = dt match {
718719
case BooleanType =>
719-
s" case $fi: return ${path}_f$fi.get(this.rowIdx) == 1;"
720-
case ByteType | ShortType | IntegerType | DateType | LongType | TimestampType |
721-
TimestampNTZType | FloatType | DoubleType =>
722-
s" case $fi: return ${path}_f$fi.get(this.rowIdx);"
720+
s" case $fi: return ${path}_f$fi.getBoolean(this.rowIdx);"
721+
case ByteType =>
722+
s" case $fi: return ${path}_f$fi.getByte(this.rowIdx);"
723+
case ShortType =>
724+
s" case $fi: return ${path}_f$fi.getShort(this.rowIdx);"
725+
case IntegerType | DateType =>
726+
s" case $fi: return ${path}_f$fi.getInt(this.rowIdx);"
727+
case LongType | TimestampType | TimestampNTZType =>
728+
s" case $fi: return ${path}_f$fi.getLong(this.rowIdx);"
729+
case FloatType =>
730+
s" case $fi: return ${path}_f$fi.getFloat(this.rowIdx);"
731+
case DoubleType =>
732+
s" case $fi: return ${path}_f$fi.getDouble(this.rowIdx);"
723733
case BinaryType =>
724-
s" case $fi: return ${path}_f$fi.get(this.rowIdx);"
734+
s""" case $fi: {
735+
|${emitBinaryBodyUnsafe(
736+
s"${path}_f${fi}_valueAddr",
737+
s"${path}_f${fi}_offsetAddr",
738+
"this.rowIdx",
739+
" ")}
740+
| }""".stripMargin
725741
case _: StringType =>
726742
s""" case $fi: {
727-
|${emitUtf8Body(s"${path}_f$fi", "this.rowIdx", " ")}
743+
|${emitUtf8BodyUnsafe(
744+
s"${path}_f${fi}_valueAddr",
745+
s"${path}_f${fi}_offsetAddr",
746+
"this.rowIdx",
747+
" ")}
728748
| }""".stripMargin
729749
case _: DecimalType =>
730750
throw new IllegalStateException("decimal handled separately")
@@ -782,7 +802,8 @@ private[udf] object CometBatchKernelCodegenInput {
782802
val dt = f.sparkType.asInstanceOf[DecimalType]
783803
val field = s"${path}_f$fi"
784804
val body =
785-
if (dt.precision <= 18) emitDecimalFastBody(field, "this.rowIdx", " ")
805+
if (dt.precision <= 18)
806+
emitDecimalFastBodyUnsafe(s"${field}_valueAddr", "this.rowIdx", " ")
786807
else emitDecimalSlowBody(field, "this.rowIdx", " ")
787808
s""" case $fi: {
788809
|$body
@@ -998,47 +1019,12 @@ private[udf] object CometBatchKernelCodegenInput {
9981019
}
9991020

10001021
// -------------------------------------------------------------------------------------------
1001-
// Scalar-read body templates shared by `emitTypedGetters`, `emitArrayElementScalarGetter`, and
1002-
// `emitStructScalarGetters`. Each helper emits the per-type read statements parameterized on
1003-
// `field` (Java expression for the Arrow vector), `idx` (Java expression for the row/slot),
1004-
// and `ind` (per-line indent prefix). Continuation lines are indented by `ind + " "`. The
1005-
// caller wraps the result in the appropriate control-flow (switch case or method override).
1006-
// -------------------------------------------------------------------------------------------
1007-
1008-
/** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */
1009-
private def castableIdx(idx: String): String = if (idx.contains(' ')) s"($idx)" else idx
1010-
1011-
private def emitDecimalFastBody(field: String, idx: String, ind: String): String = {
1012-
val cont = ind + " "
1013-
val i = castableIdx(idx)
1014-
s"""${ind}long unscaled = $field.getDataBuffer()
1015-
|$cont.getLong((long) $i * 16L);
1016-
|${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$
1017-
|$cont.createUnsafe(unscaled, precision, scale);""".stripMargin
1018-
}
1019-
1020-
private def emitDecimalSlowBody(field: String, idx: String, ind: String): String = {
1021-
val cont = ind + " "
1022-
s"""${ind}java.math.BigDecimal bd = $field.getObject($idx);
1023-
|${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$
1024-
|$cont.apply(bd, precision, scale);""".stripMargin
1025-
}
1026-
1027-
private def emitUtf8Body(field: String, idx: String, ind: String): String = {
1028-
val cont = ind + " "
1029-
s"""${ind}int s = $field.getStartOffset($idx);
1030-
|${ind}int e = $field.getEndOffset($idx);
1031-
|${ind}long addr = $field.getDataBuffer().memoryAddress() + s;
1032-
|${ind}return org.apache.spark.unsafe.types.UTF8String
1033-
|$cont.fromAddress(null, addr, e - s);""".stripMargin
1034-
}
1035-
1036-
// -------------------------------------------------------------------------------------------
1037-
// Unsafe variants for top-level scalar columns. Each batch caches the data-buffer address (and
1038-
// offset-buffer address for variable-width) on the kernel, letting per-row reads go through
1039-
// Platform.get* directly without re-dereferencing the Arrow vector's ArrowBuf per call. Nested
1040-
// classes still use the Arrow-buffer variants above until the same address caching lands at
1041-
// nested-level emission.
1022+
// Scalar-read body templates. Each helper emits the per-type read statements parameterized on
1023+
// a Java expression for the row/slot index (`idx`), the cached buffer address(es) for unsafe
1024+
// reads (`valueAddr`, `offsetAddr`), or the Arrow typed field (`field`) for the slow-path
1025+
// decimal case that still needs `getObject`. `ind` is the per-line indent prefix;
1026+
// continuation lines add four spaces. Callers wrap the output in switch cases or method
1027+
// overrides.
10421028
//
10431029
// The VarChar / VarBinary unsafe emitters below duplicate what CometPlainVector.getUTF8String
10441030
// and getBinary do today, with two differences: they skip CometPlainVector's internal
@@ -1051,6 +1037,16 @@ private[udf] object CometBatchKernelCodegenInput {
10511037
// specialization) unrelated to those issues.
10521038
// -------------------------------------------------------------------------------------------
10531039

1040+
/** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */
1041+
private def castableIdx(idx: String): String = if (idx.contains(' ')) s"($idx)" else idx
1042+
1043+
private def emitDecimalSlowBody(field: String, idx: String, ind: String): String = {
1044+
val cont = ind + " "
1045+
s"""${ind}java.math.BigDecimal bd = $field.getObject($idx);
1046+
|${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$
1047+
|$cont.apply(bd, precision, scale);""".stripMargin
1048+
}
1049+
10541050
private def emitDecimalFastBodyUnsafe(valueAddr: String, idx: String, ind: String): String = {
10551051
val cont = ind + " "
10561052
val i = castableIdx(idx)

docs/source/user-guide/latest/jvm_udf_dispatch.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ Complex (as both input and output, including arbitrary nesting): `ArrayType`, `S
4040

4141
## Configuration
4242

43-
| Key | Default | Description |
44-
| --------------------------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
43+
| Key | Default | Description |
44+
| --------------------------------------- | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
4545
| `spark.comet.exec.codegenDispatch.mode` | `auto` | `auto` routes through JVM codegen when it is the serde's primary path (regex with java engine, ScalaUDF). `force` routes through codegen whenever accepted. `disabled` never routes through codegen. |
46-
| `spark.comet.exec.regexp.engine` | `java` | `java` uses the JVM codegen path for the regex family. `rust` prefers the native DataFusion engine where one exists and falls back to Spark otherwise. |
46+
| `spark.comet.exec.regexp.engine` | `java` | `java` uses the JVM codegen path for the regex family. `rust` prefers the native DataFusion engine where one exists and falls back to Spark otherwise. |
4747

4848
## Regex routing
4949

0 commit comments

Comments
 (0)