diff --git a/common/src/main/java/org/apache/sedona/common/sphere/Haversine.java b/common/src/main/java/org/apache/sedona/common/sphere/Haversine.java index 698ef36d145..c2ea4c1c367 100644 --- a/common/src/main/java/org/apache/sedona/common/sphere/Haversine.java +++ b/common/src/main/java/org/apache/sedona/common/sphere/Haversine.java @@ -40,6 +40,12 @@ public class Haversine { */ public static final double AVG_EARTH_RADIUS = 6371008.0; + /** + * Polar radius of the WGS-84 spheroid, in meters. Used as a sphere radius when expanding + * envelopes so that the expansion upper-bounds both spherical and spheroidal distances. + */ + public static final double EARTH_POLAR_RADIUS = 6357000.0; + public static double distance(Geometry geom1, Geometry geom2, double avg_earth_radius) { Coordinate coordinate1 = geom1.getGeometryType().equals("Point") diff --git a/common/src/test/java/org/apache/sedona/common/sphere/HaversineEnvelopeTest.java b/common/src/test/java/org/apache/sedona/common/sphere/HaversineEnvelopeTest.java index b39ec644fe4..9552c105cf0 100644 --- a/common/src/test/java/org/apache/sedona/common/sphere/HaversineEnvelopeTest.java +++ b/common/src/test/java/org/apache/sedona/common/sphere/HaversineEnvelopeTest.java @@ -26,7 +26,7 @@ import org.locationtech.jts.geom.Point; public class HaversineEnvelopeTest { - private static final int SPHERE_RADIUS = 6357000; + private static final double SPHERE_RADIUS = Haversine.EARTH_POLAR_RADIUS; private static final GeometryFactory factory = new GeometryFactory(); @Test diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala index 24654fe109e..44772f70ea1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala @@ -142,9 +142,12 @@ case class BroadcastIndexJoinExec( // asymmetric and we cannot rely on a SpatialPredicateEvaluator. private lazy val refinerSwap: Boolean = indexBuildSide != windowJoinSide - private def newRefiner(): JoinRefiner = - if (geographyShape) new GeographyContainsRefiner(refinerSwap) - else new JtsRefiner(evaluator) + private def newRefiner(): JoinRefiner = { + if (geographyShape) { + if (distance.isDefined) new GeographyDistanceRefiner(refinerSwap) + else new GeographyRelationRefiner(spatialPredicate, refinerSwap) + } else new JtsRefiner(evaluator) + } private def innerJoin( streamIter: Iterator[(Geometry, UnsafeRow)], @@ -288,6 +291,36 @@ case class BroadcastIndexJoinExec( streamResultsRaw: RDD[UnsafeRow], boundStreamShape: Expression) = { distance match { + case Some(distanceExpression) if geographyShape => + val boundDistance = + BindReferences.bindReference(distanceExpression, streamed.output) + // When the broadcast side already expanded its envelope by `d` (the + // literal-radius case, where the planner forwards the same distance to + // both sides so the per-row radius is also available here for the + // refiner), keep the stream envelope unexpanded. The coarse filter then + // matches the geometry path: `expand(build, d) ∩ stream`, not the wider + // `expand(build, d) ∩ expand(stream, d)`. When only the stream side + // received the distance (per-row radius bound to the stream side), the + // build side is unexpanded and we still need to expand the stream + // envelope to ensure the index returns all candidates within `d`. + val streamSideExpands = broadcast.distance.isEmpty + streamResultsRaw.map(row => { + val serialized = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] + if (serialized == null) { + (null, row) + } else { + val geog = GeographyWKBSerializer.deserialize(serialized) + val radius = boundDistance.eval(row).asInstanceOf[Double] + val baseEnvelope = JoinedGeometry.geographyToEnvelopeGeometry(geog) + val shape = if (streamSideExpands) { + JoinedGeometry.geometryToExpandedEnvelope(baseEnvelope, radius, isGeography = true) + } else { + baseEnvelope + } + shape.setUserData(GeographyJoinShape(geog, row, radius)) + (shape, row) + } + }) case Some(distanceExpression) => streamResultsRaw.map(row => { val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]] @@ -348,8 +381,9 @@ case class BroadcastIndexJoinExec( * Per-iter helper that decides whether a candidate from the broadcast index actually satisfies * the spatial predicate, and unpacks the candidate's `userData` into the output row. * - * Two implementations: `JtsRefiner` for the planar JTS path (existing behaviour, byte-equivalent - * to the previous inline code), and `GeographyContainsRefiner` for the new Geography-on-S2 path. + * Three implementations: `JtsRefiner` for the planar JTS path, `GeographyRelationRefiner` for + * non-distance Geography predicates (CONTAINS, WITHIN, INTERSECTS, EQUALS), and + * `GeographyDistanceRefiner` for ST_DWithin on Geography. */ private sealed trait JoinRefiner { def matches(candidate: Geometry, streamShape: Geometry): Boolean @@ -368,22 +402,51 @@ private final class JtsRefiner(evaluator: SpatialPredicateEvaluator) extends Joi } /** - * Refines candidates with `Functions.contains` (S2 spherical containment). Caching of the per- - * Geography S2 ShapeIndex happens inside `WKBGeography.getShapeIndexGeography()`, so we do not - * need a JTS-style PreparedGeometry cache here — the build side keeps the same Geography JVM - * instances for the lifetime of the broadcast. + * Refines candidates with the appropriate `org.apache.sedona.common.geography.Functions` + * predicate (CONTAINS / WITHIN / INTERSECTS / EQUALS). Caching of the per-Geography S2 ShapeIndex + * happens inside `WKBGeography.getShapeIndexGeography()`, so we do not need a JTS-style + * PreparedGeometry cache here — the build side keeps the same Geography JVM instances for the + * lifetime of the broadcast. `swap` flips operand order when the build side does not correspond + * to the predicate's left-hand argument (handles the `RIGHT JOIN` / right-broadcast case for + * asymmetric predicates). */ -private final class GeographyContainsRefiner(swap: Boolean) extends JoinRefiner { +private final class GeographyRelationRefiner(predicate: SpatialPredicate, swap: Boolean) + extends JoinRefiner { override def matches(candidate: Geometry, streamShape: Geometry): Boolean = { val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape] val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape] - if (swap) { - GeographyFunctions.contains(streamShapeData.geog, buildShape.geog) - } else { - GeographyFunctions.contains(buildShape.geog, streamShapeData.geog) + val (a, b) = + if (swap) (streamShapeData.geog, buildShape.geog) + else (buildShape.geog, streamShapeData.geog) + predicate match { + case SpatialPredicate.CONTAINS => GeographyFunctions.contains(a, b) + case SpatialPredicate.WITHIN => GeographyFunctions.within(a, b) + case SpatialPredicate.INTERSECTS => GeographyFunctions.intersects(a, b) + case SpatialPredicate.EQUALS => GeographyFunctions.equals(a, b) + case other => + throw new UnsupportedOperationException( + s"Geography broadcast spatial join does not support predicate $other") } } override def unpackRow(candidate: Geometry): UnsafeRow = candidate.getUserData.asInstanceOf[GeographyJoinShape].row } + +/** + * Refines candidates for ST_DWithin on Geography. The per-row distance threshold is carried on + * the stream-side `GeographyJoinShape.radius`, populated when the stream shape is built. + */ +private final class GeographyDistanceRefiner(swap: Boolean) extends JoinRefiner { + override def matches(candidate: Geometry, streamShape: Geometry): Boolean = { + val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape] + val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape] + val (a, b) = + if (swap) (streamShapeData.geog, buildShape.geog) + else (buildShape.geog, streamShapeData.geog) + GeographyFunctions.dWithin(a, b, streamShapeData.radius) + } + + override def unpackRow(candidate: Geometry): UnsafeRow = + candidate.getUserData.asInstanceOf[GeographyJoinShape].row +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index e0840d474b1..e9db76cffda 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -228,9 +228,21 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { isGeography = false, extraCondition, geographyShape = geographyShape)) - // ST_Intersects / ST_Within / ST_Equals on Geography have no broadcast index path - // yet (the Geography refiner is ST_Contains-specific), so gate Geography inputs and - // let them fall back to row-by-row evaluation. + // ST_Intersects / ST_Within / ST_Equals on Geography route through the + // Geography-aware refiner (`GeographyRelationRefiner`); geometry inputs continue + // through `inferredJoinDetection`. + case ST_Intersects(Seq(leftShape, rightShape)) + if isGeographyInput(leftShape) || isGeographyInput(rightShape) => + Some( + JoinQueryDetection( + left, + right, + leftShape, + rightShape, + SpatialPredicate.INTERSECTS, + isGeography = false, + extraCondition, + geographyShape = true)) case ST_Intersects(Seq(leftShape, rightShape)) => inferredJoinDetection( left, @@ -239,6 +251,18 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { rightShape, SpatialPredicate.INTERSECTS, extraCondition) + case ST_Within(Seq(leftShape, rightShape)) + if isGeographyInput(leftShape) || isGeographyInput(rightShape) => + Some( + JoinQueryDetection( + left, + right, + leftShape, + rightShape, + SpatialPredicate.WITHIN, + isGeography = false, + extraCondition, + geographyShape = true)) case ST_Within(Seq(leftShape, rightShape)) => inferredJoinDetection( left, @@ -247,6 +271,18 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { rightShape, SpatialPredicate.WITHIN, extraCondition) + case ST_Equals(Seq(leftShape, rightShape)) + if isGeographyInput(leftShape) || isGeographyInput(rightShape) => + Some( + JoinQueryDetection( + left, + right, + leftShape, + rightShape, + SpatialPredicate.EQUALS, + isGeography = false, + extraCondition, + geographyShape = true)) case ST_Equals(Seq(leftShape, rightShape)) => inferredJoinDetection( left, @@ -260,6 +296,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { case pred: RS_Predicate => getRasterJoinDetection(left, right, pred, extraCondition) case ST_DWithin(Seq(leftShape, rightShape, distance)) => + val geographyShape = + isGeographyInput(leftShape) || isGeographyInput(rightShape) Some( JoinQueryDetection( left, @@ -267,9 +305,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { leftShape, rightShape, SpatialPredicate.INTERSECTS, - isGeography = false, + isGeography = geographyShape, condition, - Some(distance))) + Some(distance), + geographyShape = geographyShape)) + // Note: the 4-arg ST_DWithin is geometry-only; on Geography input the Spark + // analyzer rejects the call before reaching this matcher, so no Geography + // guard is needed here. case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) => val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean] Some( @@ -484,7 +526,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { } } else { queryDetection match { - // Geography ST_Contains has no partition/range path — fall back to row-by-row. + // Geography predicates (ST_Contains/Within/Intersects/Equals/DWithin) have no + // partition/range path — fall back to row-by-row. case Some(detection) if detection.geographyShape => Nil case Some( @@ -843,9 +886,33 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy { .map { distanceExpr => matchDistanceExpressionToJoinSide(distanceExpr, left, right) match { case Some(side) => - if (broadcastSide.get == side) (Some(distanceExpr), None) - else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None) - else (None, Some(distanceExpr)) + if (geographyShape) { + // Geography distance joins read the per-row radius from the stream-side + // GeographyJoinShape inside GeographyDistanceRefiner, so the radius MUST + // be available on the streamed side. The stream-side expression is later + // re-bound against `streamed.output` in BroadcastIndexJoinExec, so we + // can only forward it to the stream side when it is either a literal + // (no references) or already bound to the streamed side. + if (distanceExpr.references.isEmpty) { + // Literal: keep build-side expansion AND populate stream-side radius. + (Some(distanceExpr), Some(distanceExpr)) + } else if (broadcastSide.get == side) { + // Non-literal expression bound to the broadcast/index side cannot be + // re-bound against streamed.output. Reject up front rather than + // planning a broadcast geography join that will fail at execution. + throw new UnsupportedOperationException( + "Geography distance broadcast joins do not support non-literal " + + "distance expressions bound to the broadcast/index side; bind " + + "the distance expression to the streamed side or use a literal.") + } else { + // Bound to the streamed side: stream-only (no build-side expansion). + (None, Some(distanceExpr)) + } + } else { + if (broadcastSide.get == side) (Some(distanceExpr), None) + else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None) + else (None, Some(distanceExpr)) + } case _ => throw new IllegalArgumentException( "Distance expression must be bound to one side of the join") diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala index 7b0d9a59f1a..5efc36b4dce 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinedGeometry.scala @@ -25,9 +25,11 @@ import org.locationtech.jts.geom.{Envelope, Geometry, GeometryFactory} /** * Payload stored in `userData` on each Geography index entry. Carries both the deserialized - * Geography (for S2 predicate refinement) and the original row (for the join output). + * Geography (for S2 predicate refinement) and the original row (for the join output). For + * `ST_DWithin` joins, `radius` carries the per-row distance threshold from the row that produced + * this shape; for non-distance predicates it remains 0.0. */ -case class GeographyJoinShape(geog: Geography, row: UnsafeRow) +case class GeographyJoinShape(geog: Geography, row: UnsafeRow, radius: Double = 0.0) /** * Utility functions for generating geometries for spatial join. @@ -89,9 +91,8 @@ object JoinedGeometry { * in meter */ private def expandEnvelopeForGeography(envelope: Envelope, distance: Double): Envelope = { - // Here we use the polar radius of the spheroid as the radius of the sphere, so that the expanded - // envelope will work for both spherical and spheroidal distances. - val sphereRadius = 6357000.0 - Haversine.expandEnvelope(envelope, distance, sphereRadius) + // Use the polar radius of the spheroid as the radius of the sphere so that the expanded + // envelope upper-bounds both spherical and spheroidal distances. + Haversine.expandEnvelope(envelope, distance, Haversine.EARTH_POLAR_RADIUS) } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala index bbe42a127d1..c3c5c0392ba 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala @@ -56,6 +56,11 @@ case class SpatialIndexExec( val boundShape = BindReferences.bindReference(shape, child.output) val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1) val spatialRDD = distance match { + case Some(distanceExpression) if geographyShape => + toExpandedGeographyEnvelopeRDD( + resultRaw, + boundShape, + BindReferences.bindReference(distanceExpression, child.output)) case Some(distanceExpression) => toExpandedEnvelopeRDD( resultRaw, diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala index 515b54d7036..2c449056102 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala @@ -135,6 +135,38 @@ trait TraitJoinQueryBase { spatialRdd } + /** + * Geography variant of [[toExpandedEnvelopeRDD]]. Each row becomes a JTS geometry whose + * envelope is the Geography's lat/lng bounding rectangle expanded by `boundRadius` meters using + * the Haversine-based polar-radius approximation in + * [[JoinedGeometry.geometryToExpandedEnvelope]] (with `isGeography=true`). The Geography object + * and per-row radius are carried alongside the original row in `userData` via + * [[GeographyJoinShape]] so the join executor can perform S2-based ST_DWithin refinement. + */ + def toExpandedGeographyEnvelopeRDD( + rdd: RDD[UnsafeRow], + shapeExpression: Expression, + boundRadius: Expression): SpatialRDD[Geometry] = { + val spatialRdd = new SpatialRDD[Geometry] + spatialRdd.setRawSpatialRDD(rdd + .flatMap { x => + val geogBytes = shapeExpression.eval(x).asInstanceOf[Array[Byte]] + if (geogBytes == null) { + None + } else { + val geog = GeographyWKBSerializer.deserialize(geogBytes) + val distance = boundRadius.eval(x).asInstanceOf[Double] + val baseEnvelope = JoinedGeometry.geographyToEnvelopeGeometry(geog) + val expandedEnvelope = + JoinedGeometry.geometryToExpandedEnvelope(baseEnvelope, distance, isGeography = true) + expandedEnvelope.setUserData(GeographyJoinShape(geog, x.copy, distance)) + Some(expandedEnvelope) + } + } + .toJavaRDD()) + spatialRdd + } + def doSpatialPartitioning( dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry], diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala index 8f70a30aff4..a1c5afe99dc 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/BroadcastIndexJoinGeographySuite.scala @@ -150,4 +150,403 @@ class BroadcastIndexJoinGeographySuite extends TestBaseScala { assert(pairs === Set((0, 0), (1, 1), (2, 2))) } } + + describe("Geography broadcast spatial join (ST_Within)") { + + it("plans BroadcastIndexJoinExec when the polygon side is broadcast") { + val joined = + pointGeogDf.join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + } + + it("plans BroadcastIndexJoinExec when the point side is broadcast") { + val joined = + polygonGeogDf.join(broadcast(pointGeogDf), expr("ST_Within(pt_geog, poly_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + } + + it("returns the correct (poly_id, pt_id) pairs") { + val rows = pointGeogDf + .join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)")) + .selectExpr("poly_id", "pt_id") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(rows === Set((0, 0), (1, 1), (2, 2))) + } + + it("supports LEFT OUTER with the polygon side broadcast") { + val joined = pointGeogDf + .join(broadcast(polygonGeogDf), expr("ST_Within(pt_geog, poly_geog)"), "left_outer") + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 6) + assert(joined.where("poly_id IS NULL").count() === 3) + } + } + + describe("Geography broadcast spatial join (ST_Intersects)") { + + it("plans BroadcastIndexJoinExec when the polygon side is broadcast") { + val joined = + pointGeogDf.join(broadcast(polygonGeogDf), expr("ST_Intersects(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + } + + it("returns the correct (poly_id, pt_id) pairs") { + val rows = pointGeogDf + .join(broadcast(polygonGeogDf), expr("ST_Intersects(poly_geog, pt_geog)")) + .selectExpr("poly_id", "pt_id") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(rows === Set((0, 0), (1, 1), (2, 2))) + } + + it("handles antimeridian-spanning polygons correctly") { + import sparkSession.implicits._ + val polyDf = Seq((100, "POLYGON((170 -1, -170 -1, -170 1, 170 1, 170 -1))")) + .toDF("poly_id", "wkt") + .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog") + + val ptDf = Seq((1, "POINT(175 0)"), (2, "POINT(-175 0)"), (3, "POINT(0 0)")) + .toDF("pt_id", "wkt") + .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog") + + val joined = ptDf.join(broadcast(polyDf), expr("ST_Intersects(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + val matched = joined.selectExpr("pt_id").collect().map(_.getInt(0)).toSet + assert(matched === Set(1, 2)) + } + } + + private lazy val pointsLeftDf = { + import sparkSession.implicits._ + Seq((0, "POINT(0 0)"), (1, "POINT(1 1)"), (2, "POINT(2 2)"), (3, "POINT(99 99)")) + .toDF("id_l", "wkt") + .selectExpr("id_l", "ST_GeogFromWKT(wkt, 4326) AS geog_l") + } + private lazy val pointsRightDf = { + import sparkSession.implicits._ + Seq((10, "POINT(0 0)"), (11, "POINT(1 1)"), (12, "POINT(2 2)"), (13, "POINT(50 50)")) + .toDF("id_r", "wkt") + .selectExpr("id_r", "ST_GeogFromWKT(wkt, 4326) AS geog_r") + } + + describe("Geography broadcast spatial join (ST_Equals)") { + + it("plans BroadcastIndexJoinExec and matches identical points") { + val joined = + pointsLeftDf.join(broadcast(pointsRightDf), expr("ST_Equals(geog_l, geog_r)")) + assert(planUsesBroadcastIndexJoin(joined)) + val pairs = joined + .selectExpr("id_l", "id_r") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(pairs === Set((0, 10), (1, 11), (2, 12))) + } + } + + private lazy val pointsADf = { + import sparkSession.implicits._ + Seq((0, "POINT(0 0)"), (1, "POINT(1 0)"), (2, "POINT(2 0)")) + .toDF("id_a", "wkt") + .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a") + } + private lazy val pointsBDf = { + import sparkSession.implicits._ + Seq( + (10, "POINT(0 0)"), // 0 m from (0,0) + (11, "POINT(1 0)"), // 0 m from (1,0) + (12, "POINT(0 1)") // ~111 km north of (0,0) + ).toDF("id_b", "wkt") + .selectExpr("id_b", "ST_GeogFromWKT(wkt, 4326) AS geog_b") + } + + describe("Geography broadcast spatial join (ST_DWithin)") { + + it("plans BroadcastIndexJoinExec when the right side is broadcast") { + val joined = + pointsADf.join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0)")) + assert(planUsesBroadcastIndexJoin(joined)) + } + + it("returns only same-location pairs at a tight threshold (1 km)") { + val pairs = pointsADf + .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0)")) + .selectExpr("id_a", "id_b") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(pairs === Set((0, 10), (1, 11))) + } + + it("returns the additional cross-row pair at a wide threshold (200 km)") { + // 200 km covers the ~111 km north neighbour from (0,0) -> (0,1) and the ~111 km + // east-west neighbours. + val pairs = pointsADf + .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 200000.0)")) + .selectExpr("id_a", "id_b") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + // (0,0)↔(0,0), (0,0)↔(1,0), (0,0)↔(0,1) + // (1,0)↔(0,0), (1,0)↔(1,0), (1,0)↔(0,1) + // (2,0)↔(1,0) (only one within 200 km — (0,0) is ~222 km, (0,1) is ~244 km) + assert(pairs.contains((0, 10))) + assert(pairs.contains((0, 11))) + assert(pairs.contains((0, 12))) + assert(pairs.contains((1, 10))) + assert(pairs.contains((1, 11))) + assert(pairs.contains((1, 12))) + assert(pairs.contains((2, 11))) + assert(!pairs.contains((2, 10))) + } + + it("supports a per-row column-distance threshold") { + import sparkSession.implicits._ + val withRadius = + Seq((0, "POINT(0 0)", 1000.0), (1, "POINT(1 0)", 1.0), (2, "POINT(2 0)", 200000.0)) + .toDF("id_a", "wkt", "radius_m") + .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a", "radius_m") + + val joined = + withRadius.join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, radius_m)")) + assert(planUsesBroadcastIndexJoin(joined)) + val pairs = joined + .selectExpr("id_a", "id_b") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + // id_a=0 with 1 km: only (0,0) self-match — id_b=10 + // id_a=1 with 1 m: only (1,0) self-match — id_b=11 + // id_a=2 with 200 km: only (1,0) at ~111 km — id_b=11 + assert(pairs === Set((0, 10), (1, 11), (2, 11))) + } + + it("supports LEFT OUTER with the right side broadcast") { + val joined = pointsADf.join( + broadcast(pointsBDf), + expr("ST_DWithin(geog_a, geog_b, 1000.0)"), + "left_outer") + assert(planUsesBroadcastIndexJoin(joined)) + // id_a=2 has no match within 1km -> NULL right side. Counts: (0,10),(1,11),(2,NULL). + assert(joined.count() === 3) + assert(joined.where("id_b IS NULL").count() === 1) + } + + it("matches points at the literal-radius boundary (single-side build expansion)") { + // (0,0) and (0,1) are ~111 km apart on WGS-84. With a literal radius of 200 km + // they must match. Regression check for the literal-radius coarse filter: only + // the build side is expanded by `d` (mirroring the geometry path); the stream + // envelope is left unexpanded but the radius is still carried in + // GeographyJoinShape for the refiner. + import sparkSession.implicits._ + val a = Seq((0, "POINT(0 0)")) + .toDF("id_a", "wkt") + .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a") + val b = Seq((10, "POINT(0 1)")) + .toDF("id_b", "wkt") + .selectExpr("id_b", "ST_GeogFromWKT(wkt, 4326) AS geog_b") + val joined = a.join(broadcast(b), expr("ST_DWithin(geog_a, geog_b, 200000.0)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 1) + } + + it("filters out NULL geographies on either side from the inner join") { + import sparkSession.implicits._ + val a = Seq((0, "POINT(0 0)"), (1, "POINT(1 0)"), (98, null)) + .toDF("id_a", "wkt") + .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a") + val b = Seq((10, "POINT(0 0)"), (11, "POINT(1 0)"), (99, null)) + .toDF("id_b", "wkt") + .selectExpr("id_b", "ST_GeogFromWKT(wkt, 4326) AS geog_b") + val joined = a.join(broadcast(b), expr("ST_DWithin(geog_a, geog_b, 1000.0)")) + assert(planUsesBroadcastIndexJoin(joined)) + val pairs = joined + .selectExpr("id_a", "id_b") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(pairs === Set((0, 10), (1, 11))) + } + + it("LEFT OUTER preserves stream rows whose geography is NULL") { + import sparkSession.implicits._ + val a = Seq((0, "POINT(0 0)"), (1, "POINT(1 0)"), (98, null)) + .toDF("id_a", "wkt") + .selectExpr("id_a", "ST_GeogFromWKT(wkt, 4326) AS geog_a") + val b = Seq((10, "POINT(0 0)"), (11, "POINT(1 0)")) + .toDF("id_b", "wkt") + .selectExpr("id_b", "ST_GeogFromWKT(wkt, 4326) AS geog_b") + val joined = a.join(broadcast(b), expr("ST_DWithin(geog_a, geog_b, 1000.0)"), "left_outer") + assert(planUsesBroadcastIndexJoin(joined)) + // 2 matches + 1 NULL stream row preserved with NULL right side. + assert(joined.count() === 3) + val nullRightIds = joined + .where("id_b IS NULL") + .selectExpr("id_a") + .collect() + .map(_.getInt(0)) + .toSet + assert(nullRightIds === Set(98)) + } + + it("rejects ST_DWithin(geog, geog, dist, useSpheroid) at analysis time") { + // The 4-arg ST_DWithin is geometry-only; passing Geography arguments fails at + // analysis time with a DATATYPE_MISMATCH before the planner runs. There is no + // Geography overload of the 4-arg form because Geography is always spheroidal. + val ex = intercept[Throwable] { + pointsADf + .join(broadcast(pointsBDf), expr("ST_DWithin(geog_a, geog_b, 1000.0, true)")) + .queryExecution + .sparkPlan + } + val msg = Iterator + .iterate[Throwable](ex)(t => if (t == null) null else t.getCause) + .takeWhile(_ != null) + .map(_.getMessage) + .mkString(" | ") + // Normalize for cross-version stability: assert on either the stable Spark + // error class (`DATATYPE_MISMATCH`) or the human-readable text. + val normalizedMsg = msg.toLowerCase(java.util.Locale.ROOT) + assert( + normalizedMsg.contains("st_dwithin") && + (normalizedMsg.contains("datatype_mismatch") || + normalizedMsg.contains("data type mismatch")), + s"expected analysis-time DATATYPE_MISMATCH on st_dwithin; got: $msg") + } + } + + // ---- NULL geography & empty-input coverage for the non-distance predicates ---- + // + // The relation refiner path (ST_Within / ST_Intersects / ST_Equals) shares the + // non-distance stream-shape constructor in BroadcastIndexJoinExec, so one test + // per predicate is sufficient for NULL filtering; empty inputs are covered once + // per direction (build / stream) since they exercise the same broadcast index + // build code path regardless of predicate. + + private lazy val polygonGeogDfWithNulls = { + import sparkSession.implicits._ + Seq( + (0, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"), + (1, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"), + (99, null) // NULL geography on the build side + ).toDF("poly_id", "wkt") + .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog") + } + + private lazy val pointGeogDfWithNulls = { + import sparkSession.implicits._ + Seq( + (0, "POINT(0.5 0.5)"), // matches polygon 0 + (1, "POINT(1.5 1.5)"), // matches polygon 1 + (98, null) // NULL geography on the stream side + ).toDF("pt_id", "wkt") + .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog") + } + + private lazy val emptyPolygonGeogDf = { + import sparkSession.implicits._ + Seq + .empty[(Int, String)] + .toDF("poly_id", "wkt") + .selectExpr("poly_id", "ST_GeogFromWKT(wkt, 4326) AS poly_geog") + } + + private lazy val emptyPointGeogDf = { + import sparkSession.implicits._ + Seq + .empty[(Int, String)] + .toDF("pt_id", "wkt") + .selectExpr("pt_id", "ST_GeogFromWKT(wkt, 4326) AS pt_geog") + } + + describe("Geography broadcast spatial join (NULL geographies and empty inputs)") { + + it("ST_Within: NULL geographies on either side are filtered from inner join") { + val joined = pointGeogDfWithNulls.join( + broadcast(polygonGeogDfWithNulls), + expr("ST_Within(pt_geog, poly_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + val pairs = joined + .selectExpr("poly_id", "pt_id") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(pairs === Set((0, 0), (1, 1))) + } + + it("ST_Intersects: NULL geographies on either side are filtered from inner join") { + val joined = pointGeogDfWithNulls.join( + broadcast(polygonGeogDfWithNulls), + expr("ST_Intersects(poly_geog, pt_geog)")) + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 2) + } + + it("ST_Equals: NULL geographies on either side are filtered from inner join") { + import sparkSession.implicits._ + val left = Seq((0, "POINT(0 0)"), (1, "POINT(1 1)"), (98, null)) + .toDF("id_l", "wkt") + .selectExpr("id_l", "ST_GeogFromWKT(wkt, 4326) AS geog_l") + val right = Seq((10, "POINT(0 0)"), (11, "POINT(1 1)"), (99, null)) + .toDF("id_r", "wkt") + .selectExpr("id_r", "ST_GeogFromWKT(wkt, 4326) AS geog_r") + val joined = left.join(broadcast(right), expr("ST_Equals(geog_l, geog_r)")) + assert(planUsesBroadcastIndexJoin(joined)) + val pairs = joined + .selectExpr("id_l", "id_r") + .collect() + .map(r => (r.getInt(0), r.getInt(1))) + .toSet + assert(pairs === Set((0, 10), (1, 11))) + } + + it("LEFT OUTER preserves stream rows with NULL geography (ST_Within)") { + val joined = pointGeogDfWithNulls.join( + broadcast(polygonGeogDfWithNulls), + expr("ST_Within(pt_geog, poly_geog)"), + "left_outer") + assert(planUsesBroadcastIndexJoin(joined)) + assert(joined.count() === 3) + val nullPolyIds = joined + .where("poly_id IS NULL") + .selectExpr("pt_id") + .collect() + .map(_.getInt(0)) + .toSet + assert(nullPolyIds === Set(98)) + } + + // Empty-input tests do NOT assert planUsesBroadcastIndexJoin: when one side of the + // join is statically empty, Spark's logical optimizer (PropagateEmptyRelation / + // ReplaceEmptyRelation) replaces the join with a LocalRelation/EmptyRelation + // before the spatial-join strategy runs, so there is no BroadcastIndexJoinExec + // to assert on. The intent of these tests is correctness, not plan shape. + + it("Empty broadcast side: inner join returns 0 rows") { + val joined = + pointGeogDf.join(broadcast(emptyPolygonGeogDf), expr("ST_Within(pt_geog, poly_geog)")) + assert(joined.count() === 0) + } + + it("Empty broadcast side: LEFT OUTER preserves all stream rows with NULL right side") { + val joined = pointGeogDf.join( + broadcast(emptyPolygonGeogDf), + expr("ST_Within(pt_geog, poly_geog)"), + "left_outer") + assert(joined.count() === 6) + assert(joined.where("poly_id IS NULL").count() === 6) + } + + it("Empty stream side: inner join returns 0 rows") { + val joined = + emptyPointGeogDf.join(broadcast(polygonGeogDf), expr("ST_Intersects(poly_geog, pt_geog)")) + assert(joined.count() === 0) + } + } }