Skip to content

Commit 5251cd7

Browse files
committed
[SPARK-57512][SQL] Allow RuntimeReplaceable to opt out of eager replacement and survive into the physical plan
Add `RuntimeReplaceable.eagerReplace` (default `true`, preserving current behavior). An expression can override it to `false` to survive the logical optimizer into the physical plan, so a native engine can match the high-level expression directly and the optimized plan stays readable. `ReplaceExpressions` still rewrites eagerly when the replacement cannot survive (non-deterministic, or contains an `Unevaluable` other than `AttributeReference`); `RuntimeReplaceableAggregate` is always rewritten. Supporting changes: - `eval`/`doGenCode`/`deterministic`/`foldable` delegate to `replacement`; a new physical-preparation rule `MaterializeRuntimeReplaceable` unfolds survivors before `CollapseCodegenStages` (in both `QueryExecution.preparations` and AQE `postStageCreationRules`). This is required for correctness: whole-stage codegen reasons about `references`/input materialization/CSE via `children`, while the emitted code comes from `replacement`, so an un-materialized survivor can produce invalid generated code. - `FoldablePropagation` only propagates literals; `NormalizePlan` fully unfolds. - Structure-interpreting consumers that receive a predicate leaving Spark's evaluation engine unfold survivors at their boundary: data source filter pushdown (`DataSourceStrategy.translateLeafNodeFilter` for V1, `V2ExpressionBuilder` for V2 which also covers aggregate/group-by) and cached-batch pruning (`InMemoryTableScanExec.buildFilter`). - First adopter: `MultiGetJsonObject` (inserted by `OptimizeCsvJsonExprs` after `ReplaceExpressions`) becomes a surviving `RuntimeReplaceable` with an `Invoke` replacement. Co-authored-by: Isaac
1 parent a7ad18b commit 5251cd7

18 files changed

Lines changed: 499 additions & 43 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,19 +454,59 @@ trait RuntimeReplaceable extends Expression {
454454
override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
455455
override def nullable: Boolean = replacement.nullable
456456
override def dataType: DataType = replacement.dataType
457+
// The actual evaluation is delegated to `replacement`, so determinism must reflect `replacement`,
458+
// not this expression's `children` (which are the original arguments). For example, the children
459+
// of `Uniform` are literal bounds and a seed (all deterministic), while its `replacement` is a
460+
// non-deterministic `Rand`. This matters once a `RuntimeReplaceable` may survive into the
461+
// physical plan (see `eagerReplace`): the survival decision relies on an accurate determinism
462+
// signal.
463+
override lazy val deterministic: Boolean = replacement.deterministic
464+
// Foldability is also derived from `replacement` rather than this expression's `children`. Note
465+
// that this can yield a foldable expression that still has references (e.g. `collation(c1)`,
466+
// whose value depends only on the child's type): such an expression is materialized into a
467+
// literal by `ConstantFolding`, and `FoldablePropagation` only propagates literals, never bare
468+
// foldables.
469+
override def foldable: Boolean = replacement.foldable
457470
// As this expression gets replaced at optimization with its `child" expression,
458471
// two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions
459472
// are semantically equal.
460473
override lazy val canonicalized: Expression = replacement.canonicalized
461474

462-
final override def eval(input: InternalRow = null): Any = {
463-
// For convenience, we allow to evaluate `RuntimeReplaceable` expressions, in case we need to
464-
// get a constant from foldable expression before the query execution starts.
465-
assert(input == null)
466-
replacement.eval()
475+
// Whether `ReplaceExpressions` should rewrite this expression into its `replacement` eagerly, in
476+
// the logical optimizer. This is `true` by default, which preserves the historical behavior where
477+
// a `RuntimeReplaceable` never reaches the physical plan. An expression can override this to
478+
// `false` to survive into the physical plan (e.g. so a native engine can match the high-level
479+
// expression directly); such a survivor is materialized into its `replacement` right before
480+
// codegen by `MaterializeRuntimeReplaceable`. Note that an expression that opts out can still be
481+
// rewritten eagerly if its `replacement` cannot survive (non-deterministic or unevaluable); see
482+
// `ReplaceExpressions`.
483+
def eagerReplace: Boolean = true
484+
485+
// `RuntimeReplaceable` expressions are normally rewritten into their `replacement` by the
486+
// `ReplaceExpressions` rule before execution. However, an expression with `eagerReplace = false`
487+
// survives into the physical plan, and a `RuntimeReplaceable` may also be produced *after*
488+
// `ReplaceExpressions` has run (e.g. by an optimizer rule). To keep such an expression evaluable
489+
// without depending on the rewrite, both `eval` and `doGenCode` delegate to `replacement`. As
490+
// `replacement` is derived from this expression's children, it is bound and code-generated
491+
// together with them, so the delegation observes the same input row.
492+
final override def eval(input: InternalRow = null): Any = replacement.eval(input)
493+
494+
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
495+
val childGen = replacement.genCode(ctx)
496+
ev.copy(code = childGen.code, isNull = childGen.isNull, value = childGen.value)
497+
}
498+
}
499+
500+
object RuntimeReplaceable {
501+
/**
502+
* Fully unfolds every [[RuntimeReplaceable]] in `e` into its `replacement`, recursively. The
503+
* result contains no [[RuntimeReplaceable]] node. This is the canonical "materialize" transform
504+
* shared by `MaterializeRuntimeReplaceable` (physical plan) and `NormalizePlan` (comparison).
505+
*/
506+
def unfold(e: Expression): Expression = e match {
507+
case r: RuntimeReplaceable => unfold(r.replacement)
508+
case _ => e.mapChildren(unfold)
467509
}
468-
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
469-
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
470510
}
471511

472512
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,11 @@ case class AggregateExpression(
206206
*/
207207
abstract class AggregateFunction extends Expression {
208208

209-
/** An aggregate function is not foldable. */
210-
final override def foldable: Boolean = false
209+
// An aggregate function is not foldable. This is not `final` so that
210+
// `RuntimeReplaceableAggregate` can inherit `RuntimeReplaceable.foldable` (which delegates to
211+
// `replacement`); since such a replacement is itself an aggregate, the effective foldability
212+
// stays `false`.
213+
override def foldable: Boolean = false
211214

212215
/** The schema of the aggregation buffer. */
213216
def aggBufferSchema: StructType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ case class MultiGetJsonObject(
164164
fieldNames: Seq[String],
165165
fallbackPaths: Seq[String])
166166
extends UnaryExpression
167+
with RuntimeReplaceable
167168
with ExpectsInputTypes {
168169

169170
require(
@@ -182,36 +183,29 @@ case class MultiGetJsonObject(
182183

183184
override def nullable: Boolean = true
184185

185-
// This internal unary expression always returns null when its JSON child is null.
186-
override def nullIntolerant: Boolean = true
187-
188186
override def prettyName: String = "multi_get_json_object"
189187

190-
final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT)
188+
final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT, RUNTIME_REPLACEABLE)
189+
190+
// This expression is produced by `OptimizeCsvJsonExprs`, which runs after `ReplaceExpressions`,
191+
// so it must opt out of eager replacement to survive into the physical plan. It keeps the
192+
// readable `multi_get_json_object` node in the optimized plan and is materialized into its
193+
// `replacement` just before codegen by `MaterializeRuntimeReplaceable`.
194+
override def eagerReplace: Boolean = false
191195

192196
@transient
193197
private lazy val evaluator = MultiGetJsonObjectEvaluator(
194198
fieldNames,
195199
fallbackPaths.map(UTF8String.fromString))
196200

197-
override def eval(input: InternalRow): Any = {
198-
evaluator.evaluate(json.eval(input).asInstanceOf[UTF8String])
199-
}
200-
201-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
202-
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
203-
val jsonEval = json.genCode(ctx)
204-
val resultType = CodeGenerator.javaType(dataType)
205-
ev.copy(code = code"""
206-
|${jsonEval.code}
207-
|boolean ${ev.isNull} = ${jsonEval.isNull};
208-
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
209-
|if (!${ev.isNull}) {
210-
| ${ev.value} = ($resultType) $refEvaluator.evaluate(${jsonEval.value});
211-
| ${ev.isNull} = ${ev.value} == null;
212-
|}
213-
|""".stripMargin)
214-
}
201+
// Delegates evaluation to `MultiGetJsonObjectEvaluator`. `Invoke`'s default `propagateNull`
202+
// returns null when the JSON child is null, matching the original null-intolerant behavior.
203+
override def replacement: Expression = Invoke(
204+
Literal.create(evaluator, ObjectType(classOf[MultiGetJsonObjectEvaluator])),
205+
"evaluate",
206+
dataType,
207+
Seq(json),
208+
Seq(json.dataType))
215209

216210
override protected def withNewChildInternal(newChild: Expression): MultiGetJsonObject =
217211
copy(json = newChild)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ case class Uniform(
228228
def this(min: Expression, max: Expression, seedExpression: Expression) =
229229
this(min, max, seedExpression, hideSeed = false)
230230

231-
final override lazy val deterministic: Boolean = false
231+
// `deterministic` is inherited from `RuntimeReplaceable`, which delegates to `replacement` (a
232+
// non-deterministic `Rand`-based expression unless an argument is null).
232233
override def nodePatternsInternal(): Seq[TreePattern] =
233234
Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
234235

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,13 @@ object FoldablePropagation extends Rule[LogicalPlan] {
11181118

11191119
private def collectFoldables(expressions: Seq[NamedExpression]) = {
11201120
AttributeMap(expressions.collect {
1121-
case a: Alias if a.child.foldable => (a.toAttribute, a)
1121+
// Only propagate literals. A foldable expression is not necessarily a self-contained
1122+
// constant: a `RuntimeReplaceable` such as `collation(c1)` is foldable (its value depends
1123+
// only on the child's type) yet still references its children, so substituting it for its
1124+
// alias elsewhere would leave those references dangling. Such expressions are turned into
1125+
// literals by `ConstantFolding` (in the same fixed-point batch), after which they propagate
1126+
// safely.
1127+
case a: Alias if a.child.isInstanceOf[Literal] => (a.toAttribute, a)
11221128
})
11231129
}
11241130

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,28 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
5353
}
5454

5555
private def replace(e: Expression): Expression = e match {
56-
case r: RuntimeReplaceable => replace(r.replacement)
56+
// Aggregates can never self-evaluate (no real aggregation buffer), so always rewrite early.
57+
case r: RuntimeReplaceableAggregate => replace(r.replacement)
58+
59+
case r: RuntimeReplaceable =>
60+
val replaced = replace(r.replacement)
61+
// By default (`eagerReplace = true`) a `RuntimeReplaceable` is rewritten here, so it never
62+
// reaches the physical plan. An expression with `eagerReplace = false` is instead allowed to
63+
// survive into the physical plan (to be matched by a native engine or materialized just
64+
// before codegen). Even then, a survivor must be rewritten early if its replacement cannot
65+
// survive:
66+
// - A non-deterministic replacement (e.g. the `Rand` inside `uniform`) carries mutable
67+
// per-partition state that must be initialized before eval. That state cannot be reliably
68+
// initialized through the `lazy val replacement`, which tree transforms may re-create.
69+
// - A replacement that contains an `Unevaluable` expression (e.g. `With`) depends on a later
70+
// logical rule (such as `RewriteWithExpression`) that can only run in the logical phase.
71+
// `AttributeReference` is `Unevaluable` too but is bound at execution like any input
72+
// column, so it does not block survival -- we reuse `ConvertToLocalRelation`'s check,
73+
// which already excludes it.
74+
val cannotSurvive =
75+
!replaced.deterministic || ConvertToLocalRelation.hasUnevaluableExpr(replaced)
76+
if (r.eagerReplace || cannotSurvive) replaced else r
77+
5778
case _ => e.mapChildren(replace)
5879
}
5980
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.analysis.NormalizeableRelation
2323
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverTag
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
26-
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
2726
import org.apache.spark.sql.catalyst.plans.logical._
28-
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
27+
import org.apache.spark.sql.catalyst.trees.TreePattern.{PLAN_EXPRESSION, RUNTIME_REPLACEABLE}
2928

3029
/**
3130
* Object that handles normalization of operators and expressions. Used when comparing plans.
@@ -78,8 +77,18 @@ object NormalizePlan extends PredicateHelper {
7877
* [[InheritAnalysisRules]] is the replacement expression, the original expression will be lost
7978
* and timezone will never be applied. This causes inconsistencies, because fixed-point semantic
8079
* is to ALWAYS apply timezone, regardless of whether the Cast actually needs it.
80+
*
81+
* Note: this unconditionally unfolds every [[RuntimeReplaceable]] into its `replacement`. It
82+
* intentionally does NOT reuse the `ReplaceExpressions` optimizer rule, which now keeps
83+
* deterministic, evaluable [[RuntimeReplaceable]] nodes in the plan (they are materialized later,
84+
* before codegen). Normalization still needs the fully-unfolded form so that the non-child
85+
* `parameters` of [[InheritAnalysisRules]] (e.g. the original, un-timezoned cast) are dropped.
8186
*/
82-
def normalizeRuntimeReplaceable(plan: LogicalPlan): LogicalPlan = ReplaceExpressions(plan)
87+
def normalizeRuntimeReplaceable(plan: LogicalPlan): LogicalPlan = {
88+
plan.transformWithPruning(_.containsAnyPattern(RUNTIME_REPLACEABLE)) {
89+
case p => p.mapExpressions(RuntimeReplaceable.unfold)
90+
}
91+
}
8392

8493
/**
8594
* Since attribute references are given globally unique ids during analysis,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
376376
} else {
377377
None
378378
}
379+
// A surviving `RuntimeReplaceable` (`eagerReplace = false`) that no explicit case above pushes
380+
// natively is a Spark-internal optimizer node the connector cannot understand. Fall back to its
381+
// concrete `replacement` so the lowered form can still be pushed -- same boundary rationale as
382+
// `DataSourceStrategy.translateLeafNodeFilter` (V1) and `CachedBatchSerializer.buildFilter`.
383+
case r: RuntimeReplaceable => generateExpression(r.replacement, isPredicate)
379384
case _ => None
380385
}
381386

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ import org.apache.spark.SparkException
2121
import org.apache.spark.sql.AnalysisException
2222
import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.dsl.plans._
24-
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, NullIf, Remainder, RuntimeReplaceable}
24+
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, Expression, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, NullIf, Remainder, RuntimeReplaceable}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
2626
import org.apache.spark.sql.catalyst.plans.PlanTest
2727
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
2828
import org.apache.spark.sql.catalyst.rules.Rule
29+
import org.apache.spark.sql.catalyst.trees.UnaryLike
2930
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, MapType, StructField, StructType}
3132

@@ -38,6 +39,23 @@ object DecrementLiterals extends Rule[LogicalPlan] {
3839
}
3940
}
4041

42+
/**
43+
* A test-only [[RuntimeReplaceable]] that opts out of eager replacement and whose `replacement`
44+
* references an input column (via an [[AttributeReference]], which is `Unevaluable`). Used to check
45+
* that `ReplaceExpressions` lets such a column-based survivor through, rather than treating the
46+
* `AttributeReference` as a reason to rewrite eagerly.
47+
*/
48+
case class ColumnBasedRuntimeReplaceable(child: Expression)
49+
extends RuntimeReplaceable with UnaryLike[Expression] {
50+
51+
override lazy val replacement: Expression = Add(child, Literal(1))
52+
53+
override def eagerReplace: Boolean = false
54+
55+
override protected def withNewChildInternal(
56+
newChild: Expression): ColumnBasedRuntimeReplaceable = copy(child = newChild)
57+
}
58+
4159
class OptimizerSuite extends PlanTest {
4260
test("Optimizer exceeds max iterations") {
4361
val iterations = 5
@@ -354,4 +372,24 @@ class OptimizerSuite extends PlanTest {
354372
assert(optimized.expressions.forall(!_.exists(_.isInstanceOf[RuntimeReplaceable])))
355373
}
356374
}
375+
376+
test("SPARK-57512: a RuntimeReplaceable with a column-based replacement survives " +
377+
"ReplaceExpressions") {
378+
val optimizer = new SimpleTestOptimizer() {
379+
override def defaultBatches: Seq[Batch] =
380+
Batch("test", fixedPoint,
381+
ReplaceExpressions) :: Nil
382+
}
383+
384+
val relation = LocalRelation($"a".int)
385+
val wrapper = ColumnBasedRuntimeReplaceable(relation.output.head)
386+
val plan = Project(Alias(wrapper, "out")() :: Nil, relation).analyze
387+
val optimized = optimizer.execute(plan)
388+
389+
// The replacement references an input column (an `AttributeReference`, which is `Unevaluable`),
390+
// but the survivor opts out of eager replacement, so `ReplaceExpressions` keeps it in the plan
391+
// instead of rewriting it into the bare `Add`.
392+
assert(optimized.expressions.exists(_.exists(_.isInstanceOf[ColumnBasedRuntimeReplaceable])),
393+
s"Expected the column-based RuntimeReplaceable to survive:\n$optimized")
394+
}
357395
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.catalyst.trees.TreePattern.RUNTIME_REPLACEABLE
23+
24+
/**
25+
* Materializes any [[RuntimeReplaceable]] that survived the logical optimizer into its
26+
* `replacement`, on the physical plan.
27+
*
28+
* A `RuntimeReplaceable` with `eagerReplace = false` is intentionally kept in the plan by the
29+
* optimizer (see `ReplaceExpressions`) so that a native engine can match the high-level expression
30+
* directly. This rule then materializes the replacement for the Spark execution path. It is placed
31+
* after the columnar/native conversion and before `CollapseCodegenStages`, so a native engine sees
32+
* the original expression while Spark whole-stage codegen never sees a `RuntimeReplaceable`.
33+
*
34+
* Materializing before codegen is a correctness requirement, not just cleanup. A surviving
35+
* `RuntimeReplaceable` evaluates correctly on its own (`eval`/`doGenCode` delegate to
36+
* `replacement`), but whole-stage codegen reasons about `references`, input materialization, and
37+
* subexpression elimination via the node's `children`, while the emitted code comes from
38+
* `replacement`. When `replacement` reads an input differently from `children` (e.g. an
39+
* un-simplified branch that reads a column more than once), that mismatch produces invalid
40+
* generated code. Eager replacement avoids this because the unfolded form is then simplified by the
41+
* optimizer; a survivor's `replacement` is not, so it must be unfolded before codegen.
42+
*
43+
* This runs wherever a physical plan is finalized into a codegen-bearing form:
44+
* `QueryExecution.preparations` (non-AQE) and `AdaptiveSparkPlanExec.postStageCreationRules` (AQE,
45+
* applied to every codegen-producing stage). A `RuntimeReplaceable` that only feeds a structural,
46+
* non-codegen consumer is materialized at that consumer instead -- see the cached-batch pruning
47+
* predicates in `InMemoryTableScanExec`, whose leaf scan never reaches codegen and is unreachable
48+
* from AQE stage finalization.
49+
*/
50+
object MaterializeRuntimeReplaceable extends Rule[SparkPlan] {
51+
override def apply(plan: SparkPlan): SparkPlan = plan.transformUpWithSubqueries {
52+
case p if p.expressions.exists(_.containsPattern(RUNTIME_REPLACEABLE)) =>
53+
p.mapExpressions(RuntimeReplaceable.unfold)
54+
}
55+
}

0 commit comments

Comments
 (0)