Skip to content

Commit 9d096a3

Browse files
committed
[GLUTEN-12013][VL] Fix bloom-filter bytes corruption on whole-stage AQE fallback
`BloomFilterMightContainJointRewriteRule` previously rewrote every `BloomFilterAggregate` it encountered, including standalone usages such as `DataFrame.stat.bloomFilter()`. That API collects the aggregate output bytes and passes them directly to `BloomFilter.readFrom()`, which expects Spark-native format; receiving Velox-format bytes caused `java.io.IOException: Unexpected Bloom filter version number` (surfaced as a CI failure in `GlutenDataFrameStatSuite - Bloom filter`). Fix: only rewrite `BloomFilterAggregate` when it appears inside the `ScalarSubquery` of a `BloomFilterMightContain`. Standalone aggregates are left untouched so that collected bytes remain in Spark-native format. Add a regression test (`GlutenBloomFilterFallbackSuite`) to guard against reintroducing this regression. Local test results (Spark 4.0, Velox backend): - GlutenDataFrameStatSuite : 25/25 passed (was failing) - GlutenBloomFilterFallbackSuite : 4/4 passed - GlutenBloomFilterAggregateQuerySuite: 14/14 passed - GlutenInjectRuntimeFilterSuite : 13/13 passed
1 parent 5a5516b commit 9d096a3

4 files changed

Lines changed: 255 additions & 73 deletions

File tree

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ object VeloxRuleApi {
5353
private def injectSpark(injector: SparkInjector): Unit = {
5454
// Inject the regular Spark rules directly.
5555
injector.injectOptimizerRule(CollectRewriteRule.apply)
56+
injector.injectOptimizerRule(BloomFilterMightContainJointRewriteRule.apply)
5657
injector.injectOptimizerRule(HLLRewriteRule.apply)
5758
injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
5859
injector.injectOptimizerRule(RewriteCastFromArray.apply)
@@ -81,11 +82,6 @@ object VeloxRuleApi {
8182
injector.injectPreTransform(c => FallbackMultiCodegens.apply(c.session))
8283
injector.injectPreTransform(c => MergeTwoPhasesHashBaseAggregate(c.session))
8384
injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
84-
injector.injectPreTransform(
85-
c =>
86-
BloomFilterMightContainJointRewriteRule.apply(
87-
c.session,
88-
c.caller.isBloomFilterStatFunction()))
8985
injector.injectPreTransform(_ => EliminateRedundantGetTimestamp)
9086

9187
// Legacy: The legacy transform rule.

backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,63 +21,47 @@ import org.apache.gluten.expression.VeloxBloomFilterMightContain
2121
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
2222

2323
import org.apache.spark.sql.SparkSession
24-
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression}
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate}
24+
import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, ScalarSubquery}
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
26+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2627
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.execution.SparkPlan
2828

29-
case class BloomFilterMightContainJointRewriteRule(
30-
spark: SparkSession,
31-
isBloomFilterStatFunction: Boolean)
32-
extends Rule[SparkPlan] {
33-
override def apply(plan: SparkPlan): SparkPlan = {
34-
if (isBloomFilterStatFunction || !GlutenConfig.get.enableNativeBloomFilter) {
29+
/**
30+
* Optimizer rule that rewrites `BloomFilterAggregate` -> `VeloxBloomFilterAggregate` and
31+
* `BloomFilterMightContain` -> `VeloxBloomFilterMightContain` at the logical plan level.
32+
*
33+
* Running as an optimizer rule ensures the substitution is captured in the `originalPlan` snapshot
34+
* that [[org.apache.gluten.extension.columnar.heuristic.ExpandFallbackPolicy]] uses when promoting
35+
* an individual stage fallback to a whole-stage AQE fallback. This guarantees that both sides of
36+
* the bloom-filter pair always produce and consume the same byte format, regardless of whether
37+
* stages fall back to JVM execution after AQE re-planning.
38+
*
39+
* `BloomFilterAggregate` is only rewritten when it appears inside the [[ScalarSubquery]] of a
40+
* [[BloomFilterMightContain]]. Standalone usages (in particular `DataFrame.stat.bloomFilter()`,
41+
* which collects bloom filter bytes and passes them to `BloomFilter.readFrom()`) are intentionally
42+
* left untouched so that the returned bytes remain in Spark-native format.
43+
*/
44+
case class BloomFilterMightContainJointRewriteRule(spark: SparkSession)
45+
extends Rule[LogicalPlan] {
46+
47+
override def apply(plan: LogicalPlan): LogicalPlan = {
48+
if (!GlutenConfig.get.enableNativeBloomFilter) {
3549
return plan
3650
}
37-
val out = plan.transformWithSubqueries {
38-
case p =>
39-
applyForNode(p)
40-
}
41-
out
42-
}
43-
44-
private def replaceBloomFilterAggregate[T](
45-
expr: Expression,
46-
bloomFilterAggReplacer: (
47-
Expression,
48-
Expression,
49-
Expression,
50-
Int,
51-
Int) => TypedImperativeAggregate[T]): Expression = expr match {
52-
case BloomFilterAggregate(
53-
child,
54-
estimatedNumItemsExpression,
55-
numBitsExpression,
56-
mutableAggBufferOffset,
57-
inputAggBufferOffset) =>
58-
bloomFilterAggReplacer(
59-
child,
60-
estimatedNumItemsExpression,
61-
numBitsExpression,
62-
mutableAggBufferOffset,
63-
inputAggBufferOffset)
64-
case other => other
65-
}
66-
67-
private def replaceMightContain[T](
68-
expr: Expression,
69-
mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match {
70-
case BloomFilterMightContain(bloomFilterExpression, valueExpression) =>
71-
mightContainReplacer(bloomFilterExpression, valueExpression)
72-
case other => other
73-
}
74-
75-
private def applyForNode(p: SparkPlan) = {
76-
p.transformExpressions {
77-
case e =>
78-
replaceMightContain(
79-
replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply),
80-
VeloxBloomFilterMightContain.apply)
51+
plan.transformAllExpressions {
52+
case BloomFilterMightContain(subq: ScalarSubquery, v) =>
53+
val rewrittenPlan = subq.plan.transformAllExpressions {
54+
case ae @ AggregateExpression(b: BloomFilterAggregate, _, _, _, _) =>
55+
ae.copy(aggregateFunction = VeloxBloomFilterAggregate(
56+
b.child,
57+
b.estimatedNumItemsExpression,
58+
b.numBitsExpression,
59+
b.mutableAggBufferOffset,
60+
b.inputAggBufferOffset))
61+
}
62+
VeloxBloomFilterMightContain(subq.withNewPlan(rewrittenPlan), v)
63+
case BloomFilterMightContain(bf, v) =>
64+
VeloxBloomFilterMightContain(bf, v)
8165
}
8266
}
8367
}

gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ trait CallerInfo {
3030
def isAqe(): Boolean
3131
def isCache(): Boolean
3232
def isStreaming(): Boolean
33-
def isBloomFilterStatFunction(): Boolean
3433
}
3534

3635
object CallerInfo {
@@ -42,8 +41,7 @@ object CallerInfo {
4241
private class Impl(
4342
override val isAqe: Boolean,
4443
override val isCache: Boolean,
45-
override val isStreaming: Boolean,
46-
override val isBloomFilterStatFunction: Boolean
44+
override val isStreaming: Boolean
4745
) extends CallerInfo
4846

4947
/*
@@ -57,8 +55,7 @@ object CallerInfo {
5755
new Impl(
5856
isAqe = inAqeCall(stack),
5957
isCache = inCacheCall(stack),
60-
isStreaming = inStreamingCall(stack),
61-
isBloomFilterStatFunction = inBloomFilterStatFunctionCall(stack))
58+
isStreaming = inStreamingCall(stack))
6259
}
6360

6461
private def inAqeCall(stack: Seq[StackTraceElement]): Boolean = {
@@ -78,21 +75,13 @@ object CallerInfo {
7875
stack.exists(_.getClassName.equals(streamName))
7976
}
8077

81-
private def inBloomFilterStatFunctionCall(stack: Seq[StackTraceElement]): Boolean = {
82-
val res = stack.exists(
83-
_.getClassName.equals("org.apache.spark.sql.DataFrameStatFunctions")
84-
&& stack.exists(_.getMethodName.equals("bloomFilter")))
85-
res
86-
}
87-
8878
/** For testing only. */
8979
def withLocalValue[T](
9080
isAqe: Boolean,
9181
isCache: Boolean,
92-
isStreaming: Boolean = false,
93-
isBloomFilterStatFunction: Boolean = false)(body: => T): T = {
82+
isStreaming: Boolean = false)(body: => T): T = {
9483
val prevValue = localStorage.get()
95-
val newValue = new Impl(isAqe, isCache, isStreaming, isBloomFilterStatFunction)
84+
val newValue = new Impl(isAqe, isCache, isStreaming)
9685
localStorage.set(Some(newValue))
9786
try {
9887
body
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.sql
18+
19+
import org.apache.gluten.backendsapi.BackendsApiManager
20+
import org.apache.gluten.config.GlutenConfig
21+
import org.apache.gluten.execution.WholeStageTransformerSuite
22+
23+
import org.apache.spark.sql.catalyst.FunctionIdentifier
24+
import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain
25+
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
26+
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
27+
import org.apache.spark.sql.internal.SQLConf
28+
29+
/**
30+
* Regression tests for https://github.com/apache/gluten/issues/12013.
31+
*
32+
* Verifies that `BloomFilterMightContainJointRewriteRule`, registered as a `Rule[LogicalPlan]` via
33+
* `injectOptimizerRule`, correctly handles whole-stage AQE fallback scenarios where one or both
34+
* bloom-filter stages revert to vanilla Spark execution.
35+
*/
36+
class GlutenBloomFilterFallbackSuite extends WholeStageTransformerSuite {
37+
protected val resourcePath: String = null
38+
protected val fileFormat: String = null
39+
40+
import testImplicits._
41+
42+
private val funcIdBloomFilterAgg = FunctionIdentifier("bloom_filter_agg")
43+
private val funcIdMightContain = FunctionIdentifier("might_contain")
44+
45+
override def beforeAll(): Unit = {
46+
super.beforeAll()
47+
spark.sessionState.functionRegistry.registerFunction(
48+
funcIdBloomFilterAgg,
49+
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
50+
args =>
51+
args.size match {
52+
case 1 => new BloomFilterAggregate(args(0))
53+
case 2 => new BloomFilterAggregate(args(0), args(1))
54+
case 3 => new BloomFilterAggregate(args(0), args(1), args(2))
55+
case _ => throw new IllegalArgumentException("bloom_filter_agg requires 1-3 arguments")
56+
}
57+
)
58+
spark.sessionState.functionRegistry.registerFunction(
59+
funcIdMightContain,
60+
new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
61+
args => BloomFilterMightContain(args(0), args(1)))
62+
}
63+
64+
override def afterAll(): Unit = {
65+
spark.sessionState.functionRegistry.dropFunction(funcIdBloomFilterAgg)
66+
spark.sessionState.functionRegistry.dropFunction(funcIdMightContain)
67+
super.afterAll()
68+
}
69+
70+
private val veloxBloomFilterMaxNumBits = 4194304L
71+
72+
// GLUTEN-12013: only filter stage falls back (threshold=2).
73+
// bloom_filter_agg subquery runs natively and produces Velox-format bytes; the filter stage
74+
// falls back via ExpandFallbackPolicy. The optimizer-level substitution ensures the fallback
75+
// plan still uses VeloxBloomFilterMightContain so the JVM filter reads Velox-format bytes.
76+
test("GLUTEN-12013: bloom_filter_agg whole-stage fallback does not corrupt bloom filter bytes") {
77+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
78+
val table = "bloom_filter_test"
79+
val numEstimatedItems = 5000000L
80+
val sqlString =
81+
s"""
82+
|SELECT col positive_membership_test
83+
|FROM $table
84+
|WHERE might_contain(
85+
| (SELECT bloom_filter_agg(col,
86+
| cast($numEstimatedItems as long),
87+
| cast($veloxBloomFilterMaxNumBits as long))
88+
| FROM $table), col)
89+
|""".stripMargin
90+
withTempView(table) {
91+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
92+
.toDF("col")
93+
.createOrReplaceTempView(table)
94+
// Threshold=2: FilterExec fallback cost=2 triggers whole-stage fallback; agg cost=1
95+
// does not, so Stage 0 runs natively. ANSI off keeps agg cost at 1 on Spark 4.0+.
96+
withSQLConf(
97+
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false",
98+
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "2",
99+
SQLConf.ANSI_ENABLED.key -> "false"
100+
) {
101+
val df = spark.sql(sqlString)
102+
// Must not throw: java.io.IOException: Unexpected Bloom filter version number.
103+
assert(df.collect().length == 200003)
104+
// Verify the optimizer rule ran: VeloxBloomFilterMightContain must be present even
105+
// though Stage 1 executes inside a FallbackNode.
106+
assert(
107+
df.queryExecution.optimizedPlan.toString.contains("velox_might_contain"),
108+
"Expected velox_might_contain in optimized plan -- optimizer rule may not have run"
109+
)
110+
}
111+
}
112+
}
113+
}
114+
115+
// GLUTEN-12013: both stages fall back (threshold=1).
116+
// Stage 0's inherent transition cost of 1 meets the threshold so ExpandFallbackPolicy
117+
// promotes it to a whole-stage fallback too. The optimizer rule has already rewritten both
118+
// sides to Velox variants before ExpandFallbackPolicy captures its snapshot. Even in JVM
119+
// row-mode, VeloxBloomFilterAggregate produces Velox-format bytes (via JNI) and
120+
// VeloxBloomFilterMightContain consumes them -- both sides are consistent.
121+
test("GLUTEN-12013: bloom_filter_agg whole-stage fallback when both stages fall back") {
122+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
123+
val table = "bloom_filter_test"
124+
val numEstimatedItems = 5000000L
125+
val sqlString =
126+
s"""
127+
|SELECT col positive_membership_test
128+
|FROM $table
129+
|WHERE might_contain(
130+
| (SELECT bloom_filter_agg(col,
131+
| cast($numEstimatedItems as long),
132+
| cast($veloxBloomFilterMaxNumBits as long))
133+
| FROM $table), col)
134+
|""".stripMargin
135+
withTempView(table) {
136+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
137+
.toDF("col")
138+
.createOrReplaceTempView(table)
139+
// Threshold=1: both stages fall back; both use Velox variants via JNI.
140+
withSQLConf(
141+
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false",
142+
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1",
143+
SQLConf.ANSI_ENABLED.key -> "false"
144+
) {
145+
val df = spark.sql(sqlString)
146+
// Must not throw: java.io.IOException: Unexpected Bloom filter version number.
147+
assert(df.collect().length == 200003)
148+
// Verify the optimizer rule ran on both sides.
149+
assert(
150+
df.queryExecution.optimizedPlan.toString.contains("velox_might_contain"),
151+
"Expected velox_might_contain in optimized plan -- optimizer rule may not have run"
152+
)
153+
}
154+
}
155+
}
156+
}
157+
158+
// GLUTEN-12013: DataFrame.stat.bloomFilter() must not be affected by the optimizer rule.
159+
// The rule must only rewrite BloomFilterAggregate inside a BloomFilterMightContain subquery.
160+
// A standalone BloomFilterAggregate (as used here) must remain vanilla so that the collected
161+
// bytes are in Spark-native format and BloomFilter.readFrom() succeeds.
162+
test("GLUTEN-12013: DataFrame.stat.bloomFilter() produces Spark-readable bytes") {
163+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
164+
val table = "bloom_filter_stat_test"
165+
withTempView(table) {
166+
(1L to 1000L).toDF("col").createOrReplaceTempView(table)
167+
// Must not throw: java.io.IOException: Unexpected Bloom filter version number
168+
val bf = spark.table(table).stat.bloomFilter("col", 1000L, 0.01)
169+
// Bloom filters have no false negatives: every inserted value must be present.
170+
assert(bf.mightContainLong(500L), "Expected 500 to be in bloom filter")
171+
}
172+
}
173+
}
174+
175+
// GLUTEN-12013: native bloom filter disabled -- early-exit path of the optimizer rule.
176+
// When spark.gluten.sql.native.bloomFilter=false the rule returns the plan unchanged.
177+
// BloomFilterAggregate / BloomFilterMightContain remain as vanilla Spark expressions and
178+
// produce/consume consistent Spark-format bytes.
179+
test(
180+
"GLUTEN-12013: native bloom filter disabled skips rewrite and produces correct results") {
181+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
182+
val table = "bloom_filter_test"
183+
val numEstimatedItems = 5000000L
184+
val sqlString =
185+
s"""
186+
|SELECT col positive_membership_test
187+
|FROM $table
188+
|WHERE might_contain(
189+
| (SELECT bloom_filter_agg(col,
190+
| cast($numEstimatedItems as long),
191+
| cast($veloxBloomFilterMaxNumBits as long))
192+
| FROM $table), col)
193+
|""".stripMargin
194+
withTempView(table) {
195+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
196+
.toDF("col")
197+
.createOrReplaceTempView(table)
198+
withSQLConf(
199+
GlutenConfig.COLUMNAR_NATIVE_BLOOMFILTER_ENABLED.key -> "false",
200+
SQLConf.ANSI_ENABLED.key -> "false"
201+
) {
202+
val df = spark.sql(sqlString)
203+
assert(df.collect().length == 200003)
204+
// Verify the rule early-exited: plan must NOT contain Velox variants.
205+
assert(
206+
!df.queryExecution.optimizedPlan.toString.contains("velox_might_contain"),
207+
"Expected vanilla BloomFilterMightContain when native bloom filter is disabled"
208+
)
209+
}
210+
}
211+
}
212+
}
213+
}

0 commit comments

Comments
 (0)