@@ -21,11 +21,14 @@ import org.apache.spark.SparkConf
2121import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
2222import org .apache .spark .sql .catalyst .dsl .expressions ._
2323import org .apache .spark .sql .catalyst .expressions ._
24+ import org .apache .spark .sql .catalyst .expressions .aggregate .Min
2425import org .apache .spark .sql .catalyst .expressions .variant .VariantGet
2526import org .apache .spark .sql .catalyst .util .V2ExpressionBuilder
2627import org .apache .spark .sql .connector .expressions .{Expression => V2Expression , FieldReference , GeneralScalarExpression , LiteralValue , VariantGet => V2VariantGet }
2728import org .apache .spark .sql .connector .expressions .filter .{AlwaysFalse , AlwaysTrue , And => V2And , Not => V2Not , Or => V2Or , Predicate }
2829import org .apache .spark .sql .execution .TestPredicateRuntimeReplaceable
30+ import org .apache .spark .sql .execution .TestRuntimeReplaceable
31+ import org .apache .spark .sql .execution .datasources .DataSourceStrategy
2932import org .apache .spark .sql .internal .SQLConf
3033import org .apache .spark .sql .test .SharedSparkSession
3134import org .apache .spark .sql .types .{BooleanType , DoubleType , IntegerType , LongType , StringType , StructField , StructType , TimestampType , VariantType }
@@ -1040,6 +1043,28 @@ class DataSourceV2StrategySuite extends SharedSparkSession {
10401043 }
10411044 }
10421045
1046+ test(" SPARK-57512: aggregate and group-by pushdown unfold surviving RuntimeReplaceable" ) {
1047+ val attrInt = $" cint" .int
1048+ // A surviving RuntimeReplaceable whose `replacement` is `cint + 1`. It appears both as the MIN
1049+ // argument and as the group-by expression -- both route through `V2ExpressionBuilder`.
1050+ def wrapped : Expression = TestRuntimeReplaceable (attrInt, Literal (1 ))
1051+ def unfolded : Expression = Add (attrInt, Literal (1 ))
1052+
1053+ val actual = DataSourceStrategy .translateAggregation(
1054+ Seq (Min (wrapped).toAggregateExpression()), Seq (wrapped))
1055+ // Translating the already-unfolded equivalent gives the reference result.
1056+ val expected = DataSourceStrategy .translateAggregation(
1057+ Seq (Min (unfolded).toAggregateExpression()), Seq (unfolded))
1058+
1059+ assert(expected.isDefined)
1060+ assert(actual.isDefined,
1061+ " aggregate/group-by containing a surviving RuntimeReplaceable should translate" )
1062+ assert(actual.get.groupByExpressions().map(_.describe()).toSeq ==
1063+ expected.get.groupByExpressions().map(_.describe()).toSeq)
1064+ assert(actual.get.aggregateExpressions().map(_.describe()).toSeq ==
1065+ expected.get.aggregateExpressions().map(_.describe()).toSeq)
1066+ }
1067+
10431068 /**
10441069 * Translate the given Catalyst [[Expression ]] into data source V2 [[Predicate ]]
10451070 * then verify against the given [[Predicate ]].
0 commit comments