Skip to content

Commit 0e08018

Browse files
committed
Fix schema mismatch in CometArrowStream.
1 parent 3da08dc commit 0e08018

2 files changed

Lines changed: 187 additions & 6 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,22 @@
1919

2020
package org.apache.spark.sql.comet.execution.arrow
2121

22+
import scala.jdk.CollectionConverters._
23+
2224
import org.apache.arrow.c.{ArrowArrayStream, Data}
2325
import org.apache.arrow.memory.BufferAllocator
2426
import org.apache.arrow.vector.ipc.ArrowReader
27+
import org.apache.arrow.vector.types.pojo.{Field, Schema}
2528
import org.apache.spark.TaskContext
29+
import org.apache.spark.internal.Logging
2630
import org.apache.spark.rdd.RDD
2731
import org.apache.spark.sql.comet.util.Utils
2832
import org.apache.spark.sql.execution.SparkPlan
2933
import org.apache.spark.sql.types.StructType
3034
import org.apache.spark.sql.vectorized.ColumnarBatch
3135

3236
import org.apache.comet.CometArrowAllocator
33-
import org.apache.comet.vector.NativeUtil
37+
import org.apache.comet.vector.{CometDictionaryVector, CometVector, NativeUtil}
3438

3539
/**
3640
* Marker for Comet operators that can produce Arrow data destined for a Comet native executor
@@ -40,7 +44,7 @@ trait CometNativeArrowSource extends SparkPlan {
4044
def doExecuteAsArrowStream(): RDD[ArrowArrayStream]
4145
}
4246

43-
object CometArrowStream {
47+
object CometArrowStream extends Logging {
4448

4549
/**
4650
* Native side asserts `Timestamp(Microsecond, Some("UTC"))` regardless of session timezone;
@@ -61,8 +65,9 @@ object CometArrowStream {
6165
// Arrow `Schema` is not Serializable; only Spark's `StructType` is. Build the Arrow schema
6266
// inside the per-task body so the closure cleaner doesn't try to ship a Schema across.
6367
rdd.mapPartitionsInternal { batchIter =>
64-
val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId)
65-
stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, batchIter))
68+
val expected = Utils.toArrowSchema(sparkSchema, timeZoneId)
69+
val (arrowSchema, iter) = reconcileStreamSchema(name, expected, batchIter)
70+
stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter))
6671
}
6772
}
6873

@@ -76,8 +81,69 @@ object CometArrowStream {
7681
sparkSchema: StructType,
7782
timeZoneId: String,
7883
name: String): ArrowArrayStream = {
79-
val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId)
80-
stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter)).next()
84+
val expected = Utils.toArrowSchema(sparkSchema, timeZoneId)
85+
val (arrowSchema, reconciled) = reconcileStreamSchema(name, expected, iter)
86+
stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, reconciled))
87+
.next()
88+
}
89+
90+
/**
91+
* Build the stream's advertised Arrow schema from the actual `CometVector` types in the first
92+
* batch, not from `expected` (which derives from the consumer's Spark-declared types). Native
93+
* operators like `ScanExec` already cast their input to the declared scan-input schema in
94+
* `build_record_batch`, so the truthful schema lets that cast actually fire. Advertising
95+
* `expected` instead silently mislabels Int32 buffers as Int64 (and similar) and corrupts on
96+
* import. See PR #4393 width_bucket investigation.
97+
*
98+
* If the first batch's column types differ from `expected` in their `DataType` (timezone-only
99+
* differences on `Timestamp` are ignored), log one warning naming the operator, column, and
100+
* type drift; the cast happens transparently downstream in native.
101+
*/
102+
private[arrow] def reconcileStreamSchema(
103+
name: String,
104+
expected: Schema,
105+
iter: Iterator[ColumnarBatch]): (Schema, Iterator[ColumnarBatch]) = {
106+
val buffered = iter.buffered
107+
if (!buffered.hasNext) {
108+
// Empty partition: keep the consumer-declared schema; consumer can still build its plan.
109+
return (expected, buffered)
110+
}
111+
val first = buffered.head
112+
val expectedFields = expected.getFields
113+
val actualFields = (0 until first.numCols()).map { i =>
114+
val col = first.column(i).asInstanceOf[CometVector]
115+
actualFieldOf(col, expectedFields.get(i))
116+
}
117+
val mismatches = actualFields.zip(expectedFields.asScala).zipWithIndex.collect {
118+
case ((actual, exp), idx) if actual.getType != exp.getType =>
119+
s"col[$idx] '${exp.getName}': expected ${exp.getType}, child produced ${actual.getType}"
120+
}
121+
if (mismatches.nonEmpty) {
122+
logWarning(
123+
s"CometArrowStream '$name' input schema mismatch: ${mismatches.mkString("; ")}. " +
124+
"Native ScanExec will cast at the boundary. This usually means a DataFusion-Spark " +
125+
"function declares a different return type than Spark catalyst.")
126+
}
127+
(new Schema(actualFields.asJava), buffered)
128+
}
129+
130+
/**
131+
* The Arrow field that this column's buffers will look like once unloaded. For a
132+
* `CometDictionaryVector`, [[ColumnarBatchArrowReader]] decodes it via
133+
* `DictionaryEncoder.decode` before unloading, so the wire-level field is the dictionary's
134+
* *value* type, not `Dictionary<index, value>`. For everything else, use the underlying value
135+
* vector's field. Field name / nullability / metadata come from `expected` so that consumers
136+
* indexing by name keep working.
137+
*/
138+
private def actualFieldOf(col: CometVector, expected: Field): Field = {
139+
val raw = col match {
140+
case d: CometDictionaryVector =>
141+
val indices = d.getValueVector
142+
val dict = d.provider.lookup(indices.getField.getDictionary.getId)
143+
dict.getVector.getField
144+
case _ => col.getValueVector.getField
145+
}
146+
new Field(expected.getName, raw.getFieldType, raw.getChildren)
81147
}
82148

83149
/**
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.execution.arrow
21+
22+
import scala.jdk.CollectionConverters._
23+
24+
import org.scalatest.funsuite.AnyFunSuite
25+
import org.scalatest.matchers.should.Matchers
26+
27+
import org.apache.arrow.memory.RootAllocator
28+
import org.apache.arrow.vector.{BigIntVector, IntVector}
29+
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
30+
import org.apache.spark.sql.vectorized.ColumnarBatch
31+
32+
import org.apache.comet.vector.{CometPlainVector, CometVector}
33+
34+
/**
35+
* Direct tests for [[CometArrowStream.reconcileStreamSchema]]. The end-to-end regression that
36+
* motivated this (Spark Long vs DataFusion Int32 for `width_bucket`) lives in
37+
* `CometMathExpressionSuite`, but that test only catches *one* function-level type drift. This
38+
* suite covers the boundary contract independently of any specific function.
39+
*/
40+
class CometArrowStreamSuite extends AnyFunSuite with Matchers {
41+
42+
private def expectedSchema(types: (String, ArrowType)*): Schema = {
43+
val fields = types.map { case (name, t) =>
44+
new Field(name, new FieldType(true, t, null), java.util.Collections.emptyList[Field]())
45+
}
46+
new Schema(fields.asJava)
47+
}
48+
49+
private def batchOf(vectors: CometVector*): ColumnarBatch = {
50+
val numRows = if (vectors.isEmpty) 0 else vectors.head.getValueVector.getValueCount
51+
new ColumnarBatch(vectors.toArray, numRows)
52+
}
53+
54+
test("reconcileStreamSchema returns expected schema unchanged on empty iterator") {
55+
val expected = expectedSchema("c0" -> new ArrowType.Int(64, true))
56+
val (returned, iter) =
57+
CometArrowStream.reconcileStreamSchema("test", expected, Iterator.empty)
58+
returned shouldBe expected
59+
iter.hasNext shouldBe false
60+
}
61+
62+
test("reconcileStreamSchema returns expected schema when types match") {
63+
val allocator = new RootAllocator(Integer.MAX_VALUE)
64+
try {
65+
val v = new BigIntVector("col_0", allocator)
66+
v.allocateNew()
67+
v.setSafe(0, 1L)
68+
v.setValueCount(1)
69+
val cv = new CometPlainVector(v, false)
70+
val batch = batchOf(cv)
71+
val expected = expectedSchema("c0" -> new ArrowType.Int(64, true))
72+
73+
val (returned, iter) = CometArrowStream
74+
.reconcileStreamSchema("test", expected, Iterator.single(batch))
75+
76+
returned.getFields.get(0).getType shouldBe new ArrowType.Int(64, true)
77+
iter.hasNext shouldBe true
78+
iter.next() should be theSameInstanceAs batch
79+
80+
cv.close()
81+
} finally {
82+
allocator.close()
83+
}
84+
}
85+
86+
test("reconcileStreamSchema rebuilds schema from actual vector types when they differ") {
87+
val allocator = new RootAllocator(Integer.MAX_VALUE)
88+
try {
89+
// Producer produced Int32 (e.g., DataFusion-Spark width_bucket pre-fix), consumer expects
90+
// Int64 (Spark catalyst WidthBucket.dataType = LongType). The truthful schema is Int32 so
91+
// native ScanExec's build_record_batch can cast at the boundary.
92+
val v = new IntVector("col_0", allocator)
93+
v.allocateNew()
94+
v.setSafe(0, 1)
95+
v.setValueCount(1)
96+
val cv = new CometPlainVector(v, false)
97+
val batch = batchOf(cv)
98+
val expected = expectedSchema("c0" -> new ArrowType.Int(64, true))
99+
100+
val (returned, iter) = CometArrowStream
101+
.reconcileStreamSchema("test", expected, Iterator.single(batch))
102+
103+
val returnedField = returned.getFields.get(0)
104+
returnedField.getType shouldBe new ArrowType.Int(32, true)
105+
// Names come from `expected` so name-indexed consumers keep working.
106+
returnedField.getName shouldBe "c0"
107+
iter.hasNext shouldBe true
108+
iter.next() should be theSameInstanceAs batch
109+
110+
cv.close()
111+
} finally {
112+
allocator.close()
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)