diff --git a/common/src/main/java/org/apache/sedona/common/S2Geography/WkbS2Shape.java b/common/src/main/java/org/apache/sedona/common/S2Geography/WkbS2Shape.java index d1ed1f80723..4fa108a5465 100644 --- a/common/src/main/java/org/apache/sedona/common/S2Geography/WkbS2Shape.java +++ b/common/src/main/java/org/apache/sedona/common/S2Geography/WkbS2Shape.java @@ -105,34 +105,48 @@ public WkbS2Shape(byte[] wkb) { this.chainLengths = new int[numRings]; this.vertexOffsets = new int[numRings]; - // First pass: count total vertices and compute offsets + // First pass: count total vertices and compute offsets. Sedona's WKBWriter writes + // open rings (n unique vertices, no closing duplicate); standard WKB writes closed + // rings (n+1 coords with last == first). Detect the closing-duplicate case by + // comparing the first and last (lon, lat) pair so we get the right edge count + // either way: edges = uniqueVertices = closed ? ringCoords - 1 : ringCoords. int totalVerts = 0; int edgeCount = 0; int byteOffset = payloadOffset + 4; int[] ringCoordCounts = new int[numRings]; int[] ringByteOffsets = new int[numRings]; + boolean[] ringClosed = new boolean[numRings]; for (int r = 0; r < numRings; r++) { int ringCoords = buf.getInt(byteOffset); ringCoordCounts[r] = ringCoords; ringByteOffsets[r] = byteOffset + 4; + boolean closed = + ringCoords >= 2 && firstAndLastEqual(buf, ringByteOffsets[r], ringCoords); + ringClosed[r] = closed; byteOffset += 4 + ringCoords * 16; - int ringEdges = Math.max(0, ringCoords - 1); + int ringEdges = closed ? Math.max(0, ringCoords - 1) : ringCoords; + int storedVerts = closed ? ringCoords : ringCoords; chainStarts[r] = edgeCount; chainLengths[r] = ringEdges; vertexOffsets[r] = totalVerts; edgeCount += ringEdges; - totalVerts += ringCoords; + totalVerts += storedVerts + (closed ? 0 : 1); // append closing duplicate for open rings } this.totalEdges = edgeCount; - // Second pass: read all vertices at once + // Second pass: read all vertices, appending a closing duplicate for open rings so + // the rest of the shape interface (getEdge, getChainEdge, computeContainsOrigin) + // can index `vertexOffsets[r] + (i % chainLengths[r])` uniformly. this.vertices = new S2Point[totalVerts]; int vi = 0; for (int r = 0; r < numRings; r++) { S2Point[] ringVerts = readVertices(buf, ringByteOffsets[r], ringCoordCounts[r]); System.arraycopy(ringVerts, 0, vertices, vi, ringVerts.length); vi += ringVerts.length; + if (!ringClosed[r] && ringVerts.length > 0) { + vertices[vi++] = ringVerts[0]; + } } // Eagerly compute containsOrigin from first ring @@ -229,6 +243,18 @@ private int findChain(int edgeId) { return 0; } + /** + * Returns true when the ring's first and last vertex compare equal as raw doubles, i.e. the ring + * is closed in the standard WKB sense. Sedona's own WKBWriter produces open rings, so this cheap + * numeric comparison on the in-buffer bytes lets us distinguish the two cases without running + * through the S2Point conversion. + */ + private static boolean firstAndLastEqual(ByteBuffer buf, int byteOffset, int numCoords) { + int lastOffset = byteOffset + (numCoords - 1) * 16; + return buf.getDouble(byteOffset) == buf.getDouble(lastOffset) + && buf.getDouble(byteOffset + 8) == buf.getDouble(lastOffset + 8); + } + /** Read numCoords (lon, lat) doubles from WKB and convert to S2Points. */ private static S2Point[] readVertices(ByteBuffer buf, int byteOffset, int numCoords) { S2Point[] pts = new S2Point[numCoords]; diff --git a/common/src/test/java/org/apache/sedona/common/S2Geography/WkbContainsRoundtripTest.java b/common/src/test/java/org/apache/sedona/common/S2Geography/WkbContainsRoundtripTest.java new file mode 100644 index 00000000000..c8720f7e6de --- /dev/null +++ b/common/src/test/java/org/apache/sedona/common/S2Geography/WkbContainsRoundtripTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.S2Geography; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.sedona.common.geography.Constructors; +import org.apache.sedona.common.geography.Functions; +import org.junit.Test; + +/** + * Localises a Geography ST_Contains correctness bug seen when arguments come from a DataFrame (i.e. + * after a GeographyWKBSerializer round-trip). The WKBGeography fast path (getShapeIndexGeography → + * WkbS2Shape) returns the wrong answer for some polygon-point pairs; the slow path through + * S2Polygon (and direct ST_GeogFromWKT in a SELECT) is correct. + */ +public class WkbContainsRoundtripTest { + + /** + * Direct Functions.contains call, no round-trip. Should return false: (10, 10) is far outside the + * small polygon at (2..3, 2..3). + */ + @Test + public void containsIsFalseWithoutRoundTrip() throws Exception { + Geography poly = Constructors.geogFromWKT("POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))", 4326); + Geography pt = Constructors.geogFromWKT("POINT(10 10)", 4326); + assertFalse(Functions.contains(poly, pt)); + assertTrue(Functions.contains(poly, Constructors.geogFromWKT("POINT(2.5 2.5)", 4326))); + } + + /** Control test mirroring GeographyFunctionTest's "ST_Contains point outside polygon". */ + @Test + public void controlPolygonAtOrigin() throws Exception { + Geography poly = Constructors.geogFromWKT("POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))", 4326); + Geography ptOutside = Constructors.geogFromWKT("POINT(2 2)", 4326); + Geography ptInside = Constructors.geogFromWKT("POINT(0.5 0.5)", 4326); + assertFalse("polygon at origin must NOT contain (2, 2)", Functions.contains(poly, ptOutside)); + assertTrue("polygon at origin must contain (0.5, 0.5)", Functions.contains(poly, ptInside)); + } + + /** + * Bypass WkbS2Shape and feed the polygon through PolygonGeography directly. If this passes while + * the equivalent WKBGeography case fails, the bug is localised to WkbS2Shape (or to the + * `result.shapeIndex.add(new WkbS2Shape(...))` path in WKBGeography.getShapeIndexGeography). + */ + @Test + public void bypassWkbS2ShapeViaPolygonGeography() throws Exception { + // Force the slow path: parse via WKTReader then DON'T wrap in WKBGeography. + Geography poly = new WKTReader().read("POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"); + poly.setSRID(4326); + Geography ptOutside = new WKTReader().read("POINT(10 10)"); + ptOutside.setSRID(4326); + Geography ptInside = new WKTReader().read("POINT(2.5 2.5)"); + ptInside.setSRID(4326); + assertFalse( + "[slow path] polygon at (2..3,2..3) must NOT contain (10, 10)", + Functions.contains(poly, ptOutside)); + assertTrue( + "[slow path] polygon at (2..3,2..3) must contain (2.5, 2.5)", + Functions.contains(poly, ptInside)); + } + + /** + * Same logical inputs, but each Geography goes through the WKB serializer round-trip first — + * which is what happens whenever a GeographyUDT column is read back from a DataFrame. + */ + @Test + public void containsIsFalseAfterWkbRoundTrip() throws Exception { + Geography poly = + GeographyWKBSerializer.deserialize( + GeographyWKBSerializer.serialize( + Constructors.geogFromWKT("POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))", 4326))); + Geography ptOutside = + GeographyWKBSerializer.deserialize( + GeographyWKBSerializer.serialize(Constructors.geogFromWKT("POINT(10 10)", 4326))); + Geography ptInside = + GeographyWKBSerializer.deserialize( + GeographyWKBSerializer.serialize(Constructors.geogFromWKT("POINT(2.5 2.5)", 4326))); + assertFalse( + "polygon at (2..3,2..3) must NOT contain (10, 10)", Functions.contains(poly, ptOutside)); + assertTrue( + "polygon at (2..3,2..3) must contain (2.5, 2.5)", Functions.contains(poly, ptInside)); + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala index 854d8510869..24654fe109e 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala @@ -18,6 +18,8 @@ */ package org.apache.spark.sql.sedona_sql.strategy.join +import org.apache.sedona.common.S2Geography.GeographyWKBSerializer +import org.apache.sedona.common.geography.{Functions => GeographyFunctions} import org.apache.sedona.core.spatialOperator.{SpatialPredicate, SpatialPredicateEvaluators} import org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer} @@ -50,7 +52,8 @@ case class BroadcastIndexJoinExec( joinType: JoinType, spatialPredicate: SpatialPredicate, extraCondition: Option[Expression] = None, - distance: Option[Expression] = None) + distance: Option[Expression] = None, + geographyShape: Boolean = false) extends SedonaBinaryExecNode with TraitJoinQueryBase with Logging { @@ -133,56 +136,59 @@ case class BroadcastIndexJoinExec( SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate)) } + // True when the build (window) side is also the side that originally appeared as the + // left-hand argument of the spatial predicate. Mirrors the inverse logic used by + // `evaluator` for the JTS path; required for Geography because Functions.contains is + // asymmetric and we cannot rely on a SpatialPredicateEvaluator. + private lazy val refinerSwap: Boolean = indexBuildSide != windowJoinSide + + private def newRefiner(): JoinRefiner = + if (geographyShape) new GeographyContainsRefiner(refinerSwap) + else new JtsRefiner(evaluator) + private def innerJoin( streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = { - val factory = new PreparedGeometryFactory() - val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry] + val refiner = newRefiner() val joinedRow = new JoinedRow streamIter.flatMap { case (geom, row) => - joinedRow.withLeft(row) - index.value - .query(geom.getEnvelopeInternal) - .iterator - .asScala - .asInstanceOf[Iterator[Geometry]] - .filter(candidate => - evaluator.eval( - preparedGeometries.getOrElseUpdate(candidate, { factory.create(candidate) }), - geom)) - .map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow])) - .filter(boundCondition) + if (geom == null) { + Iterator.empty + } else { + joinedRow.withLeft(row) + index.value + .query(geom.getEnvelopeInternal) + .iterator + .asScala + .asInstanceOf[Iterator[Geometry]] + .filter(candidate => refiner.matches(candidate, geom)) + .map(candidate => joinedRow.withRight(refiner.unpackRow(candidate))) + .filter(boundCondition) + } } } private def semiJoin( streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = { - val factory = new PreparedGeometryFactory() - val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry] + val refiner = newRefiner() val joinedRow = new JoinedRow streamIter.flatMap { case (geom, row) => val left = row - joinedRow.withLeft(left) - val anyMatches = index.value - .query(geom.getEnvelopeInternal) - .iterator - .asScala - .asInstanceOf[Iterator[Geometry]] - .filter(candidate => - evaluator.eval( - preparedGeometries.getOrElseUpdate( - candidate, { - factory.create(candidate) - }), - geom)) - .map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow])) - .exists(boundCondition) - - if (anyMatches) { - Iterator.single(left) - } else { + if (geom == null) { Iterator.empty + } else { + joinedRow.withLeft(left) + val anyMatches = index.value + .query(geom.getEnvelopeInternal) + .iterator + .asScala + .asInstanceOf[Iterator[Geometry]] + .filter(candidate => refiner.matches(candidate, geom)) + .map(candidate => joinedRow.withRight(refiner.unpackRow(candidate))) + .exists(boundCondition) + + if (anyMatches) Iterator.single(left) else Iterator.empty } } } @@ -190,8 +196,7 @@ case class BroadcastIndexJoinExec( private def antiJoin( streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = { - val factory = new PreparedGeometryFactory() - val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry] + val refiner = newRefiner() val joinedRow = new JoinedRow streamIter.flatMap { case (geom, row) => val left = row @@ -199,14 +204,8 @@ case class BroadcastIndexJoinExec( val anyMatches = (if (geom == null) Collections.EMPTY_LIST else index.value.query(geom.getEnvelopeInternal)).iterator.asScala .asInstanceOf[Iterator[Geometry]] - .filter(candidate => - evaluator.eval( - preparedGeometries.getOrElseUpdate( - candidate, { - factory.create(candidate) - }), - geom)) - .map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow])) + .filter(candidate => refiner.matches(candidate, geom)) + .map(candidate => joinedRow.withRight(refiner.unpackRow(candidate))) .exists(boundCondition) if (anyMatches) { @@ -220,8 +219,7 @@ case class BroadcastIndexJoinExec( private def outerJoin( streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = { - val factory = new PreparedGeometryFactory() - val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry] + val refiner = newRefiner() val joinedRow = new JoinedRow val nullRow = new GenericInternalRow(broadcast.output.length) @@ -230,19 +228,13 @@ case class BroadcastIndexJoinExec( val candidates = (if (geom == null) Collections.EMPTY_LIST else index.value.query(geom.getEnvelopeInternal)).iterator.asScala .asInstanceOf[Iterator[Geometry]] - .filter(candidate => - evaluator.eval( - preparedGeometries.getOrElseUpdate( - candidate, { - factory.create(candidate) - }), - geom)) + .filter(candidate => refiner.matches(candidate, geom)) new RowIterator { private var found = false override def advanceNext(): Boolean = { while (candidates.hasNext) { - val candidateRow = candidates.next().getUserData.asInstanceOf[UnsafeRow] + val candidateRow = refiner.unpackRow(candidates.next()) if (boundCondition(joinedRow.withRight(candidateRow))) { found = true return true @@ -312,6 +304,18 @@ case class BroadcastIndexJoinExec( (geometry.getFactory.toGeometry(envelope), row) } }) + case _ if geographyShape => + streamResultsRaw.map(row => { + val serialized = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] + if (serialized == null) { + (null, row) + } else { + val geog = GeographyWKBSerializer.deserialize(serialized) + val shape = JoinedGeometry.geographyToEnvelopeGeometry(geog) + shape.setUserData(GeographyJoinShape(geog, row)) + (shape, row) + } + }) case _ => streamResultsRaw.map(row => { val serializedObject = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] @@ -339,3 +343,47 @@ case class BroadcastIndexJoinExec( copy(left = newLeft, right = newRight) } } + +/** + * Per-iter helper that decides whether a candidate from the broadcast index actually satisfies + * the spatial predicate, and unpacks the candidate's `userData` into the output row. + * + * Two implementations: `JtsRefiner` for the planar JTS path (existing behaviour, byte-equivalent + * to the previous inline code), and `GeographyContainsRefiner` for the new Geography-on-S2 path. + */ +private sealed trait JoinRefiner { + def matches(candidate: Geometry, streamShape: Geometry): Boolean + def unpackRow(candidate: Geometry): UnsafeRow +} + +private final class JtsRefiner(evaluator: SpatialPredicateEvaluator) extends JoinRefiner { + private val factory = new PreparedGeometryFactory() + private val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry] + override def matches(candidate: Geometry, streamShape: Geometry): Boolean = + evaluator.eval( + preparedGeometries.getOrElseUpdate(candidate, factory.create(candidate)), + streamShape) + override def unpackRow(candidate: Geometry): UnsafeRow = + candidate.getUserData.asInstanceOf[UnsafeRow] +} + +/** + * Refines candidates with `Functions.contains` (S2 spherical containment). Caching of the per- + * Geography S2 ShapeIndex happens inside `WKBGeography.getShapeIndexGeography()`, so we do not + * need a JTS-style PreparedGeometry cache here — the build side keeps the same Geography JVM + * instances for the lifetime of the broadcast. + */ +private final class GeographyContainsRefiner(swap: Boolean) extends JoinRefiner { + override def matches(candidate: Geometry, streamShape: Geometry): Boolean = { + val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape] + val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape] + if (swap) { + GeographyFunctions.contains(streamShapeData.geog, buildShape.geog) + } else { + GeographyFunctions.contains(buildShape.geog, streamShapeData.geog) + } + } + + override def unpackRow(candidate: Geometry): UnsafeRow = + candidate.getUserData.asInstanceOf[GeographyJoinShape].row +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index 233ed8a806f..306dd18b940 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -42,7 +42,10 @@ case class JoinQueryDetection( spatialPredicate: SpatialPredicate, isGeography: Boolean, extraCondition: Option[Expression] = None, - distance: Option[Expression] = None) + distance: Option[Expression] = None, + // True when the join key columns are GeographyUDT (independent of `isGeography`, + // which means "use spheroid distance"). Currently only ST_Contains is supported. + geographyShape: Boolean = false) /** * Plans `RangeJoinExec` for inner joins on spatial relationships ST_Contains(a, b) and @@ -54,11 +57,11 @@ case class JoinQueryDetection( */ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { - // Geography spatial joins are not supported in this PR — TraitJoinQueryBase.toSpatialRDD - // deserializes join keys with GeometrySerializer, which would fail on Geography bytes. - // ST_Contains is the only spatial predicate currently wired for Geography (via InferredExpression - // dual dispatch); when either side is GeographyUDT we skip join planning and let Spark evaluate - // the predicate row-by-row. Other ST_Predicates reject Geography inputs at analysis time, so no + // ST_Contains is the only spatial predicate currently wired for Geography (via + // InferredExpression dual dispatch). For broadcast joins we route GeographyUDT inputs through + // a dedicated index/refine path (see SpatialIndexExec.geographyShape and + // BroadcastIndexJoinExec.geographyShape); for the partition/range path we still fall back to + // row-by-row evaluation. Other ST_Predicates reject Geography inputs at analysis time, so no // guard is needed there. private def isGeographyInput(shape: Expression): Boolean = shape.dataType.isInstanceOf[GeographyUDT] @@ -209,10 +212,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { case joinConditionMatcher(predicate, extraCondition) => predicate match { // ST_Contains is an InferredExpression (not ST_Predicate) so it can't sit inside - // getJoinDetection; it's also the only predicate currently accepting Geography - // inputs and therefore the only one needing the Geography guard. - case ST_Contains(Seq(leftShape, rightShape)) - if !isGeographyInput(leftShape) && !isGeographyInput(rightShape) => + // getJoinDetection. When either operand is GeographyUDT we still detect the join + // here and set `geographyShape = true`; planBroadcastJoin will route the work to + // the Geography-aware index/refine path. Non-broadcast plans bail out in `apply` + // below and fall back to row-by-row evaluation. + case ST_Contains(Seq(leftShape, rightShape)) => + val geographyShape = + isGeographyInput(leftShape) || isGeographyInput(rightShape) Some( JoinQueryDetection( left, @@ -220,8 +226,9 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { leftShape, rightShape, SpatialPredicate.CONTAINS, - false, - extraCondition)) + isGeography = false, + extraCondition, + geographyShape = geographyShape)) case pred: ST_Predicate => getJoinDetection(left, right, pred, extraCondition) case pred: RS_Predicate => @@ -432,33 +439,28 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) { queryDetection match { - case Some( - JoinQueryDetection( - left, - right, - leftShape, - rightShape, - spatialPredicate, - isGeography, - extraCondition, - distance)) => + case Some(detection) => planBroadcastJoin( - left, - right, - Seq(leftShape, rightShape), + detection.left, + detection.right, + Seq(detection.leftShape, detection.rightShape), joinType, - spatialPredicate, + detection.spatialPredicate, sedonaConf.getIndexType, broadcastLeft, broadcastRight, - isGeography, - extraCondition, - distance) + detection.isGeography, + detection.extraCondition, + detection.distance, + detection.geographyShape) case _ => Nil } } else { queryDetection match { + // Geography ST_Contains has no partition/range path — fall back to row-by-row. + case Some(detection) if detection.geographyShape => + Nil case Some( JoinQueryDetection( left, @@ -468,7 +470,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { spatialPredicate, isGeography, extraCondition, - None)) => + None, + _)) => planSpatialJoin( left, right, @@ -485,7 +488,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { spatialPredicate, isGeography, extraCondition, - Some(distance))) => + Some(distance), + _)) => Option(spatialPredicate) match { case Some(SpatialPredicate.KNN) => planKNNJoin( @@ -714,7 +718,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { broadcastRight: Boolean, isGeography: Boolean, extraCondition: Option[Expression], - distance: Option[Expression]): Seq[SparkPlan] = { + distance: Option[Expression], + geographyShape: Boolean = false): Seq[SparkPlan] = { val broadcastSide = joinType match { case Inner if broadcastLeft => Some(LeftSide) @@ -834,7 +839,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { indexType, isRasterPredicate, isGeography, - distanceOnIndexSide), + distanceOnIndexSide, + geographyShape), planLater(right), b, LeftSide) @@ -846,7 +852,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { indexType, isRasterPredicate, isGeography, - distanceOnIndexSide), + distanceOnIndexSide, + geographyShape), planLater(right), a, RightSide) @@ -859,7 +866,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { indexType, isRasterPredicate, isGeography, - distanceOnIndexSide), + distanceOnIndexSide, + geographyShape), a, LeftSide) case (RightSide, true) => // Broadcast the right side, objects on the left @@ -871,7 +879,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { indexType, isRasterPredicate, isGeography, - distanceOnIndexSide), + distanceOnIndexSide, + geographyShape), b, RightSide) } @@ -884,7 +893,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { joinType, spatialPredicate, extraCondition, - distanceOnStreamSide) :: Nil + distanceOnStreamSide, + geographyShape) :: Nil case None => logInfo( s"Spatial join for $relationship with arguments not aligned " + diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala index 6d85f963619..7b0d9a59f1a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala @@ -18,14 +18,42 @@ */ package org.apache.spark.sql.sedona_sql.strategy.join +import org.apache.sedona.common.S2Geography.Geography import org.apache.sedona.common.sphere.Haversine -import org.locationtech.jts.geom.{Envelope, Geometry} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.locationtech.jts.geom.{Envelope, Geometry, GeometryFactory} + +/** + * Payload stored in `userData` on each Geography index entry. Carries both the deserialized + * Geography (for S2 predicate refinement) and the original row (for the join output). + */ +case class GeographyJoinShape(geog: Geography, row: UnsafeRow) /** * Utility functions for generating geometries for spatial join. */ object JoinedGeometry { + private val DEFAULT_FACTORY = new GeometryFactory() + + /** + * Convert a Geography to a JTS Geometry whose envelope covers the Geography's lat/lng bounding + * rectangle. When the rectangle wraps the antimeridian we expand to the full longitude range + * [-180, 180]; this is a coarse filter that keeps the planar index simple (apache/sedona-db PR + * #775 made the same trade-off; #782 tracks the eventual split-at-±180 optimisation). + */ + def geographyToEnvelopeGeometry(geog: Geography): Geometry = { + val rect = geog.region().getRectBound + val latLo = rect.latLo().degrees() + val latHi = rect.latHi().degrees() + val (lngLo, lngHi) = if (rect.lng().isInverted) { + (-180.0, 180.0) + } else { + (rect.lngLo().degrees(), rect.lngHi().degrees()) + } + DEFAULT_FACTORY.toGeometry(new Envelope(lngLo, lngHi, latLo, latHi)) + } + /** * Convert the given geometry to an envelope expanded by distance. * @param geom diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala index f38d7646c72..bbe42a127d1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala @@ -39,7 +39,8 @@ case class SpatialIndexExec( indexType: IndexType, isRasterPredicate: Boolean, isGeography: Boolean, - distance: Option[Expression] = None) + distance: Option[Expression] = None, + geographyShape: Boolean = false) extends SedonaUnaryExecNode with TraitJoinQueryBase with Logging { @@ -64,6 +65,8 @@ case class SpatialIndexExec( case None => if (isRasterPredicate) { toWGS84EnvelopeRDD(resultRaw, boundShape) + } else if (geographyShape) { + toGeographySpatialRDD(resultRaw, boundShape) } else { toSpatialRDD(resultRaw, boundShape) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala index 80ef9f2b984..515b54d7036 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala @@ -18,6 +18,7 @@ */ package org.apache.spark.sql.sedona_sql.strategy.join +import org.apache.sedona.common.S2Geography.GeographyWKBSerializer import org.apache.sedona.core.spatialRDD.SpatialRDD import org.apache.sedona.core.utils.SedonaConf import org.apache.sedona.sql.utils.{GeometrySerializer, RasterSerializer} @@ -57,6 +58,34 @@ trait TraitJoinQueryBase { spatialRdd } + /** + * Builds a SpatialRDD from a column of GeographyUDT bytes. Each row becomes a JTS geometry + * whose envelope is the Geography's lat/lng bounding rectangle (full-longitude when the + * rectangle wraps the antimeridian). The Geography object is carried alongside the original row + * in `userData` via [[GeographyJoinShape]] so the join executor can perform S2-based predicate + * refinement and emit the row. + */ + def toGeographySpatialRDD( + rdd: RDD[UnsafeRow], + shapeExpression: Expression): SpatialRDD[Geometry] = { + val spatialRdd = new SpatialRDD[Geometry] + spatialRdd.setRawSpatialRDD( + rdd + .flatMap { x => + val geogBytes = shapeExpression.eval(x).asInstanceOf[Array[Byte]] + if (geogBytes == null) { + None + } else { + val geog = GeographyWKBSerializer.deserialize(geogBytes) + val shape = JoinedGeometry.geographyToEnvelopeGeometry(geog) + shape.setUserData(GeographyJoinShape(geog, x.copy)) + Some(shape) + } + } + .toJavaRDD()) + spatialRdd + } + def toWGS84EnvelopeRDD( rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala new file mode 100644 index 00000000000..8f70a30aff4 --- /dev/null +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.geography + +import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.sql.functions.{broadcast, expr} +import org.apache.spark.sql.sedona_sql.strategy.join.BroadcastIndexJoinExec + +class BroadcastIndexJoinGeographySuite extends TestBaseScala { + + // Three unit-square polygons centred at integer offsets. + private lazy val polygonGeogDf = { + val rows = (0 until 3).map { i => + (i, s"POLYGON((${i} ${i}, ${i + 1} ${i}, ${i + 1} ${i + 1}, ${i} ${i + 1}, ${i} ${i}))") + } + import sparkSession.implicits._ + rows + .toDF("poly_id", "wkt") + .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog") + } + + // One point inside each of the three polygons (3 hits) plus three points outside. + private lazy val pointGeogDf = { + val rows = Seq( + (0, "POINT(0.5 0.5)"), // in polygon 0 + (1, "POINT(1.5 1.5)"), // in polygon 1 + (2, "POINT(2.5 2.5)"), // in polygon 2 + (3, "POINT(10 10)"), // outside + (4, "POINT(20 20)"), // outside + (5, "POINT(30 30)") // outside + ) + import sparkSession.implicits._ + rows + .toDF("pt_id", "wkt") + .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog") + } + + private def planUsesBroadcastIndexJoin(df: org.apache.spark.sql.DataFrame): Boolean = + df.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.nonEmpty + + describe("Geography broadcast spatial join (ST_Contains)") { + + it("plans BroadcastIndexJoinExec when the polygon side is broadcast") { + val joined = + pointGeogDf.join(broadcast(polygonGeogDf), expr("ST_Contains(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + } + + it("plans BroadcastIndexJoinExec when the point side is broadcast") { + val joined = + polygonGeogDf.join(broadcast(pointGeogDf), expr("ST_Contains(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + } + + it("returns the correct (poly_id, pt_id) pairs") { + val rows = pointGeogDf + .join(broadcast(polygonGeogDf), expr("ST_Contains(poly_geog, pt_geog)")) + .selectExpr("poly_id", "pt_id") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(rows === Set((0, 0), (1, 1), (2, 2))) + } + + it("handles antimeridian-spanning polygons correctly") { + // Polygon spanning longitude 170 → -170 (5° wide across the antimeridian on each side). + // The lat/lng rect is "inverted" so we expand to the full longitude range as a coarse + // filter; the S2 refine step is what guarantees the correct answer here. + import sparkSession.implicits._ + val polyDf = Seq((100, "POLYGON((170 -1, -170 -1, -170 1, 170 1, 170 -1))")) + .toDF("poly_id", "wkt") + .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog") + + val ptDf = Seq( + (1, "POINT(175 0)"), // inside on the +180 side + (2, "POINT(-175 0)"), // inside on the −180 side + (3, "POINT(0 0)") // far outside + ).toDF("pt_id", "wkt") + .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog") + + val joined = ptDf.join(broadcast(polyDf), expr("ST_Contains(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + val matched = joined.selectExpr("pt_id").collect().map(_.getInt(0)).toSet + assert(matched === Set(1, 2)) + } + + it("auto-broadcasts the small side when sedona.join.autoBroadcastJoinThreshold permits") { + // Bump the threshold so the small Geography frames qualify for auto-broadcast. + withConf(Map("sedona.join.autoBroadcastJoinThreshold" -> "10485760")) { + val joined = polygonGeogDf.join(pointGeogDf, expr("ST_Contains(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + } + } + + it("supports LEFT OUTER with the polygon side broadcast") { + val joined = pointGeogDf + .join(broadcast(polygonGeogDf), expr("ST_Contains(poly_geog, pt_geog)"), "left_outer") + assert(planUsesBroadcastIndexJoin(joined)) + // 6 stream rows; 3 match a polygon, 3 are emitted with NULL polygon columns. + assert(joined.count() === 6) + val nullPolygonCount = + joined.where("poly_id IS NULL").count() + assert(nullPolygonCount === 3) + } + + it("supports RIGHT OUTER with the polygon side broadcast (build = left)") { + // For RIGHT OUTER the planner requires broadcastLeft, so we broadcast the polygon side + // and stream the points. Every right-side (point) row must appear; unmatched points + // come back with NULL polygon columns. + val joined = broadcast(polygonGeogDf) + .join(pointGeogDf, expr("ST_Contains(poly_geog, pt_geog)"), "right_outer") + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 6) + val unmatchedPoints = + joined.where("poly_id IS NULL").selectExpr("pt_id").collect().map(_.getInt(0)).toSet + assert(unmatchedPoints === Set(3, 4, 5)) + } + + it("does NOT plan BroadcastIndexJoinExec without a broadcast hint") { + // autoBroadcastJoinThreshold = -1 in TestBaseScala, so neither side auto-broadcasts. + // Geography ST_Contains has no partition/range-join path, so Spark falls back to a + // row-by-row evaluation (BroadcastNestedLoopJoinExec). The result must still be + // correct. + val joined = polygonGeogDf.join(pointGeogDf, expr("ST_Contains(poly_geog, pt_geog)")) + assert(!planUsesBroadcastIndexJoin(joined)) + val pairs = joined + .selectExpr("poly_id", "pt_id") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(pairs === Set((0, 0), (1, 1), (2, 2))) + } + } +}