Skip to content

Commit f4bec83

Browse files
committed
[SPARK-57003][SQL][SS] Widen stateful operator output and state schema nullability
### What changes were proposed in this pull request? Introduce a three-component fix for stateful-operator nullability drift, gated by `spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled` (pinned per-query via the offset log): - (a) `WidenStatefulOpNullability.widenStateSchema`: every stateful physical exec widens its state key/value schema to fully nullable at construction. This covers `StateStoreSaveExec`, `BaseStreamingDeduplicateExec`, `StreamingSymmetricHashJoinExec`, `FlatMapGroupsWithStateExec`, `TransformWithStateExec` (including user-defined state variable col family schemas), `TransformWithStateInPySparkExec`, and `StreamingGlobalLimitExec`. - (b) `WidenStatefulOpNullability.widenOutputForStatefulOp`: every stateful logical and physical operator widens its declared `output` to fully nullable. - (c) `WidenStatefulOperatorAttributeNullability`: an optimizer rule that widens `AttributeReference`s inside stateful ops' internal expressions and propagates upward through ancestor expressions. The rule uses `resolveOperatorsUp` (bottom-up) and scopes the widening precisely: at a stateful operator, all children's output is included (for internal expression references like grouping keys); at non-stateful ancestors, only children whose subtrees contain a stateful operator are included, avoiding unnecessary widening of non-stateful siblings. The node's own `p.output` is excluded for non-stateful ancestors because the bottom-up traversal guarantees children are already transformed. With the above fix, we aim to ensure the state schema to be "fully" nullable (top level column, nested column, and collection types) regardless of the input schema, and the output schema of the stateful operator to be also "fully" nullable as well. The change of output schema for stateful operator is necessary, because even if the input schema is non-nullable, state can produce the null value, hence the output can be nullable. ### Why are the changes needed? This has been a long standing issue of streaming engine vs Query Optimizer. By the nature of streaming query, the query is meant to be long-running, in many cases spans to multiple Spark versions. Also, the logical plan is not always the same across batches (e.g. there are multiple stream sources and one of the source does not have a new data at batch N). This puts the streaming query to be affected by analyzer and optimizer. The state schema of stateful operator is mostly determined by the input schema of the stateful operator, and nullability isn't an exception. If the input schema has a nullable column, state schema would have a nullable column. Vice versa with non-nullable column. For Query Optimizer, one of the optimizations is to flip the nullability, say, nullable to non-nullable if appropriate. This can be done directly or indirectly, and the most problematic case is when the optimization is applied "selectively". The one of easy example is the elimination of Union: for the streaming query with multiple streams using Union, batch N could have one stream be non-empty while another stream to be empty. For that case,`PropagateEmptyRelation` can drop empty `Union` branches, causing a per-column nullability flip that propagates into a stateful operator's state schema across microbatches or restarts. This causes either `STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE` on restart or a codegen NPE when state-restored rows carry nulls in columns declared non-nullable. ### Does this PR introduce _any_ user-facing change? No user-visible behavior change for new queries (all stateful operator outputs become nullable, which is semantically correct). Existing queries keep their original behavior via the offset log gate. ### How was this patch tested? New `StreamingStatefulOperatorNullabilityDriftSuite` covering: - New-query path: Union-branch-drop restart scenarios for aggregate, dropDuplicates, dropDuplicatesWithinWatermark, stream-stream join, flatMapGroupsWithState, and transformWithState. - Codegen NPE regression with struct grouping keys. - Existing-query path: widening forced off still triggers schema mismatch. - State schema assertion validates all state stores and column families (both v2 file format and v3 directory format including `_stateSchema`). - Rule-level: scope check (non-stateful subtrees skipped). - Helper-level: `deepWidenAttribute` recursion into nested types. ### Was this patch authored or co-authored using generative AI tooling? Yes. Generated-by: Claude 4.7 Opus Closes #56061 from HeartSaVioR/widen-stateful-op-nullability. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com> (cherry picked from commit 0fb04a4) Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent dcaa427 commit f4bec83

20 files changed

Lines changed: 976 additions & 93 deletions

File tree

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId}
21+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.{DataType, StructType}
25+
26+
/**
27+
* Shared helpers for the stateful-operator nullability fix. The fix has three
28+
* independent components, all gated by
29+
* [[SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT]] (pinned per-query via the
30+
* offset log so existing queries keep their pre-fix behavior on restart):
31+
*
32+
* - (a) `widenStateSchema`: explicit `asNullable` at every state-schema construction
33+
* site in each stateful physical exec.
34+
* - (b) `widenOutputForStatefulOp`: a per-op `output` override on every stateful logical
35+
* and physical operator, used by the operator's `output` definition.
36+
* - (c) [[WidenStatefulOperatorAttributeNullability]] (defined below in this file): a
37+
* custom optimizer rule that widens `AttributeReference`s inside stateful ops'
38+
* internal expressions and propagates upward to ancestor expressions.
39+
*/
40+
object WidenStatefulOpNullability {
41+
42+
def isEnabled: Boolean =
43+
SQLConf.get.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT)
44+
45+
/**
46+
* Recursively widens an attribute to be fully nullable: outer `nullable = true` plus
47+
* every nested `StructField.nullable`, `ArrayType.containsNull`, and
48+
* `MapType.valueContainsNull` flipped to `true` via
49+
* [[org.apache.spark.sql.types.DataType#asNullable]].
50+
*/
51+
def deepWidenAttribute(a: Attribute): Attribute = a match {
52+
case ref: AttributeReference =>
53+
AttributeReference(
54+
ref.name, ref.dataType.asNullable, nullable = true, ref.metadata)(
55+
ref.exprId, ref.qualifier)
56+
case other => other.withNullability(true)
57+
}
58+
59+
/**
60+
* Component (a): widens a state schema to fully nullable. Stateful physical execs apply
61+
* this at every `validateAndMaybeEvolveStateSchema(...)` call site and every
62+
* `mapPartitionsWith*StateStore(...)` call site. When the conf is off, returns the
63+
* schema unchanged.
64+
*/
65+
def widenStateSchema(schema: StructType): StructType =
66+
if (isEnabled) schema.asNullable else schema
67+
68+
/**
69+
* Component (b): wraps a stateful operator's `output` to be fully nullable. The caller
70+
* is responsible for only calling this from within an `output` definition on a stateful
71+
* operator; gating is handled here via [[isEnabled]].
72+
*/
73+
def widenOutputForStatefulOp(base: Seq[Attribute]): Seq[Attribute] =
74+
if (isEnabled) base.map(deepWidenAttribute) else base
75+
76+
/**
77+
* Recursively walks a schema and replaces any nested `StructType` that
78+
* structurally matches `original` (by field names and base types, ignoring
79+
* nullability) with `widened`. Used by TransformWithState execs to widen
80+
* the grouping-key portion of col-family key schemas without touching
81+
* user-defined key/value portions.
82+
*/
83+
def widenGroupingKeyInSchema(
84+
schema: StructType,
85+
original: StructType,
86+
widened: StructType): StructType = {
87+
if (!isEnabled) return schema
88+
if (DataType.equalsIgnoreNullability(schema, original)) {
89+
widened
90+
} else {
91+
StructType(schema.fields.map { field =>
92+
field.dataType match {
93+
case st: StructType
94+
if DataType.equalsIgnoreNullability(st, original) =>
95+
field.copy(dataType = widened)
96+
case st: StructType =>
97+
field.copy(dataType =
98+
widenGroupingKeyInSchema(st, original, widened))
99+
case _ => field
100+
}
101+
})
102+
}
103+
}
104+
}
105+
106+
/**
107+
* Component (c) of the stateful-operator nullability fix: a custom optimizer rule that
108+
* widens `AttributeReference`s inside streaming-stateful operators' internal expressions
109+
* and propagates the widening upward to ancestor operators' expressions.
110+
*
111+
* The rule does NOT introduce any new logical or physical node. It is purely an
112+
* attribute-rewrite pass using `resolveOperatorsUp` (bottom-up): for every node whose
113+
* subtree contains a stateful operator, collect `exprId`s from children's output, then
114+
* deep-widen every `AttributeReference` in the node's expressions whose `exprId` is in
115+
* that set via [[WidenStatefulOpNullability#deepWidenAttribute]].
116+
*
117+
* At a stateful operator itself, all children's output attributes are included because
118+
* the operator's internal expressions (e.g. grouping keys) reference them directly.
119+
* At non-stateful ancestor operators, only children whose subtrees contain a stateful
120+
* operator are included, to avoid unnecessary widening of non-stateful siblings.
121+
* The node's own `p.output` is not needed for non-stateful ancestors because the
122+
* bottom-up traversal guarantees children are already transformed, so their output
123+
* attributes are already nullable and the ancestor's expressions reference those
124+
* children's `exprId`s.
125+
*
126+
* '''Scope.''' The walk only fires on nodes whose subtree contains a stateful operator.
127+
*
128+
* '''Ordering constraint.''' This rule must run AFTER every `UpdateAttributeNullability`
129+
* invocation in both the main optimizer and AQE.
130+
*
131+
* '''Idempotence.''' [[WidenStatefulOpNullability#deepWidenAttribute]] is idempotent.
132+
*/
133+
object WidenStatefulOperatorAttributeNullability extends Rule[LogicalPlan] {
134+
135+
override def apply(plan: LogicalPlan): LogicalPlan = {
136+
if (!conf.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) ||
137+
!plan.containsStatefulOperator) {
138+
return plan
139+
}
140+
plan.resolveOperatorsUp {
141+
case p if !p.resolved => p
142+
case p: LeafNode => p
143+
case p if !p.containsStatefulOperator => p
144+
case p =>
145+
val widenableAttrs = if (p.isStateful) {
146+
p.output ++ p.children.flatMap(_.output)
147+
} else {
148+
p.children.filter(_.containsStatefulOperator).flatMap(_.output)
149+
}
150+
val widenableExprIds: Set[ExprId] = widenableAttrs
151+
.iterator.collect { case ar: AttributeReference => ar.exprId }.toSet
152+
if (widenableExprIds.isEmpty) {
153+
p
154+
} else {
155+
p.transformExpressions {
156+
case ar: AttributeReference if widenableExprIds.contains(ar.exprId) =>
157+
val widened = WidenStatefulOpNullability.deepWidenAttribute(ar)
158+
if (ar.dataType == widened.dataType && ar.nullable == widened.nullable) {
159+
ar
160+
} else {
161+
widened
162+
}
163+
}
164+
}
165+
}
166+
}
167+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.{AliasIdentifier, InternalRow, SQLConfHelper}
21-
import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion, MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, UnresolvedUnaryNode}
21+
import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion, MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase, UnresolvedUnaryNode, WidenStatefulOpNullability}
2222
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
2323
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
2424
import org.apache.spark.sql.catalyst.expressions._
@@ -746,7 +746,10 @@ case class Join(
746746
}
747747
}
748748

749-
override def output: Seq[Attribute] = Join.computeOutput(joinType, left.output, right.output)
749+
override def output: Seq[Attribute] = {
750+
val base = Join.computeOutput(joinType, left.output, right.output)
751+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
752+
}
750753

751754
override def metadataOutput: Seq[Attribute] = {
752755
joinType match {
@@ -1226,7 +1229,10 @@ case class Aggregate(
12261229
expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
12271230
}
12281231

1229-
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
1232+
override def output: Seq[Attribute] = {
1233+
val base = aggregateExpressions.map(_.toAttribute)
1234+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
1235+
}
12301236
override def metadataOutput: Seq[Attribute] = Nil
12311237
override def maxRows: Option[Long] = {
12321238
if (groupingExpressions.isEmpty) {
@@ -1750,7 +1756,10 @@ object Limit {
17501756
* order.
17511757
*/
17521758
case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
1753-
override def output: Seq[Attribute] = child.output
1759+
override def output: Seq[Attribute] = {
1760+
val base = child.output
1761+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
1762+
}
17541763
override def maxRows: Option[Long] = {
17551764
limitExpr match {
17561765
case IntegerLiteral(limit) => Some(limit)
@@ -2005,7 +2014,10 @@ case class Sample(
20052014
*/
20062015
case class Distinct(child: LogicalPlan) extends UnaryNode {
20072016
override def maxRows: Option[Long] = child.maxRows
2008-
override def output: Seq[Attribute] = child.output
2017+
override def output: Seq[Attribute] = {
2018+
val base = child.output
2019+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
2020+
}
20092021
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
20102022
override protected def withNewChildInternal(newChild: LogicalPlan): Distinct =
20112023
copy(child = newChild)
@@ -2175,7 +2187,10 @@ case class Deduplicate(
21752187
keys: Seq[Attribute],
21762188
child: LogicalPlan) extends UnaryNode {
21772189
override def maxRows: Option[Long] = child.maxRows
2178-
override def output: Seq[Attribute] = child.output
2190+
override def output: Seq[Attribute] = {
2191+
val base = child.output
2192+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
2193+
}
21792194
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
21802195
override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate =
21812196
copy(child = newChild)
@@ -2187,7 +2202,10 @@ case class DeduplicateWithinWatermark(keys: Seq[Attribute], child: LogicalPlan)
21872202
override def references: AttributeSet = AttributeSet(keys) ++
21882203
AttributeSet(child.output.filter(_.metadata.contains(EventTimeWatermark.delayKey)))
21892204
override def maxRows: Option[Long] = child.maxRows
2190-
override def output: Seq[Attribute] = child.output
2205+
override def output: Seq[Attribute] = {
2206+
val base = child.output
2207+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
2208+
}
21912209
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
21922210
override protected def withNewChildInternal(newChild: LogicalPlan): DeduplicateWithinWatermark =
21932211
copy(child = newChild)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
2020
import org.apache.spark.api.java.function.FilterFunction
2121
import org.apache.spark.broadcast.Broadcast
2222
import org.apache.spark.sql.{catalyst, Encoder, Row}
23-
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedDeserializer}
23+
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedDeserializer, WidenStatefulOpNullability}
2424
import org.apache.spark.sql.catalyst.encoders._
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
@@ -568,6 +568,11 @@ case class FlatMapGroupsWithState(
568568
newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
569569
copy(child = newLeft, initialState = newRight)
570570
override def isStateful: Boolean = child.isStreaming
571+
572+
override def output: Seq[Attribute] = {
573+
val base = super.output
574+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
575+
}
571576
}
572577

573578
object TransformWithState {
@@ -657,6 +662,11 @@ case class TransformWithState(
657662
newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
658663
copy(child = newLeft, initialState = newRight)
659664
override def isStateful: Boolean = child.isStreaming
665+
666+
override def output: Seq[Attribute] = {
667+
val base = super.output
668+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base) else base
669+
}
660670
}
661671

662672
/** Factory for constructing new `FlatMapGroupsInR` nodes. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.resource.ResourceProfile
2121
import org.apache.spark.sql.catalyst.SQLConfHelper
22-
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar}
22+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar, WidenStatefulOpNullability}
2323
import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
2424
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, ExpressionDescription, ExpressionInfo, JsonToStructs, PythonUDF, PythonUDTF}
2525
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -159,7 +159,9 @@ case class FlatMapGroupsInPandasWithState(
159159
timeout: GroupStateTimeout,
160160
child: LogicalPlan) extends UnaryNode {
161161

162-
override def output: Seq[Attribute] = outputAttrs
162+
override def output: Seq[Attribute] =
163+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
164+
else outputAttrs
163165

164166
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
165167

@@ -206,7 +208,9 @@ case class TransformWithStateInPySpark(
206208

207209
override def right: LogicalPlan = initialState
208210

209-
override def output: Seq[Attribute] = outputAttrs
211+
override def output: Seq[Attribute] =
212+
if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
213+
else outputAttrs
210214

211215
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
212216

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3444,6 +3444,24 @@ object SQLConf {
34443444
.booleanConf
34453445
.createWithDefault(true)
34463446

3447+
val STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT =
3448+
buildConf("spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled")
3449+
.internal()
3450+
.withBindingPolicy(ConfigBindingPolicy.SESSION)
3451+
.doc("When true, every streaming stateful operator reports its output schema with " +
3452+
"nullable=true on all columns (including nested struct fields, array elements, and " +
3453+
"map values), and the state schema is widened at every construction site, so the " +
3454+
"existing state schema " +
3455+
"compatibility check trivially passes regardless of input nullability. " +
3456+
"This prevents query-optimizer decisions (e.g., PropagateEmptyRelation dropping a " +
3457+
"Union branch) from flipping the state schema nullability across microbatches or " +
3458+
"restarts. The effective value is pinned per query via the offset log at batch 0, " +
3459+
"so pre-existing queries keep their original behavior; only newly started queries " +
3460+
"pick this up.")
3461+
.version("4.3.0")
3462+
.booleanConf
3463+
.createWithDefault(true)
3464+
34473465
val FILESTREAM_SINK_METADATA_IGNORED =
34483466
buildConf("spark.sql.streaming.fileStreamSink.ignoreMetadata")
34493467
.internal()

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
8686
.count()
8787
.selectExpr("window.start as timestamp", "count as num_events")
8888

89-
assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT NOT NULL")
89+
assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT")
9090

9191
// Start the query
9292
val queryName = "sparkConnectStreamingQuery"

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.internal.LogKeys.{BATCH_NAME, RULE_NAME}
21-
import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
21+
import org.apache.spark.sql.catalyst.analysis.{UpdateAttributeNullability, WidenStatefulOperatorAttributeNullability}
2222
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, EliminateLimits, OptimizeOneRowPlan}
2323
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity}
2424
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
@@ -44,7 +44,8 @@ class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[Logica
4444
Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
4545
Batch("Eliminate Limits", fixedPoint, EliminateLimits),
4646
Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+
47-
Batch("User Provided Runtime Optimizers", fixedPoint, extendedRuntimeOptimizerRules: _*)
47+
Batch("User Provided Runtime Optimizers", fixedPoint, extendedRuntimeOptimizerRules: _*) :+
48+
Batch("Widen Stateful Op Nullability", Once, WidenStatefulOperatorAttributeNullability)
4849

4950
final override protected def batches: Seq[Batch] = {
5051
val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import org.apache.spark.{JobArtifactSet, SparkException, SparkUnsupportedOperati
2020
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2121
import org.apache.spark.sql.Row
2222
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
2324
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2425
import org.apache.spark.sql.catalyst.expressions._
2526
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
@@ -81,7 +82,8 @@ case class FlatMapGroupsInPandasWithStateExec(
8182
override protected val stateEncoder: ExpressionEncoder[Any] =
8283
ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
8384

84-
override def output: Seq[Attribute] = outAttributes
85+
override def output: Seq[Attribute] =
86+
WidenStatefulOpNullability.widenOutputForStatefulOp(outAttributes)
8587

8688
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
8789
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)

0 commit comments

Comments
 (0)