Skip to content

Commit 50cc6a1

Browse files
committed
[SPARK-57512][SQL] Materialize surviving RuntimeReplaceable for cached-batch pruning
Address review: a surviving RuntimeReplaceable in a cached scan's pushed-down predicates was not recognized by CachedBatchSerializer.buildFilter, silently disabling cached-batch pruning under AQE (the scan is wrapped in a leaf TableCacheQueryStageExec that the stage-finalization MaterializeRuntimeReplaceable cannot reach). Fix it at the pushdown consumer: unfold RuntimeReplaceable in InMemoryTableScanExec before calling buildFilter, rather than extending the codegen-prep rule into the non-codegen leaf scan. This covers AQE and non-AQE uniformly and keeps the readable expression in the plan/EXPLAIN. Document why codegen-materialization lives in preparations/postStageCreationRules and why the InMemory branch deliberately skips it. Co-authored-by: Isaac
1 parent b80438b commit 50cc6a1

4 files changed

Lines changed: 123 additions & 8 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/MaterializeRuntimeReplaceable.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,25 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.RUNTIME_REPLACEABLE
2727
*
2828
* A `RuntimeReplaceable` with `eagerReplace = false` is intentionally kept in the plan by the
2929
* optimizer (see `ReplaceExpressions`) so that a native engine can match the high-level expression
30-
* directly. This rule then materializes the replacement for the Spark execution path, so Spark
31-
* codegen/interpreted evaluation behaves exactly as today. It is placed after the columnar/native
32-
* conversion and before `CollapseCodegenStages`, so a native engine sees the original expression
33-
* while Spark whole-stage codegen never sees a `RuntimeReplaceable`.
30+
* directly. This rule then materializes the replacement for the Spark execution path. It is placed
31+
* after the columnar/native conversion and before `CollapseCodegenStages`, so a native engine sees
32+
* the original expression while Spark whole-stage codegen never sees a `RuntimeReplaceable`.
33+
*
34+
* Materializing before codegen is a correctness requirement, not just cleanup. A surviving
35+
* `RuntimeReplaceable` evaluates correctly on its own (`eval`/`doGenCode` delegate to
36+
* `replacement`), but whole-stage codegen reasons about `references`, input materialization, and
37+
* subexpression elimination via the node's `children`, while the emitted code comes from
38+
* `replacement`. When `replacement` reads an input differently from `children` (e.g. an
39+
* un-simplified branch that reads a column more than once), that mismatch produces invalid
40+
* generated code. Eager replacement avoids this because the unfolded form is then simplified by the
41+
* optimizer; a survivor's `replacement` is not, so it must be unfolded before codegen.
42+
*
43+
* This runs wherever a physical plan is finalized into a codegen-bearing form:
44+
* `QueryExecution.preparations` (non-AQE) and `AdaptiveSparkPlanExec.postStageCreationRules` (AQE,
45+
* applied to every codegen-producing stage). A `RuntimeReplaceable` that only feeds a structural,
46+
* non-codegen consumer is materialized at that consumer instead -- see the cached-batch pruning
47+
* predicates in `InMemoryTableScanExec`, whose leaf scan never reaches codegen and is unreachable
48+
* from AQE stage finalization.
3449
*/
3550
object MaterializeRuntimeReplaceable extends Rule[SparkPlan] {
3651
override def apply(plan: SparkPlan): SparkPlan = plan.transformUpWithSubqueries {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,10 @@ case class AdaptiveSparkPlanExec(
713713
case i: InMemoryTableScanLike =>
714714
// Apply `queryStageOptimizerRules` so that we can reuse subquery.
715715
// No need to apply `postStageCreationRules` for `InMemoryTableScanLike`
716-
// as it's a leaf node.
716+
// as it's a leaf node. In particular, `MaterializeRuntimeReplaceable` is intentionally not
717+
// applied here: this scan does not reach whole-stage codegen, and its only expressions that
718+
// may hold a surviving `RuntimeReplaceable` are the pushed-down `predicates`, which are
719+
// materialized at their consumer in `InMemoryTableScanExec` (see `filteredCachedBatches`).
717720
val newPlan = optimizeQueryStage(i, isFinalStage = false)
718721
if (!newPlan.isInstanceOf[InMemoryTableScanLike]) {
719722
throw SparkException.internalError(

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,18 @@ case class InMemoryTableScanExec(
141141
val buffers = relation.cacheBuilder.cachedColumnBuffers
142142

143143
if (inMemoryPartitionPruningEnabled) {
144-
val filterFunc = relation.cacheBuilder.serializer.buildFilter(predicates, relation.output)
144+
// `predicates` may contain a surviving `RuntimeReplaceable` (`eagerReplace = false`), which
145+
// is intentionally kept in the plan. `buildFilter` matches on expression shape to build the
146+
// cached-batch pruning filter, so it must see the materialized form. We unfold here, at the
147+
// consumer, rather than relying on the codegen-prep materialization rule
148+
// (`MaterializeRuntimeReplaceable` in `QueryExecution.preparations` /
149+
// `AdaptiveSparkPlanExec.postStageCreationRules`): this scan is a leaf that never reaches
150+
// whole-stage codegen, and under AQE it is wrapped in a `TableCacheQueryStageExec` that the
151+
// stage-finalization rules cannot descend into. Unfolding here covers both the AQE and
152+
// non-AQE paths uniformly while keeping the readable expression in the plan/EXPLAIN output.
153+
val materializedPredicates = predicates.map(RuntimeReplaceable.unfold)
154+
val filterFunc =
155+
relation.cacheBuilder.serializer.buildFilter(materializedPredicates, relation.output)
145156
buffers.mapPartitionsWithIndexInternal(filterFunc)
146157
} else {
147158
buffers

sql/core/src/test/scala/org/apache/spark/sql/execution/MaterializeRuntimeReplaceableSuite.scala

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.{QueryTest, Row}
21-
import org.apache.spark.sql.catalyst.expressions.{Add, Expression, Literal, RuntimeReplaceable}
21+
import org.apache.spark.sql.catalyst.expressions.{Add, Expression, GreaterThan, Literal, RuntimeReplaceable}
2222
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2323
import org.apache.spark.sql.catalyst.rules.Rule
2424
import org.apache.spark.sql.catalyst.trees.BinaryLike
25-
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
25+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
26+
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2627
import org.apache.spark.sql.internal.SQLConf
2728
import org.apache.spark.sql.test.SharedSparkSession
2829

@@ -54,6 +55,34 @@ object WrapAddWithRuntimeReplaceable extends Rule[LogicalPlan] {
5455
}
5556
}
5657

58+
/**
59+
* A test-only predicate [[RuntimeReplaceable]] (`eagerReplace = false`) whose `replacement` is a
60+
* plain [[GreaterThan]] -- a shape that `CachedBatchSerializer.buildFilter` recognizes for
61+
* cached-batch pruning. Used to verify that a surviving predicate is materialized at the pruning
62+
* consumer (`InMemoryTableScanExec`), so pruning still kicks in.
63+
*/
64+
case class TestPredicateRuntimeReplaceable(left: Expression, right: Expression)
65+
extends RuntimeReplaceable with BinaryLike[Expression] {
66+
67+
override lazy val replacement: Expression = GreaterThan(left, right)
68+
69+
override def eagerReplace: Boolean = false
70+
71+
override protected def withNewChildrenInternal(
72+
newLeft: Expression, newRight: Expression): TestPredicateRuntimeReplaceable =
73+
copy(left = newLeft, right = newRight)
74+
}
75+
76+
/**
77+
* Wraps `x > 88` into a surviving [[TestPredicateRuntimeReplaceable]], after `ReplaceExpressions`.
78+
*/
79+
object WrapGreaterThanWithRuntimeReplaceable extends Rule[LogicalPlan] {
80+
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
81+
case g: GreaterThan if g.right == Literal(88) =>
82+
TestPredicateRuntimeReplaceable(g.left, g.right)
83+
}
84+
}
85+
5786
class MaterializeRuntimeReplaceableSuite extends QueryTest with SharedSparkSession {
5887

5988
private def withExtraOptimization(rule: Rule[LogicalPlan])(f: => Unit): Unit = {
@@ -107,6 +136,63 @@ class MaterializeRuntimeReplaceableSuite extends QueryTest with SharedSparkSessi
107136
}
108137
}
109138

139+
test("SPARK-57512: a surviving RuntimeReplaceable in cached-scan predicates is materialized " +
140+
"for partition pruning under AQE") {
141+
// Find every InMemoryTableScanExec, descending through AQE query stages (the cached scan is
142+
// wrapped in a TableCacheQueryStageExec, which is a leaf to a normal tree walk).
143+
def findScans(p: SparkPlan): Seq[InMemoryTableScanExec] = p match {
144+
case in: InMemoryTableScanExec => Seq(in)
145+
case q: QueryStageExec => findScans(q.plan)
146+
case a: AdaptiveSparkPlanExec => findScans(a.executedPlan)
147+
case other => other.children.flatMap(findScans)
148+
}
149+
150+
withSQLConf(
151+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
152+
// Keep the range partitions intact so pruning is observable.
153+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false",
154+
SQLConf.COLUMN_BATCH_SIZE.key -> "10",
155+
SQLConf.IN_MEMORY_PARTITION_PRUNING.key -> "true",
156+
SQLConf.IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED.key -> "true") {
157+
withExtraOptimization(WrapGreaterThanWithRuntimeReplaceable) {
158+
import testImplicits._
159+
// `repartitionByRange` adds a shuffle, so the cached plan is adaptive and its scan is
160+
// wrapped in a `TableCacheQueryStageExec` (a leaf the AQE stage-finalization rules cannot
161+
// descend into) -- this is the case where the predicate is NOT materialized in the plan.
162+
// Range partitioning also keeps values clustered, so batch stats enable pruning.
163+
// 100 values, batch size 10 => 10 batches total across 5 partitions.
164+
val cached = sparkContext.makeRDD(1 to 100, 5).toDF("key").repartitionByRange(5, $"key")
165+
cached.cache()
166+
try {
167+
// `key > 88` is rewritten into a surviving `TestPredicateRuntimeReplaceable` and pushed
168+
// into the cached scan's `predicates`.
169+
val df = cached.filter("key > 88")
170+
checkAnswer(df, (89 to 100).map(Row(_)))
171+
172+
val scans = findScans(df.queryExecution.executedPlan)
173+
assert(scans.size == 1, s"expected one cached scan, found ${scans.size}")
174+
val scan = scans.head
175+
176+
// The scan is a leaf query stage, so its predicate is not materialized in the plan: the
177+
// surviving RuntimeReplaceable is still present. This is exactly why the consumer-side
178+
// unfold in `filteredCachedBatches` is needed.
179+
assert(
180+
scan.predicates.exists(_.exists(_.isInstanceOf[RuntimeReplaceable])),
181+
s"Expected a surviving RuntimeReplaceable in the cached scan predicates:\n$scan")
182+
183+
// Pruning kicked in: fewer than all 10 batches / 5 partitions are read. Without unfolding
184+
// the predicate, `buildFilter` would not recognize it and would scan everything.
185+
assert(scan.readBatches.value < 10,
186+
s"Expected pruning (< 10 batches read), got ${scan.readBatches.value}")
187+
assert(scan.readPartitions.value < 5,
188+
s"Expected pruning (< 5 partitions read), got ${scan.readPartitions.value}")
189+
} finally {
190+
cached.unpersist()
191+
}
192+
}
193+
}
194+
}
195+
110196
test("a surviving RuntimeReplaceable self-evaluates via its replacement") {
111197
// `eval` delegates to `replacement` as a backstop for paths that bypass
112198
// `MaterializeRuntimeReplaceable`.

0 commit comments

Comments
 (0)