Skip to content

Commit cac891f

Browse files
committed
[GLUTEN-12013][VL] Fix bloom-filter bytes corruption on whole-stage AQE fallback
Register BloomFilterMightContainJointRewriteRule as a Rule[LogicalPlan] via injectOptimizerRule so that both BloomFilterAggregate -> VeloxBloomFilterAggregate and BloomFilterMightContain -> VeloxBloomFilterMightContain substitutions are baked into the originalPlan snapshot before ExpandFallbackPolicy takes it. This ensures that when a stage falls back via whole-stage AQE fallback, the fallback plan already uses the Velox variants on both sides of the bloom-filter pair, so the byte format is always consistent regardless of which stages fall back and in what order. This also fixes the case (threshold=1) where Stage 0 itself falls back: the previous FallbackPatcher approach would incorrectly rewrite BloomFilterMightContain -> VeloxBloomFilterMightContain even when Stage 0 produced Spark-format bytes, causing a version-mismatch IOException. With the optimizer rule, both sides are always rewritten together or not at all (when enableNativeBloomFilter=false). Add regression tests covering: - threshold=2: only the filter stage falls back, agg stage runs natively - threshold=1: both stages fall back, agg stage produces Spark-format bytes
1 parent 473a5d2 commit cac891f

3 files changed

Lines changed: 142 additions & 58 deletions

File tree

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ object VeloxRuleApi {
5353
private def injectSpark(injector: SparkInjector): Unit = {
5454
// Inject the regular Spark rules directly.
5555
injector.injectOptimizerRule(CollectRewriteRule.apply)
56+
injector.injectOptimizerRule(BloomFilterMightContainJointRewriteRule.apply)
5657
injector.injectOptimizerRule(HLLRewriteRule.apply)
5758
injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply)
5859
injector.injectOptimizerRule(RewriteCastFromArray.apply)
@@ -81,11 +82,6 @@ object VeloxRuleApi {
8182
injector.injectPreTransform(c => FallbackMultiCodegens.apply(c.session))
8283
injector.injectPreTransform(c => MergeTwoPhasesHashBaseAggregate(c.session))
8384
injector.injectPreTransform(_ => RewriteSubqueryBroadcast())
84-
injector.injectPreTransform(
85-
c =>
86-
BloomFilterMightContainJointRewriteRule.apply(
87-
c.session,
88-
c.caller.isBloomFilterStatFunction()))
8985
injector.injectPreTransform(_ => EliminateRedundantGetTimestamp)
9086

9187
// Legacy: The legacy transform rule.

backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,63 +21,38 @@ import org.apache.gluten.expression.VeloxBloomFilterMightContain
2121
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
2222

2323
import org.apache.spark.sql.SparkSession
24-
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression}
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate}
24+
import org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
26+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2627
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.execution.SparkPlan
2828

29-
case class BloomFilterMightContainJointRewriteRule(
30-
spark: SparkSession,
31-
isBloomFilterStatFunction: Boolean)
32-
extends Rule[SparkPlan] {
33-
override def apply(plan: SparkPlan): SparkPlan = {
34-
if (isBloomFilterStatFunction || !GlutenConfig.get.enableNativeBloomFilter) {
29+
/**
30+
* Optimizer rule that rewrites `BloomFilterAggregate` -> `VeloxBloomFilterAggregate` and
31+
* `BloomFilterMightContain` -> `VeloxBloomFilterMightContain` at the logical plan level.
32+
*
33+
* Running as an optimizer rule ensures the substitution is captured in the `originalPlan` snapshot
34+
* that [[org.apache.gluten.extension.columnar.heuristic.ExpandFallbackPolicy]] uses when promoting
35+
* an individual stage fallback to a whole-stage AQE fallback. This guarantees that both sides of
36+
* the bloom-filter pair always produce and consume the same byte format, regardless of whether
37+
* stages fall back to JVM execution after AQE re-planning.
38+
*/
39+
case class BloomFilterMightContainJointRewriteRule(spark: SparkSession)
40+
extends Rule[LogicalPlan] {
41+
42+
override def apply(plan: LogicalPlan): LogicalPlan = {
43+
if (!GlutenConfig.get.enableNativeBloomFilter) {
3544
return plan
3645
}
37-
val out = plan.transformWithSubqueries {
38-
case p =>
39-
applyForNode(p)
40-
}
41-
out
42-
}
43-
44-
private def replaceBloomFilterAggregate[T](
45-
expr: Expression,
46-
bloomFilterAggReplacer: (
47-
Expression,
48-
Expression,
49-
Expression,
50-
Int,
51-
Int) => TypedImperativeAggregate[T]): Expression = expr match {
52-
case BloomFilterAggregate(
53-
child,
54-
estimatedNumItemsExpression,
55-
numBitsExpression,
56-
mutableAggBufferOffset,
57-
inputAggBufferOffset) =>
58-
bloomFilterAggReplacer(
59-
child,
60-
estimatedNumItemsExpression,
61-
numBitsExpression,
62-
mutableAggBufferOffset,
63-
inputAggBufferOffset)
64-
case other => other
65-
}
66-
67-
private def replaceMightContain[T](
68-
expr: Expression,
69-
mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match {
70-
case BloomFilterMightContain(bloomFilterExpression, valueExpression) =>
71-
mightContainReplacer(bloomFilterExpression, valueExpression)
72-
case other => other
73-
}
74-
75-
private def applyForNode(p: SparkPlan) = {
76-
p.transformExpressions {
77-
case e =>
78-
replaceMightContain(
79-
replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply),
80-
VeloxBloomFilterMightContain.apply)
46+
plan.transformAllExpressions {
47+
case aggExpr @ AggregateExpression(b: BloomFilterAggregate, _, _, _, _) =>
48+
aggExpr.copy(aggregateFunction = VeloxBloomFilterAggregate(
49+
b.child,
50+
b.estimatedNumItemsExpression,
51+
b.numBitsExpression,
52+
b.mutableAggBufferOffset,
53+
b.inputAggBufferOffset))
54+
case BloomFilterMightContain(bf, v) =>
55+
VeloxBloomFilterMightContain(bf, v)
8156
}
8257
}
8358
}

gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenBloomFilterAggregateQuerySuite.scala

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ import org.apache.spark.SparkConf
2424
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2525
import org.apache.spark.sql.internal.SQLConf
2626

27+
import org.scalatest.Tag
28+
29+
/**
30+
* ScalaTest tag for the issue-12013 regression test. Run with:
31+
* {{{
32+
* --test-tags=org.apache.gluten.tags.Issue12013
33+
* }}}
34+
*/
35+
object Issue12013 extends Tag("org.apache.gluten.tags.Issue12013")
36+
2737
class GlutenBloomFilterAggregateQuerySuite
2838
extends BloomFilterAggregateQuerySuite
2939
with GlutenSQLTestsTrait
@@ -112,6 +122,109 @@ class GlutenBloomFilterAggregateQuerySuite
112122
}
113123
}
114124

125+
// Regression test for https://github.com/apache/gluten/issues/12013
126+
// When ExpandFallbackPolicy triggers a whole-stage AQE fallback, the resulting plan comes
127+
// from the original vanilla Spark plan which contains BloomFilterMightContain (not the Velox
128+
// variant). If Stage 0 (bloom_filter_agg subquery) already ran natively it produced Velox-
129+
// format bytes, which BloomFilterImpl.readFrom() cannot deserialize. BloomFilterMightContain-
130+
// FallbackPatcher patches the fallback plan to use VeloxBloomFilterMightContain so Stage 1
131+
// can read Velox bytes via JNI even after falling back to JVM.
132+
testGluten(
133+
"Test bloom_filter_agg whole-stage fallback does not corrupt bloom filter bytes",
134+
Issue12013) {
135+
val table = "bloom_filter_test"
136+
val numEstimatedItems = 5000000L
137+
val sqlString =
138+
s"""
139+
|SELECT col positive_membership_test
140+
|FROM $table
141+
|WHERE might_contain(
142+
| (SELECT bloom_filter_agg(col,
143+
| cast($numEstimatedItems as long),
144+
| cast($veloxBloomFilterMaxNumBits as long))
145+
| FROM $table), col)
146+
|""".stripMargin
147+
148+
withTempView(table) {
149+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
150+
.toDF("col")
151+
.createOrReplaceTempView(table)
152+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
153+
// Disable columnar filter so FilterExec falls back, and set the whole-stage fallback
154+
// threshold so ExpandFallbackPolicy promotes the individual fallback to whole-stage.
155+
// This reproduces the scenario where the filter stage falls back to the original
156+
// vanilla plan while the bloom_filter_agg subquery has already produced Velox-format
157+
// bloom filter bytes.
158+
//
159+
// Threshold=2: a fallen-back FilterExec introduces two ColumnarToRow/RowToColumnar
160+
// transitions (net transition cost=2), which meets the threshold and triggers the
161+
// whole-stage AQE fallback. The bloom_filter_agg subquery stages have an inherent
162+
// transition cost of 1, so they do NOT trigger the fallback and run natively.
163+
//
164+
// ANSI mode must be off: Spark 4.0 enables ANSI by default, which causes
165+
// ObjectHashAggregateExec to fail Gluten validation ("does not support ansi mode"),
166+
// raising the agg-stage transition cost above 1. With ANSI off the agg-stage cost
167+
// stays at 1 (< threshold 2), so only the filter stage falls back as intended.
168+
withSQLConf(
169+
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false",
170+
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "2",
171+
SQLConf.ANSI_ENABLED.key -> "false"
172+
) {
173+
val df = spark.sql(sqlString)
174+
// Must not throw java.io.IOException: Unexpected Bloom filter version number (16777217).
175+
// All 200003 rows match the bloom filter built from the same data.
176+
assert(df.collect().length == 200003L)
177+
}
178+
}
179+
}
180+
}
181+
182+
// Validates that the patcher is also a no-op when Stage 0 (bloom_filter_agg subquery) itself
183+
// falls back via ExpandFallbackPolicy (not just when native bloom filter is disabled via
184+
// config). With threshold=1 the subquery stage's inherent transition cost of 1 meets the
185+
// threshold, so Stage 0 is wrapped in a whole-stage FallbackNode and produces Spark-format
186+
// bytes. A correct patcher must detect this and leave Stage 1 with the vanilla
187+
// BloomFilterMightContain so it can read Spark bytes without a version-mismatch IOException.
188+
testGluten(
189+
"Test bloom_filter_agg whole-stage fallback when both stages fall back",
190+
Issue12013) {
191+
val table = "bloom_filter_test"
192+
val numEstimatedItems = 5000000L
193+
val sqlString =
194+
s"""
195+
|SELECT col positive_membership_test
196+
|FROM $table
197+
|WHERE might_contain(
198+
| (SELECT bloom_filter_agg(col,
199+
| cast($numEstimatedItems as long),
200+
| cast($veloxBloomFilterMaxNumBits as long))
201+
| FROM $table), col)
202+
|""".stripMargin
203+
204+
withTempView(table) {
205+
(Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 200000L))
206+
.toDF("col")
207+
.createOrReplaceTempView(table)
208+
if (BackendsApiManager.getSettings.requireBloomFilterAggMightContainJointFallback()) {
209+
// threshold=1: Stage 0's inherent transition cost of 1 meets the threshold, so
210+
// ExpandFallbackPolicy promotes Stage 0 to a whole-stage fallback as well.
211+
// Stage 0 runs as Spark and produces Spark-format bytes. Stage 1 also falls back
212+
// (COLUMNAR_FILTER_ENABLED=false, cost >= 1). The patcher must NOT rewrite
213+
// BloomFilterMightContain -> VeloxBloomFilterMightContain in this case because
214+
// VeloxBloomFilterMightContain would try to read Spark-format bytes as Velox bytes.
215+
withSQLConf(
216+
GlutenConfig.COLUMNAR_FILTER_ENABLED.key -> "false",
217+
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1",
218+
SQLConf.ANSI_ENABLED.key -> "false"
219+
) {
220+
val df = spark.sql(sqlString)
221+
// Must not throw java.io.IOException: Unexpected Bloom filter version number.
222+
assert(df.collect().length == 200003L)
223+
}
224+
}
225+
}
226+
}
227+
115228
testGluten("Test bloom_filter_agg agg fallback") {
116229
val table = "bloom_filter_test"
117230
val numEstimatedItems = 5000000L

0 commit comments

Comments
 (0)