Skip to content

Commit d7147db

Browse files
authored
bug: no column projection should still persist row count (#4444)
1 parent 089b6a5 commit d7147db

4 files changed

Lines changed: 65 additions & 0 deletions

File tree

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ jobs:
383383
org.apache.spark.sql.comet.CometDppFallbackRepro3949Suite
384384
org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite
385385
org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite
386+
org.apache.spark.sql.comet.util.UtilsSuite
386387
org.apache.comet.objectstore.NativeConfigSuite
387388
org.apache.spark.sql.CometToPrettyStringSuite
388389
org.apache.spark.sql.CometCollationSuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ jobs:
223223
org.apache.spark.sql.comet.CometDppFallbackRepro3949Suite
224224
org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite
225225
org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite
226+
org.apache.spark.sql.comet.util.UtilsSuite
226227
org.apache.comet.objectstore.NativeConfigSuite
227228
org.apache.spark.sql.CometToPrettyStringSuite
228229
org.apache.spark.sql.CometCollationSuite

spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ object Utils extends CometTypeShim with Logging {
224224

225225
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
226226
val root = new VectorSchemaRoot(fieldVectors.asJava)
227+
if (fieldVectors.isEmpty) {
228+
// VSR cannot infer rowCount without field vectors
229+
root.setRowCount(batch.numRows())
230+
}
227231
val provider = batchProviderOpt.getOrElse(dictionaryProvider)
228232

229233
val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out))
@@ -336,6 +340,11 @@ object Utils extends CometTypeShim with Logging {
336340
return (Array.empty, 0L, 0L)
337341
}
338342

343+
if (targetRoot.getSchema.getFields.isEmpty) {
344+
// VSRAppender does not update rowCount with no columns
345+
targetRoot.setRowCount(totalRows.toInt)
346+
}
347+
339348
assert(
340349
targetRoot.getRowCount.toLong == totalRows,
341350
s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet.util
21+
22+
import org.apache.spark.sql.CometTestBase
23+
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
24+
25+
class UtilsSuite extends CometTestBase {
26+
27+
test("serializeBatches preserves row count for a zero-column batch") {
28+
val numRows = 5
29+
val batch = new ColumnarBatch(Array.empty[ColumnVector], numRows)
30+
31+
val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next()
32+
assert(rowCount == numRows)
33+
34+
val decoded = Utils.decodeBatches(buf, "test").toSeq
35+
assert(decoded.map(_.numRows()).sum == numRows)
36+
}
37+
38+
test("coalesceBroadcastBatches preserves row count across zero-column inputs") {
39+
val numRows = 5
40+
val numBatches = 3
41+
val batches =
42+
(0 until numBatches).map(_ => new ColumnarBatch(Array.empty[ColumnVector], numRows))
43+
44+
val bufs = Utils.serializeBatches(batches.iterator).map(_._2).toSeq.iterator
45+
val (coalesced, batchCount, totalRows) = Utils.coalesceBroadcastBatches(bufs)
46+
47+
val expected = numRows.toLong * numBatches
48+
assert(batchCount == numBatches)
49+
assert(totalRows == expected)
50+
51+
val decoded = coalesced.iterator.flatMap(b => Utils.decodeBatches(b, "test")).toSeq
52+
assert(decoded.map(_.numRows()).sum == expected)
53+
}
54+
}

0 commit comments

Comments
 (0)