Skip to content

Commit cb0ff92

Browse files
committed
[SPARK-56346][SQL] Use PartitionPredicate in DSV2 Metadata Only Delete
When `OptimizeMetadataOnlyDeleteFromTable` fails to push standard V2 predicates for a metadata-only delete, it now falls back to a second pass that converts partition-column filters to `PartitionPredicate`s (SPARK-55596) and combines them with translated V2 data filters.
1 parent 4142bc2 commit cb0ff92

5 files changed

Lines changed: 621 additions & 48 deletions

File tree

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.catalog
19+
20+
import java.util
21+
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
24+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
25+
import org.apache.spark.sql.connector.expressions.Transform
26+
import org.apache.spark.sql.connector.expressions.filter.{PartitionPredicate, Predicate}
27+
import org.apache.spark.sql.types.StructType
28+
import org.apache.spark.util.ArrayImplicits._
29+
30+
/**
31+
* In-memory table that supports row-level operations and accepts [[PartitionPredicate]]s
32+
* in V2 [[canDeleteWhere]]/[[deleteWhere]] for metadata-only deletes.
33+
*
34+
* Contains some knobs to control acceptance of various partition and data predicates.
35+
*/
36+
class InMemoryPartitionPredicateDeleteTable(
37+
name: String,
38+
schema: StructType,
39+
partitioning: Array[Transform],
40+
properties: util.Map[String, String])
41+
extends InMemoryRowLevelOperationTable(name, schema, partitioning, properties) {
42+
43+
private val acceptPartitionPredicates: Boolean =
44+
properties.getOrDefault(
45+
InMemoryPartitionPredicateDeleteTable.AcceptPartitionPredicatesKey, "true").toBoolean
46+
47+
private val acceptDataPredicates: Boolean =
48+
properties.getOrDefault(
49+
InMemoryPartitionPredicateDeleteTable.AcceptDataPredicatesKey, "false").toBoolean
50+
51+
private val partPaths = partCols.map(_.mkString(".")).toSet
52+
53+
private def refsOnlyPartCols(p: Predicate): Boolean =
54+
p.references().forall(ref => partPaths.contains(ref.fieldNames().mkString(".")))
55+
56+
override def canDeleteWhere(predicates: Array[Predicate]): Boolean = {
57+
predicates.forall {
58+
case _: PartitionPredicate => acceptPartitionPredicates
59+
case p =>
60+
InMemoryTableWithV2Filter.supportsPredicates(Array(p)) &&
61+
(acceptDataPredicates || refsOnlyPartCols(p))
62+
}
63+
}
64+
65+
override def deleteWhere(predicates: Array[Predicate]): Unit = dataMap.synchronized {
66+
val (partPreds, standardPreds) = predicates.partition(_.isInstanceOf[PartitionPredicate])
67+
val (partStdPreds, dataStdPreds) = standardPreds.partition(refsOnlyPartCols)
68+
69+
val candidateKeys = if (partStdPreds.nonEmpty) {
70+
InMemoryTableWithV2Filter.filtersToKeys(
71+
dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, partStdPreds)
72+
} else {
73+
dataMap.keys
74+
}
75+
76+
// Handle partition predicates.
77+
val keysToProcess = if (partPreds.nonEmpty) {
78+
val pArr = partPreds.map(_.asInstanceOf[PartitionPredicate])
79+
candidateKeys.filter { key =>
80+
val partRow = PartitionInternalRow(key.toArray)
81+
pArr.forall(_.eval(partRow))
82+
}
83+
} else {
84+
candidateKeys
85+
}
86+
87+
// Handle data predicates (simulate data source with data column statistics)
88+
if (dataStdPreds.isEmpty) {
89+
dataMap --= keysToProcess
90+
} else {
91+
for (key <- keysToProcess.toSeq) {
92+
dataMap.get(key).foreach { splits =>
93+
val filtered = splits.map { buffered =>
94+
val kept = new BufferedRows(key, buffered.schema)
95+
buffered.rows
96+
.filterNot(rowMatchesAll(_, dataStdPreds, buffered.schema))
97+
.foreach(kept.withRow)
98+
kept
99+
}
100+
if (filtered.forall(_.rows.isEmpty)) {
101+
dataMap.remove(key)
102+
} else {
103+
dataMap.update(key, filtered)
104+
}
105+
}
106+
}
107+
}
108+
}
109+
110+
private def rowMatchesAll(
111+
row: InternalRow,
112+
preds: Array[Predicate],
113+
rowSchema: StructType): Boolean = {
114+
val resolve: String => Any = colName => {
115+
val idx = rowSchema.fieldIndex(colName)
116+
row.get(idx, rowSchema(idx).dataType)
117+
}
118+
preds.forall(
119+
InMemoryTableWithV2Filter.evalPredicate(_, resolve))
120+
}
121+
}
122+
123+
object InMemoryPartitionPredicateDeleteTable {
124+
private[catalog] val AcceptPartitionPredicatesKey = "accept-partition-predicates"
125+
private[catalog] val AcceptDataPredicatesKey = "accept-data-predicates"
126+
}
127+
128+
class InMemoryPartitionPredicateDeleteCatalog extends InMemoryTableCatalog {
129+
import CatalogV2Implicits._
130+
131+
override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
132+
if (tables.containsKey(ident)) {
133+
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
134+
}
135+
136+
InMemoryTableCatalog.maybeSimulateFailedTableCreation(tableInfo.properties)
137+
138+
val tableName = s"$name.${ident.quoted}"
139+
val schema = CatalogV2Util.v2ColumnsToStructType(tableInfo.columns)
140+
val table = new InMemoryPartitionPredicateDeleteTable(
141+
tableName, schema, tableInfo.partitions, tableInfo.properties)
142+
tables.put(ident, table)
143+
namespaces.putIfAbsent(ident.namespace.toList, Map())
144+
table
145+
}
146+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,31 +140,36 @@ object InMemoryTableWithV2Filter {
140140
partitionNames: Seq[String],
141141
filters: Array[Predicate]): Iterable[Seq[Any]] = {
142142
keys.filter { partValues =>
143-
filters.flatMap(splitAnd).forall {
144-
case p: Predicate if p.name().equals("=") =>
145-
p.children()(1).asInstanceOf[LiteralValue[_]].value ==
146-
InMemoryBaseTable.extractValue(p.children()(0).toString, partitionNames, partValues)
147-
case p: Predicate if p.name().equals("<=>") =>
148-
val attrVal = InMemoryBaseTable
149-
.extractValue(p.children()(0).toString, partitionNames, partValues)
150-
val value = p.children()(1).asInstanceOf[LiteralValue[_]].value
151-
if (attrVal == null && value == null) {
152-
true
153-
} else if (attrVal == null || value == null) {
154-
false
155-
} else {
156-
value == attrVal
157-
}
158-
case p: Predicate if p.name().equals("IS_NULL") =>
159-
val attr = p.children()(0).toString
160-
null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
161-
case p: Predicate if p.name().equals("IS_NOT_NULL") =>
162-
val attr = p.children()(0).toString
163-
null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
164-
case p: Predicate if p.name().equals("ALWAYS_TRUE") => true
165-
case f =>
166-
throw new IllegalArgumentException(s"Unsupported filter type: $f")
167-
}
143+
val resolve: String => Any = attr =>
144+
InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
145+
filters.flatMap(splitAnd).forall(evalPredicate(_, resolve))
146+
}
147+
}
148+
149+
/**
150+
* Evaluates a single V2 predicate by resolving column values through the
151+
* given function. Supports =, <=>, IS_NULL, IS_NOT_NULL, and ALWAYS_TRUE.
152+
*/
153+
def evalPredicate(
154+
pred: Predicate,
155+
resolveValue: String => Any): Boolean = {
156+
lazy val attr = pred.children()(0).toString
157+
pred.name() match {
158+
case "=" =>
159+
resolveValue(attr) ==
160+
pred.children()(1).asInstanceOf[LiteralValue[_]].value
161+
case "<=>" =>
162+
val attrVal = resolveValue(attr)
163+
val litVal =
164+
pred.children()(1).asInstanceOf[LiteralValue[_]].value
165+
(attrVal == null && litVal == null) ||
166+
(attrVal != null && litVal != null && attrVal == litVal)
167+
case "IS_NULL" => resolveValue(attr) == null
168+
case "IS_NOT_NULL" => resolveValue(attr) != null
169+
case "ALWAYS_TRUE" => true
170+
case other =>
171+
throw new IllegalArgumentException(
172+
s"Unsupported filter type: $other")
168173
}
169174
}
170175

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic
4343
case table: SupportsDeleteV2 if !SubqueryExpression.hasSubquery(cond) =>
4444
val predicates = splitConjunctivePredicates(cond)
4545
val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, relation.output)
46-
val filters = toDataSourceV2Filters(normalizedPredicates)
47-
val allPredicatesTranslated = normalizedPredicates.size == filters.length
48-
if (allPredicatesTranslated && table.canDeleteWhere(filters)) {
49-
logDebug(s"Switching to delete with filters: ${filters.mkString("[", ", ", "]")}")
50-
DeleteFromTableWithFilters(relation, filters.toImmutableArraySeq)
46+
val filtersOpt = tryTranslateToV2(normalizedPredicates)
47+
if (filtersOpt.exists(table.canDeleteWhere)) {
48+
DeleteFromTableWithFilters(relation, filtersOpt.get.toImmutableArraySeq)
5149
} else {
52-
rowLevelPlan
50+
tryDeleteWithPartitionPredicates(table, relation, normalizedPredicates)
51+
.getOrElse(rowLevelPlan)
5352
}
5453

5554
case _: TruncatableTable if cond == TrueLiteral =>
@@ -70,6 +69,40 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic
7069
}.toArray
7170
}
7271

72+
/**
73+
* Attempts to convert partition-column filters to [[PartitionPredicate]]s and
74+
* combine them with translated V2 data filters for a metadata-only delete. (See SPARK-55596)
75+
*
76+
* Returns [[Some]] with the plan if the table accepts the combined predicates,
77+
* or [[None]] if partition predicates cannot be created or the table rejects them.
78+
*/
79+
private def tryDeleteWithPartitionPredicates(
80+
table: SupportsDeleteV2,
81+
relation: DataSourceV2Relation,
82+
normalizedPredicates: Seq[Expression]): Option[LogicalPlan] = {
83+
for {
84+
partitionFields <- PushDownUtils.getPartitionPredicateSchema(relation)
85+
flattenedFilters = PushDownUtils.flattenNestedPartitionFilters(
86+
normalizedPredicates, partitionFields).keys.toSeq
87+
(candidatePredicates, remainingFilters) =
88+
PushDownUtils.createPartitionPredicates(flattenedFilters, partitionFields)
89+
// None if no partition predicates created
90+
partPredicates <- Option.when(candidatePredicates.nonEmpty)(candidatePredicates)
91+
// None if any remaining filter cannot be translated to V2
92+
dataV2Filters <- tryTranslateToV2(remainingFilters)
93+
combined = partPredicates.toArray ++ dataV2Filters
94+
if table.canDeleteWhere(combined)
95+
} yield {
96+
DeleteFromTableWithFilters(relation, combined.toImmutableArraySeq)
97+
}
98+
}
99+
100+
/** Translates all expressions to V2 filters, or returns [[None]] if any fail. */
101+
private def tryTranslateToV2(predicates: Seq[Expression]): Option[Array[Predicate]] = {
102+
val filters = toDataSourceV2Filters(predicates)
103+
Option.when(filters.length == predicates.size)(filters)
104+
}
105+
73106
private object RewrittenRowLevelCommand {
74107
type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression, LogicalPlan)
75108

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ object PushDownUtils extends Logging {
149149
case _ => None
150150
}
151151
if (fields.length == transforms.length) {
152-
Some(fields.toSeq)
152+
Some(fields.toSeq).filter(_.nonEmpty)
153153
} else {
154154
None
155155
}
@@ -177,6 +177,32 @@ object PushDownUtils extends Logging {
177177
}
178178
}
179179

180+
/**
181+
* Separates partition filters from data filters and converts pushable partition
182+
* filters to [[PartitionPredicateImpl]] instances.
183+
*
184+
* Callers must first flatten nested partition field references via
185+
* [[flattenNestedPartitionFilters]] with [[ExprId]] matching the [[PartitionPredicateField]]s.
186+
*
187+
* @param flattenedFilters Catalyst filter expressions with partition field references
188+
* already flattened.
189+
* @param partitionFields Partition field metadata.
190+
* @return a pair of (created partition predicates, remaining filters not converted).
191+
*/
192+
private[v2] def createPartitionPredicates(
193+
flattenedFilters: Seq[Expression],
194+
partitionFields: Seq[PartitionPredicateField])
195+
: (Seq[PartitionPredicateImpl], Seq[Expression]) = {
196+
val partitionAttributes = partitionFields.map(_.attrRef)
197+
val (partFilters, nonPartitionFilters) =
198+
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionAttributes, flattenedFilters)
199+
val (pushable, nonPushable) = partFilters.partition(isPushablePartitionFilter)
200+
val (partitionPredicates, errorPartitionPredicates) = pushable.partitionMap { e =>
201+
PartitionPredicateImpl(e, partitionFields).toLeft(e)
202+
}
203+
(partitionPredicates, nonPartitionFilters ++ nonPushable ++ errorPartitionPredicates)
204+
}
205+
180206
/**
181207
* If the scan supports iterative filtering, infer additional partition filters,
182208
* convert these and unused partition filters to PartitionPredicates,
@@ -186,22 +212,15 @@ object PushDownUtils extends Logging {
186212
scanBuilder: SupportsPushDownV2Filters,
187213
partitionFields: Seq[PartitionPredicateField],
188214
remainingFilters: Seq[Expression]): Seq[Expression] = {
189-
val normalizedToOriginal = normalizeNestedPartitionFilters(remainingFilters, partitionFields)
190-
val normalized = normalizedToOriginal.keys.toSeq
191-
val partitionAttributes = partitionFields.map(_.attrRef)
192-
// may infer additional partition filters
193-
val (partFilters, nonPartitionFilters) =
194-
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionAttributes, normalized)
195-
val (pushable, nonPushable) = partFilters.partition(isPushablePartitionFilter)
196-
val (partitionPredicates, errorPartitionPredicates) = pushable.partitionMap { e =>
197-
PartitionPredicateImpl(e, partitionFields).toLeft(e)
198-
}
199-
val rejectedPartitionFilters = scanBuilder.pushPredicates(partitionPredicates.toArray).map {
215+
val flattenedToOriginal = flattenNestedPartitionFilters(remainingFilters, partitionFields)
216+
val flattened = flattenedToOriginal.keys.toSeq
217+
val (partPredicates, remaining) = createPartitionPredicates(flattened, partitionFields)
218+
val rejectedPartitionFilters = scanBuilder.pushPredicates(partPredicates.toArray).map {
200219
p => p.asInstanceOf[PartitionPredicateImpl].expression
201220
}.toSeq
202-
(nonPartitionFilters ++ nonPushable ++ errorPartitionPredicates ++ rejectedPartitionFilters)
203-
.filter(normalizedToOriginal.contains)
204-
.map(normalizedToOriginal)
221+
(remaining ++ rejectedPartitionFilters)
222+
.filter(flattenedToOriginal.contains)
223+
.map(flattenedToOriginal)
205224
}
206225

207226
private def isPushablePartitionFilter(f: Expression) =
@@ -218,9 +237,9 @@ object PushDownUtils extends Logging {
218237
* (identity transform on a nested field), the analyzer produces
219238
* `GetStructField(attr("s"), "tz")`. This method replaces that chain with `attr("s.tz")`.
220239
*
221-
* Returns a map from normalized expression to original.
240+
* Returns a map from flattened expression to original.
222241
*/
223-
private def normalizeNestedPartitionFilters(
242+
private[v2] def flattenNestedPartitionFilters(
224243
filters: Seq[Expression],
225244
partitionFields: Seq[PartitionPredicateField])
226245
: Map[Expression, Expression] = {

0 commit comments

Comments
 (0)