Skip to content

Commit a30fb9b

Browse files
authored
[GH-2830] Geography BC join strategy for newly added predicates (#2871)
1 parent 1361fed commit a30fb9b

8 files changed

Lines changed: 603 additions & 30 deletions

File tree

common/src/main/java/org/apache/sedona/common/sphere/Haversine.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ public class Haversine {
4040
*/
4141
public static final double AVG_EARTH_RADIUS = 6371008.0;
4242

43+
/**
44+
* Polar radius of the WGS-84 spheroid, in meters. Used as a sphere radius when expanding
45+
* envelopes so that the expansion upper-bounds both spherical and spheroidal distances.
46+
*/
47+
public static final double EARTH_POLAR_RADIUS = 6357000.0;
48+
4349
public static double distance(Geometry geom1, Geometry geom2, double avg_earth_radius) {
4450
Coordinate coordinate1 =
4551
geom1.getGeometryType().equals("Point")

common/src/test/java/org/apache/sedona/common/sphere/HaversineEnvelopeTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import org.locationtech.jts.geom.Point;
2727

2828
public class HaversineEnvelopeTest {
29-
private static final int SPHERE_RADIUS = 6357000;
29+
private static final double SPHERE_RADIUS = Haversine.EARTH_POLAR_RADIUS;
3030
private static final GeometryFactory factory = new GeometryFactory();
3131

3232
@Test

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,12 @@ case class BroadcastIndexJoinExec(
142142
// asymmetric and we cannot rely on a SpatialPredicateEvaluator.
143143
private lazy val refinerSwap: Boolean = indexBuildSide != windowJoinSide
144144

145-
private def newRefiner(): JoinRefiner =
146-
if (geographyShape) new GeographyContainsRefiner(refinerSwap)
147-
else new JtsRefiner(evaluator)
145+
private def newRefiner(): JoinRefiner = {
146+
if (geographyShape) {
147+
if (distance.isDefined) new GeographyDistanceRefiner(refinerSwap)
148+
else new GeographyRelationRefiner(spatialPredicate, refinerSwap)
149+
} else new JtsRefiner(evaluator)
150+
}
148151

149152
private def innerJoin(
150153
streamIter: Iterator[(Geometry, UnsafeRow)],
@@ -288,6 +291,36 @@ case class BroadcastIndexJoinExec(
288291
streamResultsRaw: RDD[UnsafeRow],
289292
boundStreamShape: Expression) = {
290293
distance match {
294+
case Some(distanceExpression) if geographyShape =>
295+
val boundDistance =
296+
BindReferences.bindReference(distanceExpression, streamed.output)
297+
// When the broadcast side already expanded its envelope by `d` (the
298+
// literal-radius case, where the planner forwards the same distance to
299+
// both sides so the per-row radius is also available here for the
300+
// refiner), keep the stream envelope unexpanded. The coarse filter then
301+
// matches the geometry path: `expand(build, d) ∩ stream`, not the wider
302+
// `expand(build, d) ∩ expand(stream, d)`. When only the stream side
303+
// received the distance (per-row radius bound to the stream side), the
304+
// build side is unexpanded and we still need to expand the stream
305+
// envelope to ensure the index returns all candidates within `d`.
306+
val streamSideExpands = broadcast.distance.isEmpty
307+
streamResultsRaw.map(row => {
308+
val serialized = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
309+
if (serialized == null) {
310+
(null, row)
311+
} else {
312+
val geog = GeographyWKBSerializer.deserialize(serialized)
313+
val radius = boundDistance.eval(row).asInstanceOf[Double]
314+
val baseEnvelope = JoinedGeometry.geographyToEnvelopeGeometry(geog)
315+
val shape = if (streamSideExpands) {
316+
JoinedGeometry.geometryToExpandedEnvelope(baseEnvelope, radius, isGeography = true)
317+
} else {
318+
baseEnvelope
319+
}
320+
shape.setUserData(GeographyJoinShape(geog, row, radius))
321+
(shape, row)
322+
}
323+
})
291324
case Some(distanceExpression) =>
292325
streamResultsRaw.map(row => {
293326
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
@@ -348,8 +381,9 @@ case class BroadcastIndexJoinExec(
348381
* Per-iter helper that decides whether a candidate from the broadcast index actually satisfies
349382
* the spatial predicate, and unpacks the candidate's `userData` into the output row.
350383
*
351-
* Two implementations: `JtsRefiner` for the planar JTS path (existing behaviour, byte-equivalent
352-
* to the previous inline code), and `GeographyContainsRefiner` for the new Geography-on-S2 path.
384+
* Three implementations: `JtsRefiner` for the planar JTS path, `GeographyRelationRefiner` for
385+
* non-distance Geography predicates (CONTAINS, WITHIN, INTERSECTS, EQUALS), and
386+
* `GeographyDistanceRefiner` for ST_DWithin on Geography.
353387
*/
354388
private sealed trait JoinRefiner {
355389
def matches(candidate: Geometry, streamShape: Geometry): Boolean
@@ -368,22 +402,51 @@ private final class JtsRefiner(evaluator: SpatialPredicateEvaluator) extends Joi
368402
}
369403

370404
/**
371-
* Refines candidates with `Functions.contains` (S2 spherical containment). Caching of the per-
372-
* Geography S2 ShapeIndex happens inside `WKBGeography.getShapeIndexGeography()`, so we do not
373-
* need a JTS-style PreparedGeometry cache here — the build side keeps the same Geography JVM
374-
* instances for the lifetime of the broadcast.
405+
* Refines candidates with the appropriate `org.apache.sedona.common.geography.Functions`
406+
* predicate (CONTAINS / WITHIN / INTERSECTS / EQUALS). Caching of the per-Geography S2 ShapeIndex
407+
* happens inside `WKBGeography.getShapeIndexGeography()`, so we do not need a JTS-style
408+
* PreparedGeometry cache here — the build side keeps the same Geography JVM instances for the
409+
* lifetime of the broadcast. `swap` flips operand order when the build side does not correspond
410+
* to the predicate's left-hand argument (handles the `RIGHT JOIN` / right-broadcast case for
411+
* asymmetric predicates).
375412
*/
376-
private final class GeographyContainsRefiner(swap: Boolean) extends JoinRefiner {
413+
private final class GeographyRelationRefiner(predicate: SpatialPredicate, swap: Boolean)
414+
extends JoinRefiner {
377415
override def matches(candidate: Geometry, streamShape: Geometry): Boolean = {
378416
val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape]
379417
val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape]
380-
if (swap) {
381-
GeographyFunctions.contains(streamShapeData.geog, buildShape.geog)
382-
} else {
383-
GeographyFunctions.contains(buildShape.geog, streamShapeData.geog)
418+
val (a, b) =
419+
if (swap) (streamShapeData.geog, buildShape.geog)
420+
else (buildShape.geog, streamShapeData.geog)
421+
predicate match {
422+
case SpatialPredicate.CONTAINS => GeographyFunctions.contains(a, b)
423+
case SpatialPredicate.WITHIN => GeographyFunctions.within(a, b)
424+
case SpatialPredicate.INTERSECTS => GeographyFunctions.intersects(a, b)
425+
case SpatialPredicate.EQUALS => GeographyFunctions.equals(a, b)
426+
case other =>
427+
throw new UnsupportedOperationException(
428+
s"Geography broadcast spatial join does not support predicate $other")
384429
}
385430
}
386431

387432
override def unpackRow(candidate: Geometry): UnsafeRow =
388433
candidate.getUserData.asInstanceOf[GeographyJoinShape].row
389434
}
435+
436+
/**
437+
* Refines candidates for ST_DWithin on Geography. The per-row distance threshold is carried on
438+
* the stream-side `GeographyJoinShape.radius`, populated when the stream shape is built.
439+
*/
440+
private final class GeographyDistanceRefiner(swap: Boolean) extends JoinRefiner {
441+
override def matches(candidate: Geometry, streamShape: Geometry): Boolean = {
442+
val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape]
443+
val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape]
444+
val (a, b) =
445+
if (swap) (streamShapeData.geog, buildShape.geog)
446+
else (buildShape.geog, streamShapeData.geog)
447+
GeographyFunctions.dWithin(a, b, streamShapeData.radius)
448+
}
449+
450+
override def unpackRow(candidate: Geometry): UnsafeRow =
451+
candidate.getUserData.asInstanceOf[GeographyJoinShape].row
452+
}

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,21 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
228228
isGeography = false,
229229
extraCondition,
230230
geographyShape = geographyShape))
231-
// ST_Intersects / ST_Within / ST_Equals on Geography have no broadcast index path
232-
// yet (the Geography refiner is ST_Contains-specific), so gate Geography inputs and
233-
// let them fall back to row-by-row evaluation.
231+
// ST_Intersects / ST_Within / ST_Equals on Geography route through the
232+
// Geography-aware refiner (`GeographyRelationRefiner`); geometry inputs continue
233+
// through `inferredJoinDetection`.
234+
case ST_Intersects(Seq(leftShape, rightShape))
235+
if isGeographyInput(leftShape) || isGeographyInput(rightShape) =>
236+
Some(
237+
JoinQueryDetection(
238+
left,
239+
right,
240+
leftShape,
241+
rightShape,
242+
SpatialPredicate.INTERSECTS,
243+
isGeography = false,
244+
extraCondition,
245+
geographyShape = true))
234246
case ST_Intersects(Seq(leftShape, rightShape)) =>
235247
inferredJoinDetection(
236248
left,
@@ -239,6 +251,18 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
239251
rightShape,
240252
SpatialPredicate.INTERSECTS,
241253
extraCondition)
254+
case ST_Within(Seq(leftShape, rightShape))
255+
if isGeographyInput(leftShape) || isGeographyInput(rightShape) =>
256+
Some(
257+
JoinQueryDetection(
258+
left,
259+
right,
260+
leftShape,
261+
rightShape,
262+
SpatialPredicate.WITHIN,
263+
isGeography = false,
264+
extraCondition,
265+
geographyShape = true))
242266
case ST_Within(Seq(leftShape, rightShape)) =>
243267
inferredJoinDetection(
244268
left,
@@ -247,6 +271,18 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
247271
rightShape,
248272
SpatialPredicate.WITHIN,
249273
extraCondition)
274+
case ST_Equals(Seq(leftShape, rightShape))
275+
if isGeographyInput(leftShape) || isGeographyInput(rightShape) =>
276+
Some(
277+
JoinQueryDetection(
278+
left,
279+
right,
280+
leftShape,
281+
rightShape,
282+
SpatialPredicate.EQUALS,
283+
isGeography = false,
284+
extraCondition,
285+
geographyShape = true))
250286
case ST_Equals(Seq(leftShape, rightShape)) =>
251287
inferredJoinDetection(
252288
left,
@@ -260,16 +296,22 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
260296
case pred: RS_Predicate =>
261297
getRasterJoinDetection(left, right, pred, extraCondition)
262298
case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
299+
val geographyShape =
300+
isGeographyInput(leftShape) || isGeographyInput(rightShape)
263301
Some(
264302
JoinQueryDetection(
265303
left,
266304
right,
267305
leftShape,
268306
rightShape,
269307
SpatialPredicate.INTERSECTS,
270-
isGeography = false,
308+
isGeography = geographyShape,
271309
condition,
272-
Some(distance)))
310+
Some(distance),
311+
geographyShape = geographyShape))
312+
// Note: the 4-arg ST_DWithin is geometry-only; on Geography input the Spark
313+
// analyzer rejects the call before reaching this matcher, so no Geography
314+
// guard is needed here.
273315
case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) =>
274316
val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
275317
Some(
@@ -484,7 +526,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
484526
}
485527
} else {
486528
queryDetection match {
487-
// Geography ST_Contains has no partition/range path — fall back to row-by-row.
529+
// Geography predicates (ST_Contains/Within/Intersects/Equals/DWithin) have no
530+
// partition/range path — fall back to row-by-row.
488531
case Some(detection) if detection.geographyShape =>
489532
Nil
490533
case Some(
@@ -843,9 +886,33 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
843886
.map { distanceExpr =>
844887
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
845888
case Some(side) =>
846-
if (broadcastSide.get == side) (Some(distanceExpr), None)
847-
else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
848-
else (None, Some(distanceExpr))
889+
if (geographyShape) {
890+
// Geography distance joins read the per-row radius from the stream-side
891+
// GeographyJoinShape inside GeographyDistanceRefiner, so the radius MUST
892+
// be available on the streamed side. The stream-side expression is later
893+
// re-bound against `streamed.output` in BroadcastIndexJoinExec, so we
894+
// can only forward it to the stream side when it is either a literal
895+
// (no references) or already bound to the streamed side.
896+
if (distanceExpr.references.isEmpty) {
897+
// Literal: keep build-side expansion AND populate stream-side radius.
898+
(Some(distanceExpr), Some(distanceExpr))
899+
} else if (broadcastSide.get == side) {
900+
// Non-literal expression bound to the broadcast/index side cannot be
901+
// re-bound against streamed.output. Reject up front rather than
902+
// planning a broadcast geography join that will fail at execution.
903+
throw new UnsupportedOperationException(
904+
"Geography distance broadcast joins do not support non-literal " +
905+
"distance expressions bound to the broadcast/index side; bind " +
906+
"the distance expression to the streamed side or use a literal.")
907+
} else {
908+
// Bound to the streamed side: stream-only (no build-side expansion).
909+
(None, Some(distanceExpr))
910+
}
911+
} else {
912+
if (broadcastSide.get == side) (Some(distanceExpr), None)
913+
else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
914+
else (None, Some(distanceExpr))
915+
}
849916
case _ =>
850917
throw new IllegalArgumentException(
851918
"Distance expression must be bound to one side of the join")

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ import org.locationtech.jts.geom.{Envelope, Geometry, GeometryFactory}
2525

2626
/**
2727
* Payload stored in `userData` on each Geography index entry. Carries both the deserialized
28-
* Geography (for S2 predicate refinement) and the original row (for the join output).
28+
* Geography (for S2 predicate refinement) and the original row (for the join output). For
29+
* `ST_DWithin` joins, `radius` carries the per-row distance threshold from the row that produced
30+
* this shape; for non-distance predicates it remains 0.0.
2931
*/
30-
case class GeographyJoinShape(geog: Geography, row: UnsafeRow)
32+
case class GeographyJoinShape(geog: Geography, row: UnsafeRow, radius: Double = 0.0)
3133

3234
/**
3335
* Utility functions for generating geometries for spatial join.
@@ -89,9 +91,8 @@ object JoinedGeometry {
8991
* in meter
9092
*/
9193
private def expandEnvelopeForGeography(envelope: Envelope, distance: Double): Envelope = {
92-
// Here we use the polar radius of the spheroid as the radius of the sphere, so that the expanded
93-
// envelope will work for both spherical and spheroidal distances.
94-
val sphereRadius = 6357000.0
95-
Haversine.expandEnvelope(envelope, distance, sphereRadius)
94+
// Use the polar radius of the spheroid as the radius of the sphere so that the expanded
95+
// envelope upper-bounds both spherical and spheroidal distances.
96+
Haversine.expandEnvelope(envelope, distance, Haversine.EARTH_POLAR_RADIUS)
9697
}
9798
}

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ case class SpatialIndexExec(
5656
val boundShape = BindReferences.bindReference(shape, child.output)
5757
val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
5858
val spatialRDD = distance match {
59+
case Some(distanceExpression) if geographyShape =>
60+
toExpandedGeographyEnvelopeRDD(
61+
resultRaw,
62+
boundShape,
63+
BindReferences.bindReference(distanceExpression, child.output))
5964
case Some(distanceExpression) =>
6065
toExpandedEnvelopeRDD(
6166
resultRaw,

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,38 @@ trait TraitJoinQueryBase {
135135
spatialRdd
136136
}
137137

138+
/**
139+
* Geography variant of [[toExpandedEnvelopeRDD]]. Each row becomes a JTS geometry whose
140+
* envelope is the Geography's lat/lng bounding rectangle expanded by `boundRadius` meters using
141+
* the Haversine-based polar-radius approximation in
142+
* [[JoinedGeometry.geometryToExpandedEnvelope]] (with `isGeography=true`). The Geography object
143+
* and per-row radius are carried alongside the original row in `userData` via
144+
* [[GeographyJoinShape]] so the join executor can perform S2-based ST_DWithin refinement.
145+
*/
146+
def toExpandedGeographyEnvelopeRDD(
147+
rdd: RDD[UnsafeRow],
148+
shapeExpression: Expression,
149+
boundRadius: Expression): SpatialRDD[Geometry] = {
150+
val spatialRdd = new SpatialRDD[Geometry]
151+
spatialRdd.setRawSpatialRDD(rdd
152+
.flatMap { x =>
153+
val geogBytes = shapeExpression.eval(x).asInstanceOf[Array[Byte]]
154+
if (geogBytes == null) {
155+
None
156+
} else {
157+
val geog = GeographyWKBSerializer.deserialize(geogBytes)
158+
val distance = boundRadius.eval(x).asInstanceOf[Double]
159+
val baseEnvelope = JoinedGeometry.geographyToEnvelopeGeometry(geog)
160+
val expandedEnvelope =
161+
JoinedGeometry.geometryToExpandedEnvelope(baseEnvelope, distance, isGeography = true)
162+
expandedEnvelope.setUserData(GeographyJoinShape(geog, x.copy, distance))
163+
Some(expandedEnvelope)
164+
}
165+
}
166+
.toJavaRDD())
167+
spatialRdd
168+
}
169+
138170
def doSpatialPartitioning(
139171
dominantShapes: SpatialRDD[Geometry],
140172
followerShapes: SpatialRDD[Geometry],

0 commit comments

Comments
 (0)