Skip to content

Commit a5046e3

Browse files
committed
Fix nullability mismatch in CometArrowStreamSuite.
1 parent 0e08018 commit a5046e3

2 files changed

Lines changed: 45 additions & 4 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._
2424
import org.apache.arrow.c.{ArrowArrayStream, Data}
2525
import org.apache.arrow.memory.BufferAllocator
2626
import org.apache.arrow.vector.ipc.ArrowReader
27-
import org.apache.arrow.vector.types.pojo.{Field, Schema}
27+
import org.apache.arrow.vector.types.pojo.{Field, FieldType, Schema}
2828
import org.apache.spark.TaskContext
2929
import org.apache.spark.internal.Logging
3030
import org.apache.spark.rdd.RDD
@@ -132,8 +132,14 @@ object CometArrowStream extends Logging {
132132
* `CometDictionaryVector`, [[ColumnarBatchArrowReader]] decodes it via
133133
* `DictionaryEncoder.decode` before unloading, so the wire-level field is the dictionary's
134134
* *value* type, not `Dictionary<index, value>`. For everything else, use the underlying value
135-
* vector's field. Field name / nullability / metadata come from `expected` so that consumers
136-
* indexing by name keep working.
135+
* vector's field.
136+
*
137+
* Field name and metadata come from `expected` so that consumers indexing by name keep working.
138+
* Nullability is the union of the two — a CometVector that happens to hold no nulls in this
139+
* batch can still be nullable per Spark's contract (the next batch may have one), and a column
140+
* whose actual buffer carries validity bits must stay nullable even if Spark thought otherwise.
141+
* Taking only `raw.isNullable` here would advertise non-nullable when the next batch does carry
142+
* a null and crash native validation.
137143
*/
138144
private def actualFieldOf(col: CometVector, expected: Field): Field = {
139145
val raw = col match {
@@ -143,7 +149,10 @@ object CometArrowStream extends Logging {
143149
dict.getVector.getField
144150
case _ => col.getValueVector.getField
145151
}
146-
new Field(expected.getName, raw.getFieldType, raw.getChildren)
152+
val nullable = expected.isNullable || raw.isNullable
153+
val fieldType =
154+
new FieldType(nullable, raw.getType, raw.getDictionary, expected.getMetadata)
155+
new Field(expected.getName, fieldType, raw.getChildren)
147156
}
148157

149158
/**

spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,36 @@ class CometArrowStreamSuite extends AnyFunSuite with Matchers {
112112
allocator.close()
113113
}
114114
}
115+
116+
test("reconcileStreamSchema preserves nullability when expected is nullable but actual is not") {
117+
val allocator = new RootAllocator(Integer.MAX_VALUE)
118+
try {
119+
// Spark catalyst declares the column nullable; the first batch happens to come from a
120+
// vector whose Field reports non-nullable. Subsequent batches may carry nulls, so the
121+
// wire schema must stay nullable or native validation rejects the next null with
122+
// "declared as non-nullable but contains null values".
123+
val v = new BigIntVector(
124+
new Field(
125+
"col_0",
126+
new FieldType(false, new ArrowType.Int(64, true), null),
127+
java.util.Collections.emptyList[Field]()),
128+
allocator)
129+
v.allocateNew()
130+
v.setSafe(0, 1L)
131+
v.setValueCount(1)
132+
val cv = new CometPlainVector(v, false)
133+
val batch = batchOf(cv)
134+
val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) // nullable=true
135+
136+
val (returned, _) = CometArrowStream
137+
.reconcileStreamSchema("test", expected, Iterator.single(batch))
138+
139+
val returnedField = returned.getFields.get(0)
140+
returnedField.isNullable shouldBe true
141+
142+
cv.close()
143+
} finally {
144+
allocator.close()
145+
}
146+
}
115147
}

0 commit comments

Comments
 (0)