Skip to content

Commit a35bbbe

Browse files
schenksjclaude
andcommitted
fix: address code review findings (critical + major)
Fixes from comprehensive code review: Critical: - planner.rs: replaced .expect() with .ok_or_else() for DeltaScan task lookup (prevents panic on edge case) - planner.rs: replaced .unwrap() with enumerate-based index for column mapping rename projection (prevents panic on schema mismatch) Major: - CometDeltaNativeScan.scala: removed unused partitionAttrsByName variable; added case-sensitive-aware partition column lookup using SQLConf.CASE_SENSITIVE (was using case-sensitive-only fieldIndex) - CometDeltaNativeScan.scala: added safe fallback (return unpruned tasks) if partition column index can't be resolved Minor: - delta_dv_filter.rs: converted debug_assert to proper error return for DV index out-of-order condition (was silent in release builds) - predicate.rs: removed #[allow(dead_code)] annotation on public API function catalyst_to_kernel_predicate Tests: succeeded 35, failed 0, canceled 0, ignored 0, pending 0 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6f66c4c commit a35bbbe

5 files changed

Lines changed: 218 additions & 18 deletions

File tree

native/core/src/delta/predicate.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ pub fn catalyst_to_kernel_predicate_with_names(
3333
translate_predicate(expr, column_names)
3434
}
3535

36-
/// Try to translate a Catalyst-proto `Expr` into a kernel `Predicate`.
37-
#[allow(dead_code)]
36+
/// Try to translate a Catalyst-proto `Expr` into a kernel `Predicate`
37+
/// (without column name resolution — BoundReferences become Unknown).
3838
pub fn catalyst_to_kernel_predicate(expr: &Expr) -> Predicate {
3939
translate_predicate(expr, &[])
4040
}

native/core/src/execution/operators/delta_dv_filter.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,11 @@ impl DeltaDvFilterStream {
252252
if d >= batch_end {
253253
break;
254254
}
255-
// Invariant: d >= batch_start (otherwise it would have been
256-
// consumed by a previous batch). Assert defensively.
257-
debug_assert!(
258-
d >= batch_start,
259-
"deletion vector index {d} predates batch start {batch_start}"
260-
);
255+
if d < batch_start {
256+
return Err(DataFusionError::Internal(format!(
257+
"DV index {d} predates batch start {batch_start}"
258+
)));
259+
}
261260
let local = (d - batch_start) as usize;
262261
if local < mask_buf.len() && mask_buf[local] {
263262
mask_buf[local] = false;

native/core/src/execution/planner.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,7 @@ impl PhysicalPlanner {
14591459
.tasks
14601460
.first()
14611461
.map(|t| t.file_path.clone())
1462-
.expect("at least one task after empty check");
1462+
.ok_or_else(|| GeneralError("DeltaScan has no tasks".into()))?;
14631463
let (object_store_url, _) = prepare_object_store_with_configs(
14641464
self.session_ctx.runtime_env(),
14651465
one_file,
@@ -1526,19 +1526,21 @@ impl PhysicalPlanner {
15261526
.map(|(l, p)| (p.clone(), l.clone()))
15271527
.collect();
15281528
let input_schema = final_exec.schema();
1529-
let rename_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = input_schema
1529+
let rename_exprs: Result<Vec<(Arc<dyn PhysicalExpr>, String)>, ExecutionError> = input_schema
15301530
.fields()
15311531
.iter()
1532-
.map(|f| {
1532+
.enumerate()
1533+
.map(|(idx, f)| {
15331534
let col: Arc<dyn PhysicalExpr> =
1534-
Arc::new(Column::new(f.name(), input_schema.index_of(f.name()).unwrap()));
1535+
Arc::new(Column::new(f.name(), idx));
15351536
let logical = physical_to_logical
15361537
.get(f.name())
15371538
.cloned()
15381539
.unwrap_or_else(|| f.name().clone());
1539-
(col, logical)
1540+
Ok((col, logical))
15401541
})
15411542
.collect();
1543+
let rename_exprs = rename_exprs?;
15421544
Arc::new(ProjectionExec::try_new(rename_exprs, final_exec)?) as Arc<dyn ExecutionPlan>
15431545
} else {
15441546
final_exec

spark/src/main/scala/org/apache/comet/serde/operator/CometDeltaNativeScan.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,20 @@ object CometDeltaNativeScan extends CometOperatorSerde[CometScanExec] with Loggi
330330

331331
// Build an `InterpretedPredicate` that expects a row whose schema matches
332332
// `partitionSchema`. Rewrite attribute references to `BoundReference`s keyed by
333-
// partition-schema column name so it can evaluate against a row we assemble below.
334-
val partitionAttrsByName =
335-
staticFilters.flatMap(_.references).groupBy(_.name.toLowerCase(Locale.ROOT))
333+
// partition-schema field index, respecting case sensitivity.
334+
val caseSensitive = scan.conf.getConf[Boolean](SQLConf.CASE_SENSITIVE)
336335
val combined = staticFilters.reduce(And)
337336
val bound = combined.transform {
338337
case a: org.apache.spark.sql.catalyst.expressions.AttributeReference =>
339-
val idx = partitionSchema.fieldIndex(a.name)
338+
val idx = if (caseSensitive) {
339+
partitionSchema.fieldIndex(a.name)
340+
} else {
341+
partitionSchema.fields.indexWhere(
342+
_.name.toLowerCase(Locale.ROOT) == a.name.toLowerCase(Locale.ROOT))
343+
}
344+
if (idx < 0) return tasks // Can't resolve; skip pruning
340345
BoundReference(idx, partitionSchema(idx).dataType, partitionSchema(idx).nullable)
341346
}
342-
val _ = partitionAttrsByName
343347
val predicate = InterpretedPredicate(bound)
344348
predicate.initialize(0)
345349

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import java.nio.file.Files
23+
24+
import org.apache.spark.SparkConf
25+
import org.apache.spark.sql.CometTestBase
26+
27+
/**
28+
* Quick benchmark comparing vanilla Spark+Delta vs Comet+Delta-kernel.
29+
*
30+
* Run with: export SPARK_LOCAL_IP=127.0.0.1 && ./mvnw -Pspark-3.5 -pl spark -am test \
31+
* -Dsuites=org.apache.comet.CometDeltaBenchmarkTest -Dmaven.gitcommitid.skip
32+
*/
33+
class CometDeltaBenchmarkTest extends CometTestBase {
34+
35+
private def deltaSparkAvailable: Boolean =
36+
try {
37+
Class.forName("org.apache.spark.sql.delta.DeltaParquetFileFormat")
38+
true
39+
} catch {
40+
case _: ClassNotFoundException => false
41+
}
42+
43+
override protected def sparkConf: SparkConf = {
44+
val conf = super.sparkConf
45+
conf.set("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
46+
conf.set("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
47+
conf.set("spark.hadoop.fs.file.impl", "org.apache.hadoop.fs.LocalFileSystem")
48+
conf.set("spark.databricks.delta.testOnly.dataFileNamePrefix", "")
49+
conf.set("spark.databricks.delta.testOnly.dvFileNamePrefix", "")
50+
conf
51+
}
52+
53+
test("benchmark: SUM aggregation - vanilla vs Comet native Delta") {
54+
assume(deltaSparkAvailable, "delta-spark not on the test classpath")
55+
56+
val tempDir = Files.createTempDirectory("comet-delta-bench").toFile
57+
try {
58+
val tablePath = new java.io.File(tempDir, "bench").getAbsolutePath
59+
val numRows = 5 * 1000 * 1000 // 5M rows
60+
val numFiles = 4
61+
62+
// scalastyle:off println
63+
println(s"\n=== Comet Delta Benchmark: $numRows rows, $numFiles files ===\n")
64+
// scalastyle:on println
65+
66+
// Generate data
67+
val ss = spark
68+
import ss.implicits._
69+
val df =
70+
(0 until numRows).map(i => (i.toLong, i * 1.5, s"name_$i")).toDF("id", "score", "name")
71+
df.repartition(numFiles).write.format("delta").save(tablePath)
72+
73+
val warmupIters = 2
74+
val benchIters = 5
75+
76+
// Vanilla Spark+Delta
77+
val vanillaTimes = new scala.collection.mutable.ArrayBuffer[Long]()
78+
withSQLConf(
79+
CometConf.COMET_ENABLED.key -> "false",
80+
CometConf.COMET_EXEC_ENABLED.key -> "false") {
81+
for (i <- 0 until (warmupIters + benchIters)) {
82+
val start = System.nanoTime()
83+
spark.sql(s"SELECT SUM(id), SUM(score) FROM delta.`$tablePath`").collect()
84+
val elapsed = (System.nanoTime() - start) / 1000000
85+
if (i >= warmupIters) vanillaTimes += elapsed
86+
}
87+
}
88+
89+
// Comet native Delta
90+
val cometTimes = new scala.collection.mutable.ArrayBuffer[Long]()
91+
withSQLConf(
92+
CometConf.COMET_ENABLED.key -> "true",
93+
CometConf.COMET_EXEC_ENABLED.key -> "true",
94+
CometConf.COMET_DELTA_NATIVE_ENABLED.key -> "true") {
95+
for (i <- 0 until (warmupIters + benchIters)) {
96+
val start = System.nanoTime()
97+
spark.sql(s"SELECT SUM(id), SUM(score) FROM delta.`$tablePath`").collect()
98+
val elapsed = (System.nanoTime() - start) / 1000000
99+
if (i >= warmupIters) cometTimes += elapsed
100+
}
101+
}
102+
103+
val vanillaAvg = vanillaTimes.sum.toDouble / vanillaTimes.size
104+
val cometAvg = cometTimes.sum.toDouble / cometTimes.size
105+
val speedup = vanillaAvg / cometAvg
106+
107+
// scalastyle:off println
108+
println(f"\n=== Results (${benchIters} iterations, ${warmupIters} warmup) ===")
109+
println(
110+
f" Vanilla Spark+Delta: ${vanillaAvg}%.0f ms avg (${vanillaTimes.mkString(", ")} ms)")
111+
println(f" Comet Native Delta: ${cometAvg}%.0f ms avg (${cometTimes.mkString(", ")} ms)")
112+
println(f" Speedup: ${speedup}%.2fx")
113+
println()
114+
// scalastyle:on println
115+
116+
// Don't assert on speedup - just report numbers.
117+
// On debug builds the native path may actually be slower due to no LTO.
118+
} finally {
119+
def deleteRecursively(file: java.io.File): Unit = {
120+
if (file.isDirectory) { Option(file.listFiles()).foreach(_.foreach(deleteRecursively)) }
121+
file.delete()
122+
}
123+
deleteRecursively(tempDir)
124+
}
125+
}
126+
127+
test("benchmark: filter scan - vanilla vs Comet native Delta") {
128+
assume(deltaSparkAvailable, "delta-spark not on the test classpath")
129+
130+
val tempDir = Files.createTempDirectory("comet-delta-bench-filter").toFile
131+
try {
132+
val tablePath = new java.io.File(tempDir, "bench").getAbsolutePath
133+
val numRows = 2 * 1000 * 1000
134+
val numFiles = 4
135+
136+
// scalastyle:off println
137+
println(s"\n=== Comet Delta Filter Benchmark: $numRows rows, $numFiles files ===\n")
138+
// scalastyle:on println
139+
140+
val ss = spark
141+
import ss.implicits._
142+
val df =
143+
(0 until numRows).map(i => (i.toLong, i * 1.5, s"name_$i")).toDF("id", "score", "name")
144+
df.repartition(numFiles).write.format("delta").save(tablePath)
145+
146+
val warmupIters = 2
147+
val benchIters = 5
148+
val query = s"SELECT COUNT(*), SUM(score) FROM delta.`$tablePath` WHERE id > ${numRows / 2}"
149+
150+
val vanillaTimes = new scala.collection.mutable.ArrayBuffer[Long]()
151+
withSQLConf(
152+
CometConf.COMET_ENABLED.key -> "false",
153+
CometConf.COMET_EXEC_ENABLED.key -> "false") {
154+
for (i <- 0 until (warmupIters + benchIters)) {
155+
val start = System.nanoTime()
156+
spark.sql(query).collect()
157+
val elapsed = (System.nanoTime() - start) / 1000000
158+
if (i >= warmupIters) vanillaTimes += elapsed
159+
}
160+
}
161+
162+
val cometTimes = new scala.collection.mutable.ArrayBuffer[Long]()
163+
withSQLConf(
164+
CometConf.COMET_ENABLED.key -> "true",
165+
CometConf.COMET_EXEC_ENABLED.key -> "true",
166+
CometConf.COMET_DELTA_NATIVE_ENABLED.key -> "true") {
167+
for (i <- 0 until (warmupIters + benchIters)) {
168+
val start = System.nanoTime()
169+
spark.sql(query).collect()
170+
val elapsed = (System.nanoTime() - start) / 1000000
171+
if (i >= warmupIters) cometTimes += elapsed
172+
}
173+
}
174+
175+
val vanillaAvg = vanillaTimes.sum.toDouble / vanillaTimes.size
176+
val cometAvg = cometTimes.sum.toDouble / cometTimes.size
177+
val speedup = vanillaAvg / cometAvg
178+
179+
// scalastyle:off println
180+
println(f"\n=== Filter Results (${benchIters} iterations, ${warmupIters} warmup) ===")
181+
println(
182+
f" Vanilla Spark+Delta: ${vanillaAvg}%.0f ms avg (${vanillaTimes.mkString(", ")} ms)")
183+
println(f" Comet Native Delta: ${cometAvg}%.0f ms avg (${cometTimes.mkString(", ")} ms)")
184+
println(f" Speedup: ${speedup}%.2fx")
185+
println()
186+
// scalastyle:on println
187+
} finally {
188+
def deleteRecursively(file: java.io.File): Unit = {
189+
if (file.isDirectory) { Option(file.listFiles()).foreach(_.foreach(deleteRecursively)) }
190+
file.delete()
191+
}
192+
deleteRecursively(tempDir)
193+
}
194+
}
195+
}

0 commit comments

Comments
 (0)