Skip to content

Commit e27e024

Browse files
authored
[GH-2939] Box2D spatial join: ST_BoxIntersects / ST_BoxContains (#2953)
1 parent 29109af commit e27e024

5 files changed

Lines changed: 299 additions & 24 deletions

File tree

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,10 @@ case class BroadcastIndexJoinExec(
323323
})
324324
case Some(distanceExpression) =>
325325
streamResultsRaw.map(row => {
326-
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
327-
if (geom == null) {
326+
val geometry = TraitJoinQueryBase.shapeToGeometry(boundStreamShape, row)
327+
if (geometry == null) {
328328
(null, row)
329329
} else {
330-
val geometry = GeometrySerializer.deserialize(geom)
331330
val radius = BindReferences
332331
.bindReference(distanceExpression, streamed.output)
333332
.eval(row)
@@ -351,23 +350,21 @@ case class BroadcastIndexJoinExec(
351350
})
352351
case _ =>
353352
streamResultsRaw.map(row => {
354-
val serializedObject = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
355-
if (serializedObject == null) {
356-
(null, row)
357-
} else {
358-
val shape = if (isRasterPredicate) {
359-
if (boundStreamShape.dataType.isInstanceOf[RasterUDT]) {
360-
val raster = RasterSerializer.deserialize(serializedObject)
361-
JoinedGeometryRaster.rasterToWGS84Envelope(raster)
362-
} else {
363-
val geom = GeometrySerializer.deserialize(serializedObject)
364-
JoinedGeometryRaster.geometryToWGS84Envelope(geom)
365-
}
353+
val shape = if (isRasterPredicate) {
354+
// Raster path keeps the legacy bytes-only handling — Box2D doesn't apply here.
355+
val serializedObject = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
356+
if (serializedObject == null) null
357+
else if (boundStreamShape.dataType.isInstanceOf[RasterUDT]) {
358+
val raster = RasterSerializer.deserialize(serializedObject)
359+
JoinedGeometryRaster.rasterToWGS84Envelope(raster)
366360
} else {
367-
GeometrySerializer.deserialize(serializedObject)
361+
val geom = GeometrySerializer.deserialize(serializedObject)
362+
JoinedGeometryRaster.geometryToWGS84Envelope(geom)
368363
}
369-
(shape, row)
364+
} else {
365+
TraitJoinQueryBase.shapeToGeometry(boundStreamShape, row)
370366
}
367+
(shape, row)
371368
})
372369
}
373370
}

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,31 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
291291
rightShape,
292292
SpatialPredicate.EQUALS,
293293
extraCondition)
294+
// Box2D predicates. Both shape expressions resolve to Box2DUDT; the executors
295+
// materialise each Box2D as a rectangular Polygon so the existing partitioner /
296+
// R-tree / refine machinery applies unchanged. ST_BoxContains is closed-interval
297+
// containment, so it maps to SpatialPredicate.COVERS (JTS `contains` would reject
298+
// edge-touching cases).
299+
case ST_BoxIntersects(Seq(leftShape, rightShape)) =>
300+
Some(
301+
JoinQueryDetection(
302+
left,
303+
right,
304+
leftShape,
305+
rightShape,
306+
SpatialPredicate.INTERSECTS,
307+
isGeography = false,
308+
extraCondition))
309+
case ST_BoxContains(Seq(leftShape, rightShape)) =>
310+
Some(
311+
JoinQueryDetection(
312+
left,
313+
right,
314+
leftShape,
315+
rightShape,
316+
SpatialPredicate.COVERS,
317+
isGeography = false,
318+
extraCondition))
294319
case pred: ST_Predicate =>
295320
getJoinDetection(left, right, pred, extraCondition)
296321
case pred: RS_Predicate =>

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ case class OptimizableJoinCondition(left: LogicalPlan, right: LogicalPlan) {
6363
expression match {
6464
case _: ST_Intersects | _: ST_Contains | _: ST_Covers | _: ST_Within | _: ST_CoveredBy |
6565
_: ST_Overlaps | _: ST_Touches | _: ST_Equals | _: ST_Crosses | _: ST_KNN |
66-
_: RS_Predicate =>
66+
_: ST_BoxIntersects | _: ST_BoxContains | _: RS_Predicate =>
6767
val leftShape = expression.children.head
6868
val rightShape = expression.children(1)
6969
ExpressionUtils.matchExpressionsToPlans(leftShape, rightShape, left, right).isDefined

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818
*/
1919
package org.apache.spark.sql.sedona_sql.strategy.join
2020

21+
import org.apache.sedona.common.Constructors
2122
import org.apache.sedona.common.S2Geography.GeographyWKBSerializer
2223
import org.apache.sedona.core.spatialRDD.SpatialRDD
2324
import org.apache.sedona.core.utils.SedonaConf
2425
import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer}
2526
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.sql.catalyst.InternalRow
2628
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
2729
import org.apache.spark.sql.execution.SparkPlan
28-
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
29-
import org.locationtech.jts.geom.Geometry
30+
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, RasterUDT}
31+
import org.locationtech.jts.geom.{Geometry, GeometryFactory}
3032

3133
trait TraitJoinQueryBase {
3234
self: SparkPlan =>
@@ -49,8 +51,10 @@ trait TraitJoinQueryBase {
4951
spatialRdd.setRawSpatialRDD(
5052
rdd
5153
.map { x =>
52-
val shape =
53-
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
54+
// Null shape rows materialise as an empty geometry collection so they carry the row
55+
// payload through the partitioner / index without participating in any spatial match
56+
// — mirrors the pre-existing `GeometrySerializer.deserialize(null)` fallback.
57+
val shape = TraitJoinQueryBase.shapeToGeometryOrEmpty(shapeExpression, x)
5458
shape.setUserData(x.copy)
5559
shape
5660
}
@@ -123,8 +127,7 @@ trait TraitJoinQueryBase {
123127
spatialRdd.setRawSpatialRDD(
124128
rdd
125129
.map { x =>
126-
val shape =
127-
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
130+
val shape = TraitJoinQueryBase.shapeToGeometryOrEmpty(shapeExpression, x)
128131
val distance = boundRadius.eval(x).asInstanceOf[Double]
129132
val expandedEnvelope =
130133
JoinedGeometry.geometryToExpandedEnvelope(shape, distance, isGeography)
@@ -178,3 +181,61 @@ trait TraitJoinQueryBase {
178181
}
179182
}
180183
}
184+
185+
object TraitJoinQueryBase {
186+
187+
/**
188+
* Materialise a shape column value as a JTS [[Geometry]]. Box2D-typed columns are turned into
189+
* the closed rectangular polygon implied by their `(xmin, ymin, xmax, ymax)` bounds; all other
190+
* shape columns are deserialised from the Sedona geometry binary form.
191+
*
192+
* Producing a JTS rectangle here lets the rest of the join machinery — partitioner, R-tree
193+
* `IndexBuilder`, refine evaluator — stay shape-agnostic. JTS already short-circuits
194+
* rectangle-rectangle predicates (`Polygon.isRectangle` triggers `RectangleIntersects` /
195+
* `RectangleContains`), so a `ST_BoxIntersects` join naturally pays only the four-double
196+
* envelope comparison at refine time.
197+
*
198+
* Inverted Box2D bounds (`xmin > xmax` / `ymin > ymax`) are rejected with the same
199+
* `IllegalArgumentException` raised by `Predicates.boxIntersects` / `boxContains`. Inverted
200+
* bounds have no defined planar meaning today (they are reserved for future
201+
* antimeridian-wraparound semantics on Geography bboxes) and would silently mis-prune the
202+
* R-tree if accepted here.
203+
*
204+
* Returns `null` when the shape column evaluates to NULL; the caller is expected to either skip
205+
* the row or substitute an empty geometry.
206+
*/
207+
def shapeToGeometry(shapeExpression: Expression, row: InternalRow): Geometry = {
208+
val evaluated = shapeExpression.eval(row)
209+
if (evaluated == null) {
210+
null
211+
} else
212+
shapeExpression.dataType match {
213+
case _: Box2DUDT =>
214+
val box = evaluated.asInstanceOf[InternalRow]
215+
val xmin = box.getDouble(0)
216+
val ymin = box.getDouble(1)
217+
val xmax = box.getDouble(2)
218+
val ymax = box.getDouble(3)
219+
if (xmin > xmax || ymin > ymax) {
220+
throw new IllegalArgumentException(
221+
"Box2D join input has inverted bounds (xmin > xmax or ymin > ymax). " +
222+
"Planar Box2D predicates require ordered intervals; inverted bounds are " +
223+
"reserved for future antimeridian wraparound semantics.")
224+
}
225+
Constructors.polygonFromEnvelope(xmin, ymin, xmax, ymax)
226+
case _ =>
227+
GeometrySerializer.deserialize(evaluated.asInstanceOf[Array[Byte]])
228+
}
229+
}
230+
231+
/**
232+
* Convenience wrapper that substitutes an empty geometry collection for NULL shapes. Used by
233+
* the partitioned-RDD path where each row must carry a non-null geometry so the original
234+
* `UnsafeRow` survives to outer-join output; spatial predicates against the empty geometry
235+
* produce no matches, matching the legacy `GeometrySerializer.deserialize(null)` behaviour.
236+
*/
237+
def shapeToGeometryOrEmpty(shapeExpression: Expression, row: InternalRow): Geometry = {
238+
val shape = shapeToGeometry(shapeExpression, row)
239+
if (shape == null) new GeometryFactory().createGeometryCollection() else shape
240+
}
241+
}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sedona.sql
20+
21+
import org.apache.spark.sql.DataFrame
22+
import org.apache.spark.sql.functions.{broadcast, expr}
23+
import org.apache.spark.sql.sedona_sql.strategy.join.{BroadcastIndexJoinExec, RangeJoinExec}
24+
25+
class Box2DJoinSuite extends TestBaseScala {
26+
27+
import Box2DJoinSuite.TestBox
28+
29+
/**
30+
* Three left-side boxes and three right-side boxes wired so we can predict exact result sizes:
31+
*
32+
* - L1=(0,0,10,10) R1=(5,5,15,15) — overlapping
33+
* - L1=(0,0,10,10) R2=(2,2,8,8) — R2 fully inside L1
34+
* - L2=(0,0,10,10) R1=(5,5,15,15) — overlapping
35+
* - L2=(0,0,10,10) R2=(2,2,8,8) — R2 fully inside L2
36+
* - L3 and R3 are disjoint from everything else; (L3,R3) is itself disjoint.
37+
*
38+
* Intersection-pair count: 4. Containment-pair count: 2 (L1⊇R2, L2⊇R2).
39+
*/
40+
private def leftBoxes: DataFrame = {
41+
import sparkSession.implicits._
42+
Seq(TestBox(1, 0, 0, 10, 10), TestBox(2, 0, 0, 10, 10), TestBox(3, 20, 20, 30, 30))
43+
.toDF("id", "xmin", "ymin", "xmax", "ymax")
44+
.selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box")
45+
}
46+
47+
private def rightBoxes: DataFrame = {
48+
import sparkSession.implicits._
49+
Seq(TestBox(11, 5, 5, 15, 15), TestBox(12, 2, 2, 8, 8), TestBox(13, 40, 40, 50, 50))
50+
.toDF("id", "xmin", "ymin", "xmax", "ymax")
51+
.selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box")
52+
}
53+
54+
describe("Box2D spatial join") {
55+
56+
it("ST_BoxIntersects: broadcast index join produces correct pairs") {
57+
val df = leftBoxes
58+
.alias("L")
59+
.join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)"))
60+
val plan = df.queryExecution.sparkPlan
61+
assert(
62+
plan.collect { case b: BroadcastIndexJoinExec => b }.size == 1,
63+
"Expected BroadcastIndexJoinExec in the plan")
64+
assert(df.count() == 4)
65+
}
66+
67+
it("ST_BoxIntersects: argument order is symmetric") {
68+
val swapped = leftBoxes
69+
.alias("L")
70+
.join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(R.box, L.box)"))
71+
assert(swapped.count() == 4)
72+
assert(swapped.queryExecution.sparkPlan.collect { case b: BroadcastIndexJoinExec =>
73+
b
74+
}.size == 1)
75+
}
76+
77+
it("ST_BoxContains: broadcast index join uses COVERS semantics") {
78+
val df = leftBoxes
79+
.alias("L")
80+
.join(broadcast(rightBoxes.alias("R")), expr("ST_BoxContains(L.box, R.box)"))
81+
assert(df.queryExecution.sparkPlan.collect { case b: BroadcastIndexJoinExec =>
82+
b
83+
}.size == 1)
84+
assert(df.count() == 2)
85+
}
86+
87+
it("ST_BoxContains: edge-touching boxes count (closed-interval semantics)") {
88+
// R contained in L sharing an edge: ST_BoxContains is closed-interval, so this matches.
89+
// JTS Polygon.contains would reject (strict-interior), JTS Polygon.covers accepts; the
90+
// detector maps ST_BoxContains → SpatialPredicate.COVERS specifically for this case.
91+
import sparkSession.implicits._
92+
val outer = Seq(TestBox(1, 0, 0, 10, 10))
93+
.toDF("id", "xmin", "ymin", "xmax", "ymax")
94+
.selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box")
95+
// edge-sharing box: same xmax, shares the right edge with outer.
96+
val inner = Seq(TestBox(11, 5, 5, 10, 10))
97+
.toDF("id", "xmin", "ymin", "xmax", "ymax")
98+
.selectExpr("id", "ST_MakeBox2D(ST_Point(xmin, ymin), ST_Point(xmax, ymax)) AS box")
99+
val df = outer
100+
.alias("O")
101+
.join(broadcast(inner.alias("I")), expr("ST_BoxContains(O.box, I.box)"))
102+
assert(df.count() == 1, "Closed-interval containment must include edge-touching boxes")
103+
}
104+
105+
it("ST_BoxIntersects: non-broadcast range join produces the same count") {
106+
val df = leftBoxes
107+
.alias("L")
108+
.join(rightBoxes.alias("R"), expr("ST_BoxIntersects(L.box, R.box)"))
109+
assert(
110+
df.queryExecution.sparkPlan.collect { case r: RangeJoinExec => r }.size == 1,
111+
"Expected RangeJoinExec in the plan")
112+
assert(df.count() == 4)
113+
}
114+
115+
it("Null Box2D rows are safe and produce no matches") {
116+
// A null shape on either side must not crash the executor and must not contribute matches
117+
// (mirrors the existing GeometrySerializer.deserialize(null) → empty-collection fallback).
118+
import sparkSession.implicits._
119+
val withNullLeft = leftBoxes
120+
.selectExpr("id", "box AS box")
121+
.union(Seq((99, null.asInstanceOf[org.apache.sedona.common.geometryObjects.Box2D]))
122+
.toDF("id", "box"))
123+
val df = withNullLeft
124+
.alias("L")
125+
.join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)"))
126+
assert(df.count() == 4) // unchanged from the non-null fixture
127+
// Range join path (no broadcast) also tolerates nulls.
128+
val rangeDf = withNullLeft
129+
.alias("L")
130+
.join(rightBoxes.alias("R"), expr("ST_BoxIntersects(L.box, R.box)"))
131+
assert(rangeDf.count() == 4)
132+
}
133+
134+
it("Inverted Box2D bounds in a join throw IllegalArgumentException") {
135+
import sparkSession.implicits._
136+
// Construct an inverted Box2D directly via the Java constructor (the SQL ST_MakeBox2D
137+
// doesn't validate, so this is how a stored column with inverted bounds would look).
138+
val invertedLeft =
139+
Seq((1, new org.apache.sedona.common.geometryObjects.Box2D(10.0, 0.0, 0.0, 10.0)))
140+
.toDF("id", "box")
141+
val df = invertedLeft
142+
.alias("L")
143+
.join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)"))
144+
// Confirm the join is actually planned as BroadcastIndexJoinExec so the throw originates
145+
// from the join-side `shapeToGeometry` validation, not from a row-by-row fallback that
146+
// also happens to throw via `Predicates.boxIntersects`.
147+
assert(
148+
df.queryExecution.sparkPlan.collect { case b: BroadcastIndexJoinExec => b }.size == 1,
149+
"Expected BroadcastIndexJoinExec — without it the test could pass via row-by-row " +
150+
"predicate evaluation, hiding a regression in join optimization")
151+
val ex = intercept[org.apache.spark.SparkException](df.collect())
152+
val cause = Iterator
153+
.iterate(ex: Throwable)(_.getCause)
154+
.takeWhile(_ != null)
155+
.find(_.isInstanceOf[IllegalArgumentException])
156+
assert(cause.isDefined, s"Expected IllegalArgumentException in cause chain, got: $ex")
157+
assert(cause.get.getMessage.contains("inverted bounds"))
158+
}
159+
160+
it("Result is equivalent to ST_Intersects on the Box2D-as-polygon envelopes") {
161+
val viaBox = leftBoxes
162+
.alias("L")
163+
.join(broadcast(rightBoxes.alias("R")), expr("ST_BoxIntersects(L.box, R.box)"))
164+
.selectExpr("L.id AS l", "R.id AS r")
165+
.orderBy("l", "r")
166+
.collect()
167+
.toSeq
168+
169+
// ST_GeomFromBox2D is the function-form equivalent of `CAST(box AS geometry)`. The cast
170+
// syntax requires the Sedona SQL parser extension; this suite runs under the common test
171+
// base, which doesn't wire that extension, so we go through the function form here.
172+
val asPolygons = leftBoxes
173+
.selectExpr("id", "ST_GeomFromBox2D(box) AS g")
174+
.alias("L")
175+
.join(
176+
broadcast(rightBoxes.selectExpr("id", "ST_GeomFromBox2D(box) AS g").alias("R")),
177+
expr("ST_Intersects(L.g, R.g)"))
178+
.selectExpr("L.id AS l", "R.id AS r")
179+
.orderBy("l", "r")
180+
.collect()
181+
.toSeq
182+
183+
assert(viaBox == asPolygons)
184+
}
185+
}
186+
187+
}
188+
189+
object Box2DJoinSuite {
190+
// Top-level case class so Spark's encoder doesn't need an outer-class reference.
191+
case class TestBox(id: Int, xmin: Double, ymin: Double, xmax: Double, ymax: Double)
192+
}

0 commit comments

Comments
 (0)