@@ -95,7 +95,7 @@ private[udf] object CometBatchKernelCodegenInput {
9595 val lines = new mutable.ArrayBuffer [String ]()
9696 inputSchema.zipWithIndex.foreach { case (spec, ord) =>
9797 val path = s " col $ord"
98- collectVectorFieldDecls(path, spec, topLevel = true , lines)
98+ collectVectorFieldDecls(path, spec, lines)
9999 collectTopLevelInstanceDecl(path, spec, lines)
100100 }
101101 lines.mkString(" \n " )
@@ -145,43 +145,53 @@ private[udf] object CometBatchKernelCodegenInput {
145145 private def needsOffsetAddrField (cls : Class [_]): Boolean =
146146 cls == classOf [VarCharVector ] || cls == classOf [VarBinaryVector ]
147147
148+ /**
149+ * Java method name for the null check on a column's typed field. Primitive scalars wrapped in
150+ * [[CometPlainVector ]] expose `isNullAt`; Arrow typed fields (complex containers,
151+ * `DecimalVector`, `VarCharVector`, `VarBinaryVector`) expose `isNull`. Both read the validity
152+ * bitmap.
153+ */
154+ private def nullCheckMethod (spec : ArrowColumnSpec ): String = spec match {
155+ case sc : ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => " isNullAt"
156+ case _ => " isNull"
157+ }
158+
148159 private val cometPlainVectorName : String = classOf [CometPlainVector ].getName
149160
150161 private def collectVectorFieldDecls (
151162 path : String ,
152163 spec : ArrowColumnSpec ,
153- topLevel : Boolean ,
154164 out : mutable.ArrayBuffer [String ]): Unit = spec match {
155165 case sc : ScalarColumnSpec =>
156- // CometPlainVector wrapping and cached-address fields apply only at the kernel's top
157- // level. Nested-class children stay on Arrow typed fields because their generated method
158- // bodies (inside `InputArray_*` / `InputStruct_*` / `InputMap_*`) call Arrow-style
159- // `.isNull(i)` / `.get(i)`; converting those too is Phase D .
166+ // Primitive scalar columns (at any nesting depth) are wrapped in CometPlainVector so
167+ // per-row reads go through JIT-inlined Platform.get* against a cached buffer address.
168+ // DecimalVector / VarCharVector / VarBinaryVector stay on the Arrow typed field but
169+ // cache data- and (variable-width) offset-buffer addresses for inline unsafe reads .
160170 val fieldClass =
161- if (topLevel && wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName
171+ if (wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName
162172 else sc.vectorClass.getName
163173 out += s " private $fieldClass $path; "
164- if (topLevel && needsValueAddrField(sc.vectorClass)) {
174+ if (needsValueAddrField(sc.vectorClass)) {
165175 out += s " private long ${path}_valueAddr; "
166176 }
167- if (topLevel && needsOffsetAddrField(sc.vectorClass)) {
177+ if (needsOffsetAddrField(sc.vectorClass)) {
168178 out += s " private long ${path}_offsetAddr; "
169179 }
170180 case ar : ArrayColumnSpec =>
171181 out += s " private ${classOf [ListVector ].getName} $path; "
172- collectVectorFieldDecls(s " ${path}_e " , ar.element, topLevel = false , out)
182+ collectVectorFieldDecls(s " ${path}_e " , ar.element, out)
173183 case st : StructColumnSpec =>
174184 out += s " private ${classOf [StructVector ].getName} $path; "
175185 st.fields.zipWithIndex.foreach { case (f, fi) =>
176- collectVectorFieldDecls(s " ${path}_f $fi" , f.child, topLevel = false , out)
186+ collectVectorFieldDecls(s " ${path}_f $fi" , f.child, out)
177187 }
178188 case mp : MapColumnSpec =>
179189 out += s " private ${classOf [MapVector ].getName} $path; "
180190 // Key and value vectors live at `${P}_k_e` / `${P}_v_e` so the `InputArray_${P}_k` /
181191 // `InputArray_${P}_v` synthetic classes (which follow the array-element convention of
182192 // reading from `${path}_e`) resolve their element reads correctly.
183- collectVectorFieldDecls(s " ${path}_k_e " , mp.key, topLevel = false , out)
184- collectVectorFieldDecls(s " ${path}_v_e " , mp.value, topLevel = false , out)
193+ collectVectorFieldDecls(s " ${path}_k_e " , mp.key, out)
194+ collectVectorFieldDecls(s " ${path}_v_e " , mp.value, out)
185195 }
186196
187197 private def collectTopLevelInstanceDecl (
@@ -208,7 +218,7 @@ private[udf] object CometBatchKernelCodegenInput {
208218 val lines = new mutable.ArrayBuffer [String ]()
209219 inputSchema.zipWithIndex.foreach { case (spec, ord) =>
210220 val path = s " col $ord"
211- collectCasts(path, spec, s " inputs[ $ord] " , topLevel = true , lines)
221+ collectCasts(path, spec, s " inputs[ $ord] " , lines)
212222 }
213223 lines.mkString(" \n " )
214224 }
@@ -217,38 +227,30 @@ private[udf] object CometBatchKernelCodegenInput {
217227 path : String ,
218228 spec : ArrowColumnSpec ,
219229 source : String ,
220- topLevel : Boolean ,
221230 out : mutable.ArrayBuffer [String ]): Unit = spec match {
222231 case sc : ScalarColumnSpec =>
223- if (topLevel && wrapsInCometPlainVector(sc.vectorClass)) {
232+ if (wrapsInCometPlainVector(sc.vectorClass)) {
224233 // Wrap in CometPlainVector so per-row reads go through Platform.get* against a final
225234 // long buffer address. JIT inlines the one-liner getters, treating the address as a
226- // register-cached constant across the process loop. useDecimal128 = true matches Spark's
227- // 128-bit decimal storage.
235+ // register-cached constant across the process loop. useDecimal128 = true matches
236+ // Spark's 128-bit decimal storage.
228237 out += s " this. $path = new $cometPlainVectorName( $source, true); "
229238 } else {
230239 out += s " this. $path = ( ${sc.vectorClass.getName}) $source; "
231240 }
232- // Address caching applies only at the kernel top level; nested-class reads still go
233- // through Arrow typed getters (Phase D).
234- if (topLevel && needsValueAddrField(sc.vectorClass)) {
241+ if (needsValueAddrField(sc.vectorClass)) {
235242 out += s " this. ${path}_valueAddr = this. $path.getDataBuffer().memoryAddress(); "
236243 }
237- if (topLevel && needsOffsetAddrField(sc.vectorClass)) {
244+ if (needsOffsetAddrField(sc.vectorClass)) {
238245 out += s " this. ${path}_offsetAddr = this. $path.getOffsetBuffer().memoryAddress(); "
239246 }
240247 case ar : ArrayColumnSpec =>
241248 out += s " this. $path = ( ${classOf [ListVector ].getName}) $source; "
242- collectCasts(s " ${path}_e " , ar.element, s " this. $path.getDataVector() " , topLevel = false , out)
249+ collectCasts(s " ${path}_e " , ar.element, s " this. $path.getDataVector() " , out)
243250 case st : StructColumnSpec =>
244251 out += s " this. $path = ( ${classOf [StructVector ].getName}) $source; "
245252 st.fields.zipWithIndex.foreach { case (f, fi) =>
246- collectCasts(
247- s " ${path}_f $fi" ,
248- f.child,
249- s " this. $path.getChildByOrdinal( $fi) " ,
250- topLevel = false ,
251- out)
253+ collectCasts(s " ${path}_f $fi" , f.child, s " this. $path.getChildByOrdinal( $fi) " , out)
252254 }
253255 case mp : MapColumnSpec =>
254256 // MapVector's data vector is a StructVector with key at child 0 and value at child 1.
@@ -259,18 +261,8 @@ private[udf] object CometBatchKernelCodegenInput {
259261 out += s " this. $path = ( ${classOf [MapVector ].getName}) $source; "
260262 out += s " ${classOf [StructVector ].getName} $structLocal = " +
261263 s " ( ${classOf [StructVector ].getName}) this. $path.getDataVector(); "
262- collectCasts(
263- s " ${path}_k_e " ,
264- mp.key,
265- s " $structLocal.getChildByOrdinal(0) " ,
266- topLevel = false ,
267- out)
268- collectCasts(
269- s " ${path}_v_e " ,
270- mp.value,
271- s " $structLocal.getChildByOrdinal(1) " ,
272- topLevel = false ,
273- out)
264+ collectCasts(s " ${path}_k_e " , mp.key, s " $structLocal.getChildByOrdinal(0) " , out)
265+ collectCasts(s " ${path}_v_e " , mp.value, s " $structLocal.getChildByOrdinal(1) " , out)
274266 }
275267
276268 /**
@@ -506,7 +498,7 @@ private[udf] object CometBatchKernelCodegenInput {
506498 val isNullAt =
507499 s """ @Override
508500 | public boolean isNullAt(int i) {
509- | return $elemPath.isNull (startIndex + i);
501+ | return $elemPath. ${nullCheckMethod(spec.element)} (startIndex + i);
510502 | } """ .stripMargin
511503 val elementGetter = emitArrayElementGetter(path, spec)
512504 s """ private final class InputArray_ $path extends $baseClassName {
@@ -574,41 +566,42 @@ private[udf] object CometBatchKernelCodegenInput {
574566 case BooleanType =>
575567 s """ @Override
576568 | public boolean getBoolean(int i) {
577- | return $childField.get (startIndex + i) == 1 ;
569+ | return $childField.getBoolean (startIndex + i);
578570 | } """ .stripMargin
579571 case ByteType =>
580572 s """ @Override
581573 | public byte getByte(int i) {
582- | return $childField.get (startIndex + i);
574+ | return $childField.getByte (startIndex + i);
583575 | } """ .stripMargin
584576 case ShortType =>
585577 s """ @Override
586578 | public short getShort(int i) {
587- | return $childField.get (startIndex + i);
579+ | return $childField.getShort (startIndex + i);
588580 | } """ .stripMargin
589581 case IntegerType | DateType =>
590582 s """ @Override
591583 | public int getInt(int i) {
592- | return $childField.get (startIndex + i);
584+ | return $childField.getInt (startIndex + i);
593585 | } """ .stripMargin
594586 case LongType | TimestampType | TimestampNTZType =>
595587 s """ @Override
596588 | public long getLong(int i) {
597- | return $childField.get (startIndex + i);
589+ | return $childField.getLong (startIndex + i);
598590 | } """ .stripMargin
599591 case FloatType =>
600592 s """ @Override
601593 | public float getFloat(int i) {
602- | return $childField.get (startIndex + i);
594+ | return $childField.getFloat (startIndex + i);
603595 | } """ .stripMargin
604596 case DoubleType =>
605597 s """ @Override
606598 | public double getDouble(int i) {
607- | return $childField.get (startIndex + i);
599+ | return $childField.getDouble (startIndex + i);
608600 | } """ .stripMargin
609601 case dt : DecimalType =>
610602 val body =
611- if (dt.precision <= 18 ) emitDecimalFastBody(childField, " startIndex + i" , " " )
603+ if (dt.precision <= 18 )
604+ emitDecimalFastBodyUnsafe(s " ${childField}_valueAddr " , " startIndex + i" , " " )
612605 else emitDecimalSlowBody(childField, " startIndex + i" , " " )
613606 s """ @Override
614607 | public org.apache.spark.sql.types.Decimal getDecimal(
@@ -618,12 +611,20 @@ private[udf] object CometBatchKernelCodegenInput {
618611 case _ : StringType =>
619612 s """ @Override
620613 | public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i) {
621- | ${emitUtf8Body(childField, " startIndex + i" , " " )}
614+ | ${emitUtf8BodyUnsafe(
615+ s " ${childField}_valueAddr " ,
616+ s " ${childField}_offsetAddr " ,
617+ " startIndex + i" ,
618+ " " )}
622619 | } """ .stripMargin
623620 case BinaryType =>
624621 s """ @Override
625622 | public byte[] getBinary(int i) {
626- | return $childField.get(startIndex + i);
623+ | ${emitBinaryBodyUnsafe(
624+ s " ${childField}_valueAddr " ,
625+ s " ${childField}_offsetAddr " ,
626+ " startIndex + i" ,
627+ " " )}
627628 | } """ .stripMargin
628629 case other =>
629630 throw new UnsupportedOperationException (
@@ -677,8 +678,8 @@ private[udf] object CometBatchKernelCodegenInput {
677678 val isNullCases = spec.fields.zipWithIndex.map {
678679 case (f, fi) if ! f.nullable =>
679680 s " case $fi: return false; "
680- case (_ , fi) =>
681- s " case $fi: return ${path}_f $fi.isNull (this.rowIdx); "
681+ case (f , fi) =>
682+ s " case $fi: return ${path}_f $fi. ${nullCheckMethod(f.child)} (this.rowIdx); "
682683 }
683684 val scalarGetters = emitStructScalarGetters(path, spec)
684685 val complexGetters = emitStructComplexGetters(path, spec)
@@ -716,15 +717,34 @@ private[udf] object CometBatchKernelCodegenInput {
716717
717718 def fieldReadScalar (fi : Int , dt : DataType ): String = dt match {
718719 case BooleanType =>
719- s " case $fi: return ${path}_f $fi.get(this.rowIdx) == 1; "
720- case ByteType | ShortType | IntegerType | DateType | LongType | TimestampType |
721- TimestampNTZType | FloatType | DoubleType =>
722- s " case $fi: return ${path}_f $fi.get(this.rowIdx); "
720+ s " case $fi: return ${path}_f $fi.getBoolean(this.rowIdx); "
721+ case ByteType =>
722+ s " case $fi: return ${path}_f $fi.getByte(this.rowIdx); "
723+ case ShortType =>
724+ s " case $fi: return ${path}_f $fi.getShort(this.rowIdx); "
725+ case IntegerType | DateType =>
726+ s " case $fi: return ${path}_f $fi.getInt(this.rowIdx); "
727+ case LongType | TimestampType | TimestampNTZType =>
728+ s " case $fi: return ${path}_f $fi.getLong(this.rowIdx); "
729+ case FloatType =>
730+ s " case $fi: return ${path}_f $fi.getFloat(this.rowIdx); "
731+ case DoubleType =>
732+ s " case $fi: return ${path}_f $fi.getDouble(this.rowIdx); "
723733 case BinaryType =>
724- s " case $fi: return ${path}_f $fi.get(this.rowIdx); "
734+ s """ case $fi: {
735+ | ${emitBinaryBodyUnsafe(
736+ s " ${path}_f ${fi}_valueAddr " ,
737+ s " ${path}_f ${fi}_offsetAddr " ,
738+ " this.rowIdx" ,
739+ " " )}
740+ | } """ .stripMargin
725741 case _ : StringType =>
726742 s """ case $fi: {
727- | ${emitUtf8Body(s " ${path}_f $fi" , " this.rowIdx" , " " )}
743+ | ${emitUtf8BodyUnsafe(
744+ s " ${path}_f ${fi}_valueAddr " ,
745+ s " ${path}_f ${fi}_offsetAddr " ,
746+ " this.rowIdx" ,
747+ " " )}
728748 | } """ .stripMargin
729749 case _ : DecimalType =>
730750 throw new IllegalStateException (" decimal handled separately" )
@@ -782,7 +802,8 @@ private[udf] object CometBatchKernelCodegenInput {
782802 val dt = f.sparkType.asInstanceOf [DecimalType ]
783803 val field = s " ${path}_f $fi"
784804 val body =
785- if (dt.precision <= 18 ) emitDecimalFastBody(field, " this.rowIdx" , " " )
805+ if (dt.precision <= 18 )
806+ emitDecimalFastBodyUnsafe(s " ${field}_valueAddr " , " this.rowIdx" , " " )
786807 else emitDecimalSlowBody(field, " this.rowIdx" , " " )
787808 s """ case $fi: {
788809 | $body
@@ -998,47 +1019,12 @@ private[udf] object CometBatchKernelCodegenInput {
9981019 }
9991020
10001021 // -------------------------------------------------------------------------------------------
1001- // Scalar-read body templates shared by `emitTypedGetters`, `emitArrayElementScalarGetter`, and
1002- // `emitStructScalarGetters`. Each helper emits the per-type read statements parameterized on
1003- // `field` (Java expression for the Arrow vector), `idx` (Java expression for the row/slot),
1004- // and `ind` (per-line indent prefix). Continuation lines are indented by `ind + " "`. The
1005- // caller wraps the result in the appropriate control-flow (switch case or method override).
1006- // -------------------------------------------------------------------------------------------
1007-
1008- /** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */
1009- private def castableIdx (idx : String ): String = if (idx.contains(' ' )) s " ( $idx) " else idx
1010-
1011- private def emitDecimalFastBody (field : String , idx : String , ind : String ): String = {
1012- val cont = ind + " "
1013- val i = castableIdx(idx)
1014- s """ ${ind}long unscaled = $field.getDataBuffer()
1015- | $cont.getLong((long) $i * 16L);
1016- | ${ind}return org.apache.spark.sql.types.Decimal $$ .MODULE $$
1017- | $cont.createUnsafe(unscaled, precision, scale); """ .stripMargin
1018- }
1019-
1020- private def emitDecimalSlowBody (field : String , idx : String , ind : String ): String = {
1021- val cont = ind + " "
1022- s """ ${ind}java.math.BigDecimal bd = $field.getObject( $idx);
1023- | ${ind}return org.apache.spark.sql.types.Decimal $$ .MODULE $$
1024- | $cont.apply(bd, precision, scale); """ .stripMargin
1025- }
1026-
1027- private def emitUtf8Body (field : String , idx : String , ind : String ): String = {
1028- val cont = ind + " "
1029- s """ ${ind}int s = $field.getStartOffset( $idx);
1030- | ${ind}int e = $field.getEndOffset( $idx);
1031- | ${ind}long addr = $field.getDataBuffer().memoryAddress() + s;
1032- | ${ind}return org.apache.spark.unsafe.types.UTF8String
1033- | $cont.fromAddress(null, addr, e - s); """ .stripMargin
1034- }
1035-
1036- // -------------------------------------------------------------------------------------------
1037- // Unsafe variants for top-level scalar columns. Each batch caches the data-buffer address (and
1038- // offset-buffer address for variable-width) on the kernel, letting per-row reads go through
1039- // Platform.get* directly without re-dereferencing the Arrow vector's ArrowBuf per call. Nested
1040- // classes still use the Arrow-buffer variants above until the same address caching lands at
1041- // nested-level emission.
1022+ // Scalar-read body templates. Each helper emits the per-type read statements parameterized on
1023+ // a Java expression for the row/slot index (`idx`), the cached buffer address(es) for unsafe
1024+ // reads (`valueAddr`, `offsetAddr`), or the Arrow typed field (`field`) for the slow-path
1025+ // decimal case that still needs `getObject`. `ind` is the per-line indent prefix;
1026+ // continuation lines add four spaces. Callers wrap the output in switch cases or method
1027+ // overrides.
10421028 //
10431029 // The VarChar / VarBinary unsafe emitters below duplicate what CometPlainVector.getUTF8String
10441030 // and getBinary do today, with two differences: they skip CometPlainVector's internal
@@ -1051,6 +1037,16 @@ private[udf] object CometBatchKernelCodegenInput {
10511037 // specialization) unrelated to those issues.
10521038 // -------------------------------------------------------------------------------------------
10531039
1040+ /** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */
1041+ private def castableIdx (idx : String ): String = if (idx.contains(' ' )) s " ( $idx) " else idx
1042+
1043+ private def emitDecimalSlowBody (field : String , idx : String , ind : String ): String = {
1044+ val cont = ind + " "
1045+ s """ ${ind}java.math.BigDecimal bd = $field.getObject( $idx);
1046+ | ${ind}return org.apache.spark.sql.types.Decimal $$ .MODULE $$
1047+ | $cont.apply(bd, precision, scale); """ .stripMargin
1048+ }
1049+
10541050 private def emitDecimalFastBodyUnsafe (valueAddr : String , idx : String , ind : String ): String = {
10551051 val cont = ind + " "
10561052 val i = castableIdx(idx)
0 commit comments