|
18 | 18 | package org.apache.spark.sql.execution |
19 | 19 |
|
20 | 20 | import org.apache.spark.sql.{QueryTest, Row} |
21 | | -import org.apache.spark.sql.catalyst.expressions.{Add, Expression, Literal, RuntimeReplaceable} |
| 21 | +import org.apache.spark.sql.catalyst.expressions.{Add, Expression, GreaterThan, Literal, RuntimeReplaceable} |
22 | 22 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
23 | 23 | import org.apache.spark.sql.catalyst.rules.Rule |
24 | 24 | import org.apache.spark.sql.catalyst.trees.BinaryLike |
25 | | -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec |
| 25 | +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} |
| 26 | +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec |
26 | 27 | import org.apache.spark.sql.internal.SQLConf |
27 | 28 | import org.apache.spark.sql.test.SharedSparkSession |
28 | 29 |
|
@@ -54,6 +55,34 @@ object WrapAddWithRuntimeReplaceable extends Rule[LogicalPlan] { |
54 | 55 | } |
55 | 56 | } |
56 | 57 |
|
| 58 | +/** |
| 59 | + * A test-only predicate [[RuntimeReplaceable]] (`eagerReplace = false`) whose `replacement` is a |
| 60 | + * plain [[GreaterThan]] -- a shape that `CachedBatchSerializer.buildFilter` recognizes for |
| 61 | + * cached-batch pruning. Used to verify that a surviving predicate is materialized at the pruning |
| 62 | + * consumer (`InMemoryTableScanExec`), so pruning still kicks in. |
| 63 | + */ |
| 64 | +case class TestPredicateRuntimeReplaceable(left: Expression, right: Expression) |
| 65 | + extends RuntimeReplaceable with BinaryLike[Expression] { |
| 66 | + |
| 67 | + override lazy val replacement: Expression = GreaterThan(left, right) |
| 68 | + |
| 69 | + override def eagerReplace: Boolean = false |
| 70 | + |
| 71 | + override protected def withNewChildrenInternal( |
| 72 | + newLeft: Expression, newRight: Expression): TestPredicateRuntimeReplaceable = |
| 73 | + copy(left = newLeft, right = newRight) |
| 74 | +} |
| 75 | + |
| 76 | +/** |
| 77 | + * Wraps `x > 88` into a surviving [[TestPredicateRuntimeReplaceable]], after `ReplaceExpressions`. |
| 78 | + */ |
| 79 | +object WrapGreaterThanWithRuntimeReplaceable extends Rule[LogicalPlan] { |
| 80 | + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { |
| 81 | + case g: GreaterThan if g.right == Literal(88) => |
| 82 | + TestPredicateRuntimeReplaceable(g.left, g.right) |
| 83 | + } |
| 84 | +} |
| 85 | + |
57 | 86 | class MaterializeRuntimeReplaceableSuite extends QueryTest with SharedSparkSession { |
58 | 87 |
|
59 | 88 | private def withExtraOptimization(rule: Rule[LogicalPlan])(f: => Unit): Unit = { |
@@ -107,6 +136,63 @@ class MaterializeRuntimeReplaceableSuite extends QueryTest with SharedSparkSessi |
107 | 136 | } |
108 | 137 | } |
109 | 138 |
|
| 139 | + test("SPARK-57512: a surviving RuntimeReplaceable in cached-scan predicates is materialized " + |
| 140 | + "for partition pruning under AQE") { |
| 141 | + // Find every InMemoryTableScanExec, descending through AQE query stages (the cached scan is |
| 142 | + // wrapped in a TableCacheQueryStageExec, which is a leaf to a normal tree walk). |
| 143 | + def findScans(p: SparkPlan): Seq[InMemoryTableScanExec] = p match { |
| 144 | + case in: InMemoryTableScanExec => Seq(in) |
| 145 | + case q: QueryStageExec => findScans(q.plan) |
| 146 | + case a: AdaptiveSparkPlanExec => findScans(a.executedPlan) |
| 147 | + case other => other.children.flatMap(findScans) |
| 148 | + } |
| 149 | + |
| 150 | + withSQLConf( |
| 151 | + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", |
| 152 | + // Keep the range partitions intact so pruning is observable. |
| 153 | + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false", |
| 154 | + SQLConf.COLUMN_BATCH_SIZE.key -> "10", |
| 155 | + SQLConf.IN_MEMORY_PARTITION_PRUNING.key -> "true", |
| 156 | + SQLConf.IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED.key -> "true") { |
| 157 | + withExtraOptimization(WrapGreaterThanWithRuntimeReplaceable) { |
| 158 | + import testImplicits._ |
| 159 | + // `repartitionByRange` adds a shuffle, so the cached plan is adaptive and its scan is |
| 160 | + // wrapped in a `TableCacheQueryStageExec` (a leaf the AQE stage-finalization rules cannot |
| 161 | + // descend into) -- this is the case where the predicate is NOT materialized in the plan. |
| 162 | + // Range partitioning also keeps values clustered, so batch stats enable pruning. |
| 163 | + // 100 values, batch size 10 => 10 batches total across 5 partitions. |
| 164 | + val cached = sparkContext.makeRDD(1 to 100, 5).toDF("key").repartitionByRange(5, $"key") |
| 165 | + cached.cache() |
| 166 | + try { |
| 167 | + // `key > 88` is rewritten into a surviving `TestPredicateRuntimeReplaceable` and pushed |
| 168 | + // into the cached scan's `predicates`. |
| 169 | + val df = cached.filter("key > 88") |
| 170 | + checkAnswer(df, (89 to 100).map(Row(_))) |
| 171 | + |
| 172 | + val scans = findScans(df.queryExecution.executedPlan) |
| 173 | + assert(scans.size == 1, s"expected one cached scan, found ${scans.size}") |
| 174 | + val scan = scans.head |
| 175 | + |
| 176 | + // The scan is a leaf query stage, so its predicate is not materialized in the plan: the |
| 177 | + // surviving RuntimeReplaceable is still present. This is exactly why the consumer-side |
| 178 | + // unfold in `filteredCachedBatches` is needed. |
| 179 | + assert( |
| 180 | + scan.predicates.exists(_.exists(_.isInstanceOf[RuntimeReplaceable])), |
| 181 | + s"Expected a surviving RuntimeReplaceable in the cached scan predicates:\n$scan") |
| 182 | + |
| 183 | + // Pruning kicked in: fewer than all 10 batches / 5 partitions are read. Without unfolding |
| 184 | + // the predicate, `buildFilter` would not recognize it and would scan everything. |
| 185 | + assert(scan.readBatches.value < 10, |
| 186 | + s"Expected pruning (< 10 batches read), got ${scan.readBatches.value}") |
| 187 | + assert(scan.readPartitions.value < 5, |
| 188 | + s"Expected pruning (< 5 partitions read), got ${scan.readPartitions.value}") |
| 189 | + } finally { |
| 190 | + cached.unpersist() |
| 191 | + } |
| 192 | + } |
| 193 | + } |
| 194 | + } |
| 195 | + |
110 | 196 | test("a surviving RuntimeReplaceable self-evaluates via its replacement") { |
111 | 197 | // `eval` delegates to `replacement` as a backstop for paths that bypass |
112 | 198 | // `MaterializeRuntimeReplaceable`. |
|
0 commit comments