Skip to content

Commit 7cc6faf

Browse files
brijrajkclaude
andcommitted
[GLUTEN-12013][VL] Fix bloom-filter bytes corruption on whole-stage AQE fallback
Move BloomFilterMightContainJointRewriteRule from injectPreTransform (Rule[SparkPlan]) to injectOptimizerRule (Rule[LogicalPlan]), modelled after CollectRewriteRule. Running as an optimizer rule ensures both substitutions are baked into the originalPlan snapshot before ExpandFallbackPolicy takes it, so the bloom-filter byte format stays consistent regardless of which stages fall back. Also removes BloomFilterMightContainFallbackPatcher (no longer needed) and cleans up the now-dead CallerInfo.isBloomFilterStatFunction stack- trace detection. Regression tests in GlutenBloomFilterFallbackSuite (gluten-ut/test): - threshold=2: only the filter stage falls back; agg runs natively and emits Velox-format bytes; filter must still read them correctly - threshold=1: both stages fall back; both use Velox variants via JNI - native bloom filter disabled: rule early-exits, vanilla expressions remain and produce consistent Spark-format bytes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7479629 commit 7cc6faf

4 files changed

Lines changed: 229 additions & 73 deletions

File tree

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ object VeloxRuleApi {
5353
private def injectSpark(injector: SparkInjector): Unit = {
5454
// Inject the regular Spark rules directly.
5555
injector.injectOptimizerRule(CollectRewriteRule.apply)
56+
injector.injectOptimizerRule(BloomFilterMightContainJointRewriteRule.apply)
5657
injector.injectOptimizerRule(HLLRewriteRule.apply)
5758
injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
5859
injector.injectOptimizerRule(RewriteCastFromArray.apply)
@@ -81,11 +82,6 @@ object VeloxRuleApi {
8182
injector.injectPreTransform(c => FallbackMultiCodegens.apply(c.session))
8283
injector.injectPreTransform(c => MergeTwoPhasesHashBaseAggregate(c.session))
8384
injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
84-
injector.injectPreTransform(
85-
c =>
86-
BloomFilterMightContainJointRewriteRule.apply(
87-
c.session,
88-
c.caller.isBloomFilterStatFunction()))
8985
injector.injectPreTransform(_ => EliminateRedundantGetTimestamp)
9086

9187
// Legacy: The legacy transform rule.

backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,63 +21,38 @@ import org.apache.gluten.expression.VeloxBloomFilterMightContain
2121
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
2222

2323
import org.apache.spark.sql.SparkSession
24-
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression}
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate}
24+
import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
26+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2627
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.execution.SparkPlan
2828

29-
case class BloomFilterMightContainJointRewriteRule(
30-
spark: SparkSession,
31-
isBloomFilterStatFunction: Boolean)
32-
extends Rule[SparkPlan] {
33-
override def apply(plan: SparkPlan): SparkPlan = {
34-
if (isBloomFilterStatFunction || !GlutenConfig.get.enableNativeBloomFilter) {
29+
/**
30+
* Optimizer rule that rewrites `BloomFilterAggregate` -> `VeloxBloomFilterAggregate` and
31+
* `BloomFilterMightContain` -> `VeloxBloomFilterMightContain` at the logical plan level.
32+
*
33+
* Running as an optimizer rule ensures the substitution is captured in the `originalPlan` snapshot
34+
* that [[org.apache.gluten.extension.columnar.heuristic.ExpandFallbackPolicy]] uses when promoting
35+
* an individual stage fallback to a whole-stage AQE fallback. This guarantees that both sides of
36+
* the bloom-filter pair always produce and consume the same byte format, regardless of whether
37+
* stages fall back to JVM execution after AQE re-planning.
38+
*/
39+
case class BloomFilterMightContainJointRewriteRule(spark: SparkSession)
40+
extends Rule[LogicalPlan] {
41+
42+
override def apply(plan: LogicalPlan): LogicalPlan = {
43+
if (!GlutenConfig.get.enableNativeBloomFilter) {
3544
return plan
3645
}
37-
val out = plan.transformWithSubqueries {
38-
case p =>
39-
applyForNode(p)
40-
}
41-
out
42-
}
43-
44-
private def replaceBloomFilterAggregate[T](
45-
expr: Expression,
46-
bloomFilterAggReplacer: (
47-
Expression,
48-
Expression,
49-
Expression,
50-
Int,
51-
Int) => TypedImperativeAggregate[T]): Expression = expr match {
52-
case BloomFilterAggregate(
53-
child,
54-
estimatedNumItemsExpression,
55-
numBitsExpression,
56-
mutableAggBufferOffset,
57-
inputAggBufferOffset) =>
58-
bloomFilterAggReplacer(
59-
child,
60-
estimatedNumItemsExpression,
61-
numBitsExpression,
62-
mutableAggBufferOffset,
63-
inputAggBufferOffset)
64-
case other => other
65-
}
66-
67-
private def replaceMightContain[T](
68-
expr: Expression,
69-
mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match {
70-
case BloomFilterMightContain(bloomFilterExpression, valueExpression) =>
71-
mightContainReplacer(bloomFilterExpression, valueExpression)
72-
case other => other
73-
}
74-
75-
private def applyForNode(p: SparkPlan) = {
76-
p.transformExpressions {
77-
case e =>
78-
replaceMightContain(
79-
replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply),
80-
VeloxBloomFilterMightContain.apply)
46+
plan.transformAllExpressions {
47+
case aggExpr @ AggregateExpression(b: BloomFilterAggregate, _, _, _, _) =>
48+
aggExpr.copy(aggregateFunction = VeloxBloomFilterAggregate(
49+
b.child,
50+
b.estimatedNumItemsExpression,
51+
b.numBitsExpression,
52+
b.mutableAggBufferOffset,
53+
b.inputAggBufferOffset))
54+
case BloomFilterMightContain(bf, v) =>
55+
VeloxBloomFilterMightContain(bf, v)
8156
}
8257
}
8358
}

gluten-core/src/main/scala/org/apache/gluten/extension/caller/CallerInfo.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ trait CallerInfo {
3030
def isAqe(): Boolean
3131
def isCache(): Boolean
3232
def isStreaming(): Boolean
33-
def isBloomFilterStatFunction(): Boolean
3433
}
3534

3635
object CallerInfo {
@@ -42,8 +41,7 @@ object CallerInfo {
4241
private class Impl(
4342
override val isAqe: Boolean,
4443
override val isCache: Boolean,
45-
override val isStreaming: Boolean,
46-
override val isBloomFilterStatFunction: Boolean
44+
override val isStreaming: Boolean
4745
) extends CallerInfo
4846

4947
/*
@@ -57,8 +55,7 @@ object CallerInfo {
5755
new Impl(
5856
isAqe = inAqeCall(stack),
5957
isCache = inCacheCall(stack),
60-
isStreaming = inStreamingCall(stack),
61-
isBloomFilterStatFunction = inBloomFilterStatFunctionCall(stack))
58+
isStreaming = inStreamingCall(stack))
6259
}
6360

6461
private def inAqeCall(stack: Seq[StackTraceElement]): Boolean = {
@@ -78,21 +75,13 @@ object CallerInfo {
7875
stack.exists(_.getClassName.equals(streamName))
7976
}
8077

81-
private def inBloomFilterStatFunctionCall(stack: Seq[StackTraceElement]): Boolean = {
82-
val res = stack.exists(
83-
_.getClassName.equals("org.apache.spark.sql.DataFrameStatFunctions")
84-
&& stack.exists(_.getMethodName.equals("bloomFilter")))
85-
res
86-
}
87-
8878
/** For testing only. */
8979
def withLocalValue[T](
9080
isAqe: Boolean,
9181
isCache: Boolean,
92-
isStreaming: Boolean = false,
93-
isBloomFilterStatFunction: Boolean = false)(body: => T): T = {
82+
isStreaming: Boolean = false)(body: => T): T = {
9483
val prevValue = localStorage.get()
95-
val newValue = new Impl(isAqe, isCache, isStreaming, isBloomFilterStatFunction)
84+
val newValue = new Impl(isAqe, isCache, isStreaming)
9685
localStorage.set(Some(newValue))
9786
try {
9887
body
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.gluten.sql
18+
19+
import org.apache.gluten.backendsapi.BackendsApiManager
20+
import org.apache.gluten.config.GlutenConfig
21+
import org.apache.gluten.execution.WholeStageTransformerSuite
22+
23+
import org.apache.spark.sql.catalyst.FunctionIdentifier
24+
import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain
25+
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
26+
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
27+
import org.apache.spark.sql.internal.SQLConf
28+
29+
/**
30+
* Regression tests for https://github.com/apache/gluten/issues/12013.
31+
*
32+
* Verifies that `BloomFilterMightContainJointRewriteRule`, registered as a `Rule[LogicalPlan]` via
33+
* `injectOptimizerRule`, correctly handles whole-stage AQE fallback scenarios where one or both
34+
* bloom-filter stages revert to vanilla Spark execution.
35+
*/
36+
class GlutenBloomFilterFallbackSuite extends WholeStageTransformerSuite {
37+
protected val resourcePath: String = null
38+
protected val fileFormat: String = null
39+
40+
import testImplicits._
41+
42+
private val funcIdBloomFilterAgg = FunctionIdentifier("bloom_filter_agg")
43+
private val funcIdMightContain = FunctionIdentifier("might_contain")
44+
45+
override def beforeAll(): Unit = {
46+
super.beforeAll()
47+
spark.sessionState.functionRegistry.registerFunction(
48+
funcIdBloomFilterAgg,
49+
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
50+
args =>
51+
args.size match {
52+
case 1 => new BloomFilterAggregate(args(0))
53+
case 2 => new BloomFilterAggregate(args(0), args(1))
54+
case 3 => new BloomFilterAggregate(args(0), args(1), args(2))
55+
case _ => throw new IllegalArgumentException("bloom_filter_agg requires 1-3 arguments")
56+
}
57+
)
58+
spark.sessionState.functionRegistry.registerFunction(
59+
funcIdMightContain,
60+
new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
61+
args => BloomFilterMightContain(args(0), args(1)))
62+
}
63+
64+
override def afterAll(): Unit = {
65+
spark.sessionState.functionRegistry.dropFunction(funcIdBloomFilterAgg)
66+
spark.sessionState.functionRegistry.dropFunction(funcIdMightContain)
67+
super.afterAll()
68+
}
69+
70+
private val veloxBloomFilterMaxNumBits = 4194304L
71+
72+
// GLUTEN-12013: only filter stage falls back (threshold=2).
73+
// bloom_filter_agg subquery runs natively and produces Velox-format bytes; the filter stage
74+
// falls back via ExpandFallbackPolicy. The optimizer-level substitution ensures the fallback
75+
// plan still uses VeloxBloomFilterMightContain so the JVM filter reads Velox-format bytes.
76+
test("GLUTEN-12013: bloom_filter_agg whole-stage fallback does not corrupt bloom filter bytes") {
77+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
78+
val table = "bloom_filter_test"
79+
val numEstimatedItems = 5000000L
80+
val sqlString =
81+
s"""
82+
|SELECT col positive_membership_test
83+
|FROM $table
84+
|WHERE might_contain(
85+
| (SELECT bloom_filter_agg(col,
86+
| cast($numEstimatedItems as long),
87+
| cast($veloxBloomFilterMaxNumBits as long))
88+
| FROM $table), col)
89+
|""".stripMargin
90+
withTempView(table) {
91+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
92+
.toDF("col")
93+
.createOrReplaceTempView(table)
94+
// Threshold=2: FilterExec fallback cost=2 triggers whole-stage fallback; agg cost=1
95+
// does not, so Stage 0 runs natively. ANSI off keeps agg cost at 1 on Spark 4.0+.
96+
withSQLConf(
97+
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false",
98+
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "2",
99+
SQLConf.ANSI_ENABLED.key -> "false"
100+
) {
101+
val df = spark.sql(sqlString)
102+
// Must not throw: java.io.IOException: Unexpected Bloom filter version number.
103+
assert(df.collect().length == 200003)
104+
// Verify the optimizer rule ran: VeloxBloomFilterMightContain must be present even
105+
// though Stage 1 executes inside a FallbackNode.
106+
assert(
107+
df.queryExecution.optimizedPlan.toString.contains("velox_might_contain"),
108+
"Expected velox_might_contain in optimized plan -- optimizer rule may not have run"
109+
)
110+
}
111+
}
112+
}
113+
}
114+
115+
// GLUTEN-12013: both stages fall back (threshold=1).
116+
// Stage 0's inherent transition cost of 1 meets the threshold so ExpandFallbackPolicy
117+
// promotes it to a whole-stage fallback too. The optimizer rule has already rewritten both
118+
// sides to Velox variants before ExpandFallbackPolicy captures its snapshot. Even in JVM
119+
// row-mode, VeloxBloomFilterAggregate produces Velox-format bytes (via JNI) and
120+
// VeloxBloomFilterMightContain consumes them -- both sides are consistent.
121+
test("GLUTEN-12013: bloom_filter_agg whole-stage fallback when both stages fall back") {
122+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
123+
val table = "bloom_filter_test"
124+
val numEstimatedItems = 5000000L
125+
val sqlString =
126+
s"""
127+
|SELECT col positive_membership_test
128+
|FROM $table
129+
|WHERE might_contain(
130+
| (SELECT bloom_filter_agg(col,
131+
| cast($numEstimatedItems as long),
132+
| cast($veloxBloomFilterMaxNumBits as long))
133+
| FROM $table), col)
134+
|""".stripMargin
135+
withTempView(table) {
136+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
137+
.toDF("col")
138+
.createOrReplaceTempView(table)
139+
// Threshold=1: both stages fall back; both use Velox variants via JNI.
140+
withSQLConf(
141+
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false",
142+
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1",
143+
SQLConf.ANSI_ENABLED.key -> "false"
144+
) {
145+
val df = spark.sql(sqlString)
146+
// Must not throw: java.io.IOException: Unexpected Bloom filter version number.
147+
assert(df.collect().length == 200003)
148+
// Verify the optimizer rule ran on both sides.
149+
assert(
150+
df.queryExecution.optimizedPlan.toString.contains("velox_might_contain"),
151+
"Expected velox_might_contain in optimized plan -- optimizer rule may not have run"
152+
)
153+
}
154+
}
155+
}
156+
}
157+
158+
// GLUTEN-12013: native bloom filter disabled -- early-exit path of the optimizer rule.
159+
// When spark.gluten.sql.native.bloomFilter=false the rule returns the plan unchanged.
160+
// BloomFilterAggregate / BloomFilterMightContain remain as vanilla Spark expressions and
161+
// produce/consume consistent Spark-format bytes.
162+
test(
163+
"GLUTEN-12013: native bloom filter disabled skips rewrite and produces correct results") {
164+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
165+
val table = "bloom_filter_test"
166+
val numEstimatedItems = 5000000L
167+
val sqlString =
168+
s"""
169+
|SELECT col positive_membership_test
170+
|FROM $table
171+
|WHERE might_contain(
172+
| (SELECT bloom_filter_agg(col,
173+
| cast($numEstimatedItems as long),
174+
| cast($veloxBloomFilterMaxNumBits as long))
175+
| FROM $table), col)
176+
|""".stripMargin
177+
withTempView(table) {
178+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
179+
.toDF("col")
180+
.createOrReplaceTempView(table)
181+
withSQLConf(
182+
GlutenConfig.COLUMNAR_NATIVE_BLOOMFILTER_ENABLED.key -> "false",
183+
SQLConf.ANSI_ENABLED.key -> "false"
184+
) {
185+
val df = spark.sql(sqlString)
186+
assert(df.collect().length == 200003)
187+
// Verify the rule early-exited: plan must NOT contain Velox variants.
188+
assert(
189+
!df.queryExecution.optimizedPlan.toString.contains("velox_might_contain"),
190+
"Expected vanilla BloomFilterMightContain when native bloom filter is disabled"
191+
)
192+
}
193+
}
194+
}
195+
}
196+
}

0 commit comments

Comments
 (0)