Skip to content

Commit f4b320e

Browse files
aokolnychyigengliangwang
authored andcommitted
[SPARK-56669][SQL] Implement group filtering for WriteDelta row level operations
### What changes were proposed in this pull request? This PR implements group filtering for WriteDelta row level operations. It re-applies #55612 (commit `5ef2e1ba174`, reverted in `8e8fee2692f`) and resolves the test failures reported in #55612 (comment) by updating the scan-count assertions in the transactional check tests in `MergeIntoTableSuiteBase` and `UpdateTableSuiteBase`. With group filtering, `matchingRowsPlan` re-scans the target, and for MERGE `RewritePredicateSubquery` also re-scans the source. For MERGE the delta scan counts now match the non-delta values, so the `deltaMerge` conditionals collapse. For UPDATE the delta counts double but remain under the non-delta values because `ReplaceData` still adds further scans. ### Why are the changes needed? These changes are needed to close the gap in WriteDelta plans. ### Does this PR introduce _any_ user-facing change? Changes are backward compatible. ### How was this patch tested? This PR comes with tests. Locally verified all 9 affected suites are green (517 tests): ``` build/sbt 'sql/testOnly \ org.apache.spark.sql.connector.DeltaBasedMergeIntoTableSuite \ org.apache.spark.sql.connector.DeltaBasedMergeIntoTableWithDeletionVectorsSuite \ org.apache.spark.sql.connector.DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite \ org.apache.spark.sql.connector.DeltaBasedUpdateTableSuite \ org.apache.spark.sql.connector.DeltaBasedUpdateTableWithDeletionVectorsSuite \ org.apache.spark.sql.connector.DeltaBasedUpdateAsDeleteAndInsertTableSuite \ org.apache.spark.sql.connector.DeltaBasedNoMetadataDeleteFromTableSuite \ org.apache.spark.sql.connector.GroupBasedMergeIntoTableSuite \ org.apache.spark.sql.connector.GroupBasedUpdateTableSuite' ``` ### Was this patch authored or co-authored using generative AI tooling? Claude Code v2.1.123. Closes #55635 from gengliangwang/spark-56669-redo. Lead-authored-by: Anton Okolnychyi <aokolnychyi@apache.org> Co-authored-by: Anton Okolnychyi <aokolnychyi@apache.org> Signed-off-by: Gengliang Wang <gengliang@apache.org>
1 parent 294f6c1 commit f4b320e

16 files changed

Lines changed: 357 additions & 97 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,12 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
295295
// build a plan to write the row delta to the table
296296
val writeRelation = relation.copy(table = operationTable)
297297
val projections = buildWriteDeltaProjections(mergeRowsPlan, rowAttrs, rowIdAttrs, metadataAttrs)
298-
WriteDelta(writeRelation, cond, mergeRowsPlan, relation, projections)
298+
val groupFilterCond = if (notMatchedBySourceActions.isEmpty && groupFilterEnabled) {
299+
Some(toGroupFilterCondition(relation, source, cond))
300+
} else {
301+
None
302+
}
303+
WriteDelta(writeRelation, cond, mergeRowsPlan, relation, projections, groupFilterCond)
299304
}
300305

301306
private def chooseWriteDeltaJoinType(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
174174
// build a plan to write the row delta to the table
175175
val writeRelation = relation.copy(table = operationTable)
176176
val projections = buildWriteDeltaProjections(rowDeltaPlan, rowAttrs, rowIdAttrs, metadataAttrs)
177-
WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections)
177+
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
178+
WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections, groupFilterCond)
178179
}
179180

180181
// this method assumes the assignments have been already aligned before

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
6060
val newCond = replaceNullWithFalse(cond)
6161
val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse)
6262
rd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond)
63-
case wd @ WriteDelta(_, cond, _, _, _, _) => wd.copy(condition = replaceNullWithFalse(cond))
63+
case wd @ WriteDelta(_, cond, _, _, _, groupFilterCond, _) =>
64+
val newCond = replaceNullWithFalse(cond)
65+
val newGroupFilterCond = groupFilterCond.map(replaceNullWithFalse)
66+
wd.copy(condition = newCond, groupFilterCondition = newGroupFilterCond)
6467
case d @ DeleteFromTable(_, cond) => d.copy(condition = replaceNullWithFalse(cond))
6568
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
6669
case m: MergeIntoTable =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre
432432
* - the read relation that can be either [[DataSourceV2Relation]] or [[DataSourceV2ScanRelation]]
433433
* depending on whether the planning has already happened;
434434
*/
435-
object GroupBasedRowLevelOperation {
435+
object GroupBasedRowLevelOperation extends RowLevelOperationExtractor {
436436
type ReturnType = (ReplaceData, Expression, Option[Expression], LogicalPlan)
437437

438438
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
@@ -445,8 +445,34 @@ object GroupBasedRowLevelOperation {
445445
case _ =>
446446
None
447447
}
448+
}
449+
450+
/**
451+
* An extractor for row-level commands such as DELETE, UPDATE, MERGE that were rewritten using plans
452+
* that operate on individual rows (row deltas).
453+
*
454+
* This class extracts the following entities:
455+
* - the delta-based rewrite plan;
456+
* - the condition that defines matching rows;
457+
* - the group filter condition;
458+
* - the read relation that can be either [[DataSourceV2Relation]] or [[DataSourceV2ScanRelation]]
459+
* depending on whether the planning has already happened;
460+
*/
461+
object DeltaBasedRowLevelOperation extends RowLevelOperationExtractor {
462+
type ReturnType = (WriteDelta, Expression, Option[Expression], LogicalPlan)
463+
464+
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
465+
case wd @ WriteDelta(ExtractV2Table(table), cond, query, _, _, groupFilterCond, _) =>
466+
val readRelation = findReadRelation(table, query, allowMultipleReads = false)
467+
readRelation.map((wd, cond, groupFilterCond, _))
468+
469+
case _ =>
470+
None
471+
}
472+
}
448473

449-
private def findReadRelation(
474+
trait RowLevelOperationExtractor {
475+
protected def findReadRelation(
450476
table: Table,
451477
plan: LogicalPlan,
452478
allowMultipleReads: Boolean): Option[LogicalPlan] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ case class ReplaceData(
425425
* @param query a query with a delta of records that should written
426426
* @param originalTable a plan for the original table for which the row-level command was triggered
427427
* @param projections projections for row ID, row, metadata attributes
428+
* @param groupFilterCondition a condition that can be used to filter groups at runtime
428429
* @param write a logical write, if already constructed
429430
*/
430431
case class WriteDelta(
@@ -433,6 +434,7 @@ case class WriteDelta(
433434
query: LogicalPlan,
434435
originalTable: NamedRelation,
435436
projections: WriteDeltaProjections,
437+
groupFilterCondition: Option[Expression] = None,
436438
write: Option[DeltaWrite] = None) extends RowLevelWrite {
437439

438440
override val isByName: Boolean = false

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -746,15 +746,15 @@ object SQLConf {
746746

747747
val RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED =
748748
buildConf("spark.sql.optimizer.runtime.rowLevelOperationGroupFilter.enabled")
749-
.doc("Enables runtime group filtering for group-based row-level operations. " +
750-
"Data sources that replace groups of data (e.g. files, partitions) may prune entire " +
751-
"groups using provided data source filters when planning a row-level operation scan. " +
752-
"However, such filtering is limited as not all expressions can be converted into data " +
753-
"source filters and some expressions can only be evaluated by Spark (e.g. subqueries). " +
754-
"Since rewriting groups is expensive, Spark can execute a query at runtime to find what " +
755-
"records match the condition of the row-level operation. The information about matching " +
756-
"records will be passed back to the row-level operation scan, allowing data sources to " +
757-
"discard groups that don't have to be rewritten.")
749+
.doc("Enables runtime filtering for group-based and delta-based row-level operations. " +
750+
"Data sources may prune entire file groups at runtime when planning a row-level " +
751+
"operation scan. Planning-time filter pushdown is limited as not all expressions can " +
752+
"be converted into data source filters and some expressions can only be evaluated by " +
753+
"Spark (e.g. subqueries). Since rewriting groups or scanning unnecessary files is " +
754+
"expensive, Spark can execute a lightweight query at runtime to find what records match " +
755+
"the condition of the row-level operation. The information about matching records will " +
756+
"be passed back to the row-level operation scan, allowing data sources to skip files " +
757+
"that don't have to be processed.")
758758
.version("3.4.0")
759759
.booleanConf
760760
.createWithDefault(true)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
525525
r.name) :: Nil
526526

527527
case wd @ WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections,
528-
Some(write)) =>
528+
_, Some(write)) =>
529529
WriteDeltaExec(
530530
planLater(query),
531531
refreshCache(r), // use the original relation to refresh the cache

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic
114114
val command = rd.operation.command
115115
Some(rd, command, cond, originalTable)
116116

117-
case wd @ WriteDelta(_, cond, _, originalTable, _, _) =>
117+
case wd @ WriteDelta(_, cond, _, originalTable, _, _, _) =>
118118
val command = wd.operation.command
119119
Some(wd, command, cond, originalTable)
120120

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
113113
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
114114
rd.copy(write = Some(write), query = newQuery)
115115

116-
case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) =>
116+
case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, _, None) =>
117117
val writeOptions = mergeOptions(Map.empty, r.options.asCaseSensitiveMap.asScala.toMap)
118118
val deltaWriteBuilder = newDeltaWriteBuilder(r.table, writeOptions, projections)
119119
val deltaWrite = deltaWriteBuilder.build()

sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils}
2222
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2323
import org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery
24-
import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
25-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan}
24+
import org.apache.spark.sql.catalyst.planning.{DeltaBasedRowLevelOperation, GroupBasedRowLevelOperation}
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, RowLevelWrite}
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
28-
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
2928
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE}
3029
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation, ExtractV2Scan}
3130
import org.apache.spark.util.ArrayImplicits._
@@ -34,66 +33,78 @@ import org.apache.spark.util.ArrayImplicits._
3433
* A rule that assigns a subquery to filter groups in row-level operations at runtime.
3534
*
3635
* Data skipping during job planning for row-level operations is limited to expressions that can be
37-
* converted to data source filters. Since not all expressions can be pushed down that way and
38-
* rewriting groups is expensive, Spark allows data sources to filter group at runtime.
39-
* If the primary scan in a group-based row-level operation supports runtime filtering, this rule
40-
* will inject a subquery to find all rows that match the condition so that data sources know
41-
* exactly which groups must be rewritten.
36+
* converted to data source filters. Since not all expressions can be pushed down that way, Spark
37+
* allows data sources to filter groups at runtime. If the primary scan in a row-level operation
38+
* supports runtime filtering, this rule will inject a subquery to find all rows that match the
39+
* condition so that data sources know exactly which groups have changes.
4240
*
43-
* Note this rule only applies to group-based row-level operations.
41+
* Note that this rule is also beneficial for operations that deal with deltas of rows. Even if
42+
* the data source is capable of handling specific changes, it is useful to first discard entire
43+
* groups that are not modified. The cost of the runtime query is small as it only projects columns
44+
* required to evaluate the row level operation condition. The main scan, on the other hand, must
45+
* project all columns, meaning the cost of reading unaffected groups can dominate the runtime.
4446
*/
4547
class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPlan])
4648
extends Rule[LogicalPlan] with PredicateHelper {
4749

4850
import DataSourceV2Implicits._
4951

5052
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
51-
// apply special dynamic filtering only for group-based row-level operations
5253
case GroupBasedRowLevelOperation(replaceData, _, Some(cond),
53-
ExtractV2Scan(scan: SupportsRuntimeV2Filtering))
54-
if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral
55-
&& scan.filterAttributes().nonEmpty =>
56-
57-
// use reference equality on scan to find required scan relations
58-
val newQuery = replaceData.query transformUp {
59-
case r: DataSourceV2ScanRelation if r.scan eq scan =>
60-
// use the original table instance that was loaded for this row-level operation
61-
// in order to leverage a regular batch scan in the group filter query
62-
val originalTable = r.relation.table.asRowLevelOperationTable.table
63-
val relation = r.relation.copy(table = originalTable)
64-
val tableAttrs = replaceData.table.output
65-
val command = replaceData.operation.command
66-
val matchingRowsPlan = buildMatchingRowsPlan(relation, cond, tableAttrs, command)
67-
68-
val filterAttrs = scan.filterAttributes.toImmutableArraySeq
69-
val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan)
70-
val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r)
71-
val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys)
72-
73-
Filter(dynamicPruningCond, r)
74-
}
75-
76-
// optimize subqueries to rewrite them as joins and trigger job planning
77-
replaceData.copy(query = optimizeSubqueries(newQuery))
54+
ExtractV2Scan(scan: SupportsRuntimeV2Filtering)) if canInjectGroupFilters(cond, scan) =>
55+
injectGroupFilters(replaceData, cond, scan)
56+
57+
case DeltaBasedRowLevelOperation(writeDelta, _, Some(cond),
58+
ExtractV2Scan(scan: SupportsRuntimeV2Filtering)) if canInjectGroupFilters(cond, scan) =>
59+
injectGroupFilters(writeDelta, cond, scan)
60+
}
61+
62+
private def canInjectGroupFilters(
63+
cond: Expression,
64+
scan: SupportsRuntimeV2Filtering): Boolean = {
65+
conf.runtimeRowLevelOperationGroupFilterEnabled &&
66+
cond != TrueLiteral &&
67+
scan.filterAttributes.nonEmpty
68+
}
69+
70+
private def injectGroupFilters(
71+
write: RowLevelWrite,
72+
cond: Expression,
73+
scan: SupportsRuntimeV2Filtering): LogicalPlan = {
74+
// use reference equality on scan to find required scan relations
75+
val newQuery = write.query transformUp {
76+
case r: DataSourceV2ScanRelation if r.scan eq scan =>
77+
// use the original table instance that was loaded for this row-level operation
78+
// in order to leverage a regular batch scan in the group filter query
79+
val originalTable = r.relation.table.asRowLevelOperationTable.table
80+
val relation = r.relation.copy(table = originalTable)
81+
val matchingRowsPlan = buildMatchingRowsPlan(write, relation, cond)
82+
val filterAttrs = scan.filterAttributes.toImmutableArraySeq
83+
val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan)
84+
val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r)
85+
Filter(buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys), r)
86+
}
87+
// optimize subqueries to rewrite them as joins and trigger job planning
88+
write.withNewQuery(optimizeSubqueries(newQuery))
7889
}
7990

8091
private def buildMatchingRowsPlan(
92+
write: RowLevelWrite,
8193
relation: DataSourceV2Relation,
82-
cond: Expression,
83-
tableAttrs: Seq[Attribute],
84-
command: Command): LogicalPlan = {
94+
cond: Expression): LogicalPlan = {
8595

86-
val matchingRowsPlan = command match {
96+
val matchingRowsPlan = write.operation.command match {
8797
case DELETE =>
8898
Filter(cond, relation)
8999

90100
case UPDATE =>
91-
// UPDATEs with subqueries are rewritten using UNION with two identical scan relations
101+
// UPDATEs with subqueries can be rewritten using UNION with two identical scan relations
92102
// the analyzer assigns fresh expr IDs for one of them so that attributes don't collide
93103
// this rule assigns runtime filters to both scan relations (will be shared at runtime)
94104
// and must transform the runtime filter condition to use correct expr IDs for each relation
105+
// note this only applies to group-based row-level operations (i.e. ReplaceData)
95106
// see RewriteUpdateTable for more details
96-
val attrMap = buildTableToScanAttrMap(tableAttrs, relation.output)
107+
val attrMap = buildTableToScanAttrMap(write.table.output, relation.output)
97108
val transformedCond = cond transform {
98109
case attr: AttributeReference if attrMap.contains(attr) => attrMap(attr)
99110
}

0 commit comments

Comments
 (0)