Skip to content

Commit 3c280e7

Browse files
authored
[GH-2884] Add ST_Extent aggregate (returns Box2D) (#2898)
1 parent a30fb9b commit 3c280e7

3 files changed

Lines changed: 95 additions & 1 deletion

File tree

spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,5 +538,10 @@ object Catalog extends AbstractCatalog with Logging {
538538
// are only constructed when registerAll is called and Spark is set up. This lets the
539539
// categorization invariant test access `Catalog.expressions` without bootstrapping Spark.
540540
lazy val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
541-
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr(), new ST_Collect_Agg())
541+
Seq(
542+
new ST_Envelope_Aggr,
543+
new ST_Extent,
544+
new ST_Intersection_Aggr,
545+
new ST_Union_Aggr(),
546+
new ST_Collect_Agg())
542547
}

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.spark.sql.sedona_sql.expressions
2020

2121
import org.apache.sedona.common.Functions
22+
import org.apache.sedona.common.geometryObjects.Box2D
2223
import org.apache.spark.sql.{Encoder, Encoders}
2324
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2425
import org.apache.spark.sql.expressions.Aggregator
@@ -164,6 +165,51 @@ private[apache] class ST_Envelope_Aggr
164165
def zero: Option[EnvelopeBuffer] = None
165166
}
166167

168+
/**
169+
* Return the planar bounding box (Box2D) of all geometries in the given column. Returns NULL when
170+
* the input contains no rows or all rows are null/empty geometries. Mirrors PostGIS ST_Extent.
171+
*
172+
* ST_Envelope_Aggr is left untouched (returns a polygon Geometry) for backwards compatibility.
173+
*/
174+
private[apache] class ST_Extent extends Aggregator[Geometry, Option[EnvelopeBuffer], Box2D] {
175+
176+
val outputSerde: ExpressionEncoder[Box2D] = ExpressionEncoder[Box2D]()
177+
178+
def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): Option[EnvelopeBuffer] = {
179+
if (input == null || input.isEmpty) return buffer
180+
val env = input.getEnvelopeInternal
181+
val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, env.getMaxY)
182+
buffer match {
183+
case Some(b) => Some(b.merge(envBuffer))
184+
case None => Some(envBuffer)
185+
}
186+
}
187+
188+
def merge(
189+
buffer1: Option[EnvelopeBuffer],
190+
buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = {
191+
(buffer1, buffer2) match {
192+
case (Some(b1), Some(b2)) => Some(b1.merge(b2))
193+
case (Some(_), None) => buffer1
194+
case (None, Some(_)) => buffer2
195+
case (None, None) => None
196+
}
197+
}
198+
199+
def finish(reduction: Option[EnvelopeBuffer]): Box2D = {
200+
reduction match {
201+
case Some(b) => new Box2D(b.minX, b.minY, b.maxX, b.maxY)
202+
case None => null
203+
}
204+
}
205+
206+
def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = Encoders.product[Option[EnvelopeBuffer]]
207+
208+
def outputEncoder: ExpressionEncoder[Box2D] = outputSerde
209+
210+
def zero: Option[EnvelopeBuffer] = None
211+
}
212+
167213
/**
168214
* Return the polygon intersection of all Polygon in the given column
169215
*/

spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package org.apache.sedona.sql
2020

21+
import org.apache.sedona.common.geometryObjects.Box2D
2122
import org.apache.spark.sql.DataFrame
2223
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory, Polygon}
2324

@@ -73,6 +74,48 @@ class aggregateFunctionTestScala extends TestBaseScala {
7374
assert(env.getMaxY == 4.0)
7475
}
7576

77+
it("Passed ST_Extent") {
78+
val df = sparkSession.sql(
79+
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES ('POINT (1 2)'), ('POINT (4 5)'), ('LINESTRING (-3 0, 0 0)') AS t(wkt)")
80+
df.createOrReplaceTempView("t")
81+
val bbox =
82+
sparkSession.sql("SELECT ST_Extent(geom) AS bbox FROM t").take(1)(0).getAs[Box2D](0)
83+
assert(bbox.getXMin == -3.0)
84+
assert(bbox.getYMin == 0.0)
85+
assert(bbox.getXMax == 4.0)
86+
assert(bbox.getYMax == 5.0)
87+
}
88+
89+
it("ST_Extent returns null over zero rows") {
90+
val emptyDf = sparkSession.sql(
91+
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (NULL) AS t(wkt) WHERE wkt IS NOT NULL")
92+
emptyDf.createOrReplaceTempView("empty_extent")
93+
val result = sparkSession.sql("SELECT ST_Extent(geom) FROM empty_extent")
94+
assert(result.take(1)(0).get(0) == null)
95+
}
96+
97+
it("ST_Extent returns null when all inputs are null or empty") {
98+
val nullDf = sparkSession.sql(
99+
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (CAST(NULL AS STRING)), ('POINT EMPTY'), ('POLYGON EMPTY') AS t(wkt)")
100+
nullDf.createOrReplaceTempView("null_extent")
101+
val result = sparkSession.sql("SELECT ST_Extent(geom) FROM null_extent")
102+
assert(result.take(1)(0).get(0) == null)
103+
}
104+
105+
it("ST_Extent ignores null and empty rows mixed with valid geometries") {
106+
val mixedDf = sparkSession.sql(
107+
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (CAST(NULL AS STRING)), ('POINT EMPTY'), ('POINT (10 20)'), ('POINT (-5 -5)') AS t(wkt)")
108+
mixedDf.createOrReplaceTempView("mixed_extent")
109+
val bbox = sparkSession
110+
.sql("SELECT ST_Extent(geom) FROM mixed_extent")
111+
.take(1)(0)
112+
.getAs[Box2D](0)
113+
assert(bbox.getXMin == -5.0)
114+
assert(bbox.getYMin == -5.0)
115+
assert(bbox.getXMax == 10.0)
116+
assert(bbox.getYMax == 20.0)
117+
}
118+
76119
it("Passed ST_Union_aggr") {
77120

78121
var polygonCsvDf = sparkSession.read

0 commit comments

Comments
 (0)