Skip to content

Commit 66a5a89

Browse files
anton5798cloud-fan
authored andcommitted
[SPARK-56467][SQL] Route scalar subquery partition filters into DSv2 runtime filtering
### What changes were proposed in this pull request? Scalar subquery filters on partition columns (e.g., `WHERE d_date_sk = (SELECT min(d_date_sk) FROM ...)`) are excluded from pushdown in DSv2 at every stage. The filter lands as a `FilterExec` above `BatchScanExec`, evaluated row-by-row. The scan reads all partitions -- no partition pruning occurs. DSv1 already handles this: `FileSourceStrategy` puts subquery filters in `partitionFilters`, `isDynamicFilter` classifies them as dynamic, and `getPartitionPruningFilterFromBroadcast` calls `ScalarSubquery.toLiteral` at execution time for partition pruning via `listFiles()`. This PR routes partition-column scalar subquery filters into `BatchScanExec.runtimeFilters`, leveraging the existing `SupportsRuntimeV2Filtering.filter()` infrastructure: - **DataSourceV2Strategy**: When the scan implements `SupportsRuntimeV2Filtering`, extract subquery filters from `postScanFilters` where references are a subset of partition columns. Add to `runtimeFilters` alongside existing DPP filters. They remain in `postScanFilters` as a correctness safety net (V2 `filter()` is advisory). - **BatchScanExec**: In `filteredPartitions`, non-DPP runtime filters are literalized (replacing `ExecScalarSubquery` with its resolved literal) and translated to V2 predicates via `translateFilterV2`. - **InMemoryTableWithV2Filter** (test infra): Added `=` predicate handling in `filter()` alongside existing `IN`, plus a `case _ =>` catch-all. No new interfaces, no config flags, no connector changes needed. ### Why are the changes needed? TPC-DS queries with scalar subquery partition filters (e.g., Q5, Q12, Q16, Q20, Q37, Q77, Q80, Q92, Q94, Q95) read all partitions in DSv2 scans even though the subquery resolves to a single value at runtime. This causes significant I/O overhead that DSv1 avoids. ### Does this PR introduce _any_ user-facing change? No API changes. Queries with scalar subquery filters on partition columns will now benefit from partition pruning in DSv2 scans, reducing I/O. ### How was this patch tested? New unit test in `DataSourceV2SQLSuiteV2Filter`: - Creates a 10-partition table and a dimension table - Runs `SELECT * FROM t WHERE part = (SELECT max(val) FROM dim)` - Asserts query correctness, scalar subquery presence in `runtimeFilters`, and exactly 1 partition after pruning ### Was this patch authored or co-authored using generative AI tooling? Yes, co-authored with Claude Code. Closes #55335 from anton5798/scalar-subquery-dsv2-pruning. Lead-authored-by: Anton Lykov <25360033+anton5798@users.noreply.github.com> Co-authored-by: Anton Lykov <antony.lykov@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 26a86f9 commit 66a5a89

5 files changed

Lines changed: 101 additions & 5 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.{Optional, OptionalLong}
2222
import org.apache.spark.SparkException
2323
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation, TimeTravelSpec}
2424
import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics}
25-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder}
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, SortOrder, V2ExpressionUtils}
2626
import org.apache.spark.sql.catalyst.plans.QueryPlan
2727
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics}
2828
import org.apache.spark.sql.catalyst.streaming.{StreamingSourceIdentifyingName, Unassigned}
@@ -31,11 +31,12 @@ import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
3131
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil}
3232
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
3333
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
34-
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics}
34+
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics, SupportsRuntimeV2Filtering}
3535
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram => V2Histogram, HistogramBin => V2HistogramBin}
3636
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
3737
import org.apache.spark.sql.types.{DataType, StructType}
3838
import org.apache.spark.sql.util.CaseInsensitiveStringMap
39+
import org.apache.spark.util.ArrayImplicits._
3940
import org.apache.spark.util.Utils
4041

4142
/**
@@ -174,6 +175,18 @@ case class DataSourceV2ScanRelation(
174175
// skip adding IsNotNull when the scan already implies it, or infer new filters across
175176
// joins), so plan stability testing is needed first.
176177

178+
/**
179+
* Resolved attributes that the scan declares for runtime filtering via
180+
* [[SupportsRuntimeV2Filtering.filterAttributes]]. Empty when the scan
181+
* does not implement [[SupportsRuntimeV2Filtering]] or exposes no attributes.
182+
*/
183+
lazy val runtimeFilterAttrs: AttributeSet = scan match {
184+
case s: SupportsRuntimeV2Filtering =>
185+
AttributeSet(V2ExpressionUtils.resolveRefs[Attribute](
186+
s.filterAttributes.toImmutableArraySeq, this))
187+
case _ => AttributeSet.empty
188+
}
189+
177190
override def name: String = relation.name
178191

179192
override def simpleString(maxFields: Int): String = {

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,22 @@ class InMemoryTableWithV2Filter(
8787
})
8888
}
8989
}
90+
case p : Predicate if p.name().equals("=") =>
91+
if (p.children().length == 2) {
92+
val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head
93+
if (filterRef.toString.equals(ref.toString)) {
94+
val matchingKey = p.children()(1).asInstanceOf[LiteralValue[_]].value
95+
if (matchingKey != null) {
96+
data = data.filter(partition => {
97+
val key = partition.asInstanceOf[BufferedRows].keyString()
98+
key == matchingKey.toString
99+
})
100+
} else {
101+
data = Seq.empty // NULL = anything is always false
102+
}
103+
}
104+
}
105+
case _ => // Ignore unsupported predicate types
90106
}
91107
}
92108
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ case class BatchScanExec(
6363
@transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = {
6464
val dataSourceFilters = runtimeFilters.flatMap {
6565
case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
66-
case _ => None
66+
case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f)
6767
}
6868

6969
val originalPartitioning = outputPartitioning

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning
3131
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
3232
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
3333
import org.apache.spark.sql.catalyst.plans.logical._
34+
import org.apache.spark.sql.catalyst.trees.TreePattern.SCALAR_SUBQUERY
3435
import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, ResolveTableConstraints, V2ExpressionBuilder}
3536
import org.apache.spark.sql.classic.SparkSession
3637
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, TableCapability, TableCatalog, TruncatableTable, V1Table}
@@ -42,7 +43,7 @@ import org.apache.spark.sql.connector.read.LocalScan
4243
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, SupportsRealTimeMode}
4344
import org.apache.spark.sql.connector.write.V1Write
4445
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
45-
import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan, SparkStrategy => Strategy}
46+
import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, ScalarSubquery => ExecScalarSubquery, SparkPlan, SparkStrategy => Strategy}
4647
import org.apache.spark.sql.execution.command.CommandUtils
4748
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn}
4849
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
@@ -155,10 +156,26 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
155156
// projection and filters were already pushed down in the optimizer.
156157
// this uses PhysicalOperation to get the projection and ensure that if the batch scan does
157158
// not support columnar, a projection is added to convert the rows to UnsafeRow.
158-
val (runtimeFilters, postScanFilters) = filters.partition {
159+
val (dynamicFilters, postScanFilters) = filters.partition {
159160
case _: DynamicPruning => true
160161
case _ => false
161162
}
163+
164+
// Extract scalar subquery filters on runtime-filterable columns for runtime pushdown.
165+
// These filters stay in postScanFilters for correctness (FilterExec above scan),
166+
// but are also routed into runtimeFilters so BatchScanExec can use them for
167+
// partition pruning via SupportsRuntimeV2Filtering.filter().
168+
val scalarSubqueryFilters = if (relation.runtimeFilterAttrs.nonEmpty) {
169+
postScanFilters.filter { f =>
170+
f.containsPattern(SCALAR_SUBQUERY) &&
171+
f.references.nonEmpty &&
172+
f.references.subsetOf(relation.runtimeFilterAttrs)
173+
}
174+
} else {
175+
Seq.empty
176+
}
177+
val runtimeFilters = dynamicFilters ++ scalarSubqueryFilters
178+
162179
val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters,
163180
relation.ordering, relation.relation.table, relation.keyGroupedPartitioning)
164181
DataSourceV2Strategy.withProjectAndFilter(
@@ -746,6 +763,19 @@ private[sql] object DataSourceV2Strategy extends Logging {
746763
None
747764
}
748765

766+
/**
767+
* Literalizes scalar subqueries in the given expression and translates the result to a V2
768+
* [[Predicate]]. Used at runtime in [[BatchScanExec]] after scalar subqueries have been
769+
* evaluated.
770+
*/
771+
protected[sql] def translateScalarSubqueryFilterV2(
772+
expr: Expression): Option[Predicate] = {
773+
val literalized = expr.transform {
774+
case s: ExecScalarSubquery => s.toLiteral
775+
}
776+
translateFilterV2(literalized)
777+
}
778+
749779
/**
750780
* Creates new spark plan that should apply given filters and projections to given scan node
751781
* @param project Projection list that should be output of returned spark plan

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4315,7 +4315,44 @@ class DataSourceV2SQLSuiteV1Filter
43154315
}
43164316

43174317
class DataSourceV2SQLSuiteV2Filter extends DataSourceV2SQLSuite {
4318+
import org.apache.spark.sql.catalyst.expressions.DynamicPruning
4319+
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
4320+
43184321
override protected val catalogAndNamespace = "testv2filter.ns1.ns2."
4322+
4323+
test("SPARK-56467: scalar subquery filters on partition columns are pushed into runtimeFilters") {
4324+
val tbl = s"${catalogAndNamespace}tbl"
4325+
val dim = s"${catalogAndNamespace}dim"
4326+
withTable(tbl, dim) {
4327+
sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Format PARTITIONED BY (part)")
4328+
for (i <- 0 until 10) {
4329+
sql(s"INSERT INTO $tbl VALUES ($i, $i)")
4330+
}
4331+
4332+
sql(s"CREATE TABLE $dim (val INT) USING $v2Format")
4333+
sql(s"INSERT INTO $dim VALUES (3)")
4334+
4335+
val df = sql(s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)")
4336+
4337+
// Verify query correctness
4338+
checkAnswer(df, Row(3, 3))
4339+
4340+
// Verify runtime filters contain the scalar subquery filter
4341+
val batchScan = collect(df.queryExecution.executedPlan) {
4342+
case b: BatchScanExec => b
4343+
}.head
4344+
assert(batchScan.runtimeFilters.nonEmpty,
4345+
"Expected runtimeFilters to contain scalar subquery filter")
4346+
assert(!batchScan.runtimeFilters.exists(
4347+
_.isInstanceOf[DynamicPruning]),
4348+
"Expected non-DPP runtime filter (scalar subquery)")
4349+
4350+
// Verify partition pruning: only 1 of 10 partitions should remain
4351+
val numPartitions = batchScan.filteredPartitions.count(_.isDefined)
4352+
assert(numPartitions == 1,
4353+
s"Expected 1 partition after scalar subquery pruning, got $numPartitions")
4354+
}
4355+
}
43194356
}
43204357

43214358
class ReserveSchemaNullabilityCatalog extends InMemoryCatalog {

0 commit comments

Comments
 (0)