1919
2020package org .apache .spark .sql .comet .execution .arrow
2121
22+ import scala .jdk .CollectionConverters ._
23+
2224import org .apache .arrow .c .{ArrowArrayStream , Data }
2325import org .apache .arrow .memory .BufferAllocator
2426import org .apache .arrow .vector .ipc .ArrowReader
27+ import org .apache .arrow .vector .types .pojo .{Field , Schema }
2528import org .apache .spark .TaskContext
29+ import org .apache .spark .internal .Logging
2630import org .apache .spark .rdd .RDD
2731import org .apache .spark .sql .comet .util .Utils
2832import org .apache .spark .sql .execution .SparkPlan
2933import org .apache .spark .sql .types .StructType
3034import org .apache .spark .sql .vectorized .ColumnarBatch
3135
3236import org .apache .comet .CometArrowAllocator
33- import org .apache .comet .vector .NativeUtil
37+ import org .apache .comet .vector .{ CometDictionaryVector , CometVector , NativeUtil }
3438
3539/**
3640 * Marker for Comet operators that can produce Arrow data destined for a Comet native executor
@@ -40,7 +44,7 @@ trait CometNativeArrowSource extends SparkPlan {
4044 def doExecuteAsArrowStream (): RDD [ArrowArrayStream ]
4145}
4246
43- object CometArrowStream {
47+ object CometArrowStream extends Logging {
4448
4549 /**
4650 * Native side asserts `Timestamp(Microsecond, Some("UTC"))` regardless of session timezone;
@@ -61,8 +65,9 @@ object CometArrowStream {
6165 // Arrow `Schema` is not Serializable; only Spark's `StructType` is. Build the Arrow schema
6266 // inside the per-task body so the closure cleaner doesn't try to ship a Schema across.
6367 rdd.mapPartitionsInternal { batchIter =>
64- val arrowSchema = Utils .toArrowSchema(sparkSchema, timeZoneId)
65- stream(name, allocator => new ColumnarBatchArrowReader (allocator, arrowSchema, batchIter))
68+ val expected = Utils .toArrowSchema(sparkSchema, timeZoneId)
69+ val (arrowSchema, iter) = reconcileStreamSchema(name, expected, batchIter)
70+ stream(name, allocator => new ColumnarBatchArrowReader (allocator, arrowSchema, iter))
6671 }
6772 }
6873
@@ -76,8 +81,69 @@ object CometArrowStream {
7681 sparkSchema : StructType ,
7782 timeZoneId : String ,
7883 name : String ): ArrowArrayStream = {
79- val arrowSchema = Utils .toArrowSchema(sparkSchema, timeZoneId)
80- stream(name, allocator => new ColumnarBatchArrowReader (allocator, arrowSchema, iter)).next()
84+ val expected = Utils .toArrowSchema(sparkSchema, timeZoneId)
85+ val (arrowSchema, reconciled) = reconcileStreamSchema(name, expected, iter)
86+ stream(name, allocator => new ColumnarBatchArrowReader (allocator, arrowSchema, reconciled))
87+ .next()
88+ }
89+
90+ /**
91+ * Build the stream's advertised Arrow schema from the actual `CometVector` types in the first
92+ * batch, not from `expected` (which derives from the consumer's Spark-declared types). Native
93+ * operators like `ScanExec` already cast their input to the declared scan-input schema in
94+ * `build_record_batch`, so the truthful schema lets that cast actually fire. Advertising
95+ * `expected` instead silently mislabels Int32 buffers as Int64 (and similar) and corrupts on
96+ * import. See PR #4393 width_bucket investigation.
97+ *
98+ * If the first batch's column types differ from `expected` in their `DataType` (timezone-only
99+ * differences on `Timestamp` are ignored), log one warning naming the operator, column, and
100+ * type drift; the cast happens transparently downstream in native.
101+ */
102+ private [arrow] def reconcileStreamSchema (
103+ name : String ,
104+ expected : Schema ,
105+ iter : Iterator [ColumnarBatch ]): (Schema , Iterator [ColumnarBatch ]) = {
106+ val buffered = iter.buffered
107+ if (! buffered.hasNext) {
108+ // Empty partition: keep the consumer-declared schema; consumer can still build its plan.
109+ return (expected, buffered)
110+ }
111+ val first = buffered.head
112+ val expectedFields = expected.getFields
113+ val actualFields = (0 until first.numCols()).map { i =>
114+ val col = first.column(i).asInstanceOf [CometVector ]
115+ actualFieldOf(col, expectedFields.get(i))
116+ }
117+ val mismatches = actualFields.zip(expectedFields.asScala).zipWithIndex.collect {
118+ case ((actual, exp), idx) if actual.getType != exp.getType =>
119+ s " col[ $idx] ' ${exp.getName}': expected ${exp.getType}, child produced ${actual.getType}"
120+ }
121+ if (mismatches.nonEmpty) {
122+ logWarning(
123+ s " CometArrowStream ' $name' input schema mismatch: ${mismatches.mkString(" ; " )}. " +
124+ " Native ScanExec will cast at the boundary. This usually means a DataFusion-Spark " +
125+ " function declares a different return type than Spark catalyst." )
126+ }
127+ (new Schema (actualFields.asJava), buffered)
128+ }
129+
130+ /**
131+ * The Arrow field that this column's buffers will look like once unloaded. For a
132+ * `CometDictionaryVector`, [[ColumnarBatchArrowReader ]] decodes it via
133+ * `DictionaryEncoder.decode` before unloading, so the wire-level field is the dictionary's
134+ * *value* type, not `Dictionary<index, value>`. For everything else, use the underlying value
135+ * vector's field. Field name / nullability / metadata come from `expected` so that consumers
136+ * indexing by name keep working.
137+ */
138+ private def actualFieldOf (col : CometVector , expected : Field ): Field = {
139+ val raw = col match {
140+ case d : CometDictionaryVector =>
141+ val indices = d.getValueVector
142+ val dict = d.provider.lookup(indices.getField.getDictionary.getId)
143+ dict.getVector.getField
144+ case _ => col.getValueVector.getField
145+ }
146+ new Field (expected.getName, raw.getFieldType, raw.getChildren)
81147 }
82148
83149 /**
0 commit comments