Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ public class Haversine {
*/
public static final double AVG_EARTH_RADIUS = 6371008.0;

/**
* Polar radius of the WGS-84 spheroid, in meters. Used as a sphere radius when expanding
* envelopes so that the expansion upper-bounds both spherical and spheroidal distances.
*/
public static final double EARTH_POLAR_RADIUS = 6357000.0;

public static double distance(Geometry geom1, Geometry geom2, double avg_earth_radius) {
Coordinate coordinate1 =
geom1.getGeometryType().equals("Point")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.locationtech.jts.geom.Point;

public class HaversineEnvelopeTest {
private static final int SPHERE_RADIUS = 6357000;
private static final double SPHERE_RADIUS = Haversine.EARTH_POLAR_RADIUS;
private static final GeometryFactory factory = new GeometryFactory();

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,12 @@ case class BroadcastIndexJoinExec(
// asymmetric and we cannot rely on a SpatialPredicateEvaluator.
private lazy val refinerSwap: Boolean = indexBuildSide != windowJoinSide

private def newRefiner(): JoinRefiner =
if (geographyShape) new GeographyContainsRefiner(refinerSwap)
else new JtsRefiner(evaluator)
private def newRefiner(): JoinRefiner = {
if (geographyShape) {
if (distance.isDefined) new GeographyDistanceRefiner(refinerSwap)
else new GeographyRelationRefiner(spatialPredicate, refinerSwap)
} else new JtsRefiner(evaluator)
}

private def innerJoin(
streamIter: Iterator[(Geometry, UnsafeRow)],
Expand Down Expand Up @@ -288,6 +291,36 @@ case class BroadcastIndexJoinExec(
streamResultsRaw: RDD[UnsafeRow],
boundStreamShape: Expression) = {
distance match {
case Some(distanceExpression) if geographyShape =>
val boundDistance =
BindReferences.bindReference(distanceExpression, streamed.output)
// When the broadcast side already expanded its envelope by `d` (the
// literal-radius case, where the planner forwards the same distance to
// both sides so the per-row radius is also available here for the
// refiner), keep the stream envelope unexpanded. The coarse filter then
// matches the geometry path: `expand(build, d) ∩ stream`, not the wider
// `expand(build, d) ∩ expand(stream, d)`. When only the stream side
// received the distance (per-row radius bound to the stream side), the
// build side is unexpanded and we still need to expand the stream
// envelope to ensure the index returns all candidates within `d`.
val streamSideExpands = broadcast.distance.isEmpty
streamResultsRaw.map(row => {
val serialized = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (serialized == null) {
(null, row)
} else {
val geog = GeographyWKBSerializer.deserialize(serialized)
val radius = boundDistance.eval(row).asInstanceOf[Double]
val baseEnvelope = JoinedGeometry.geographyToEnvelopeGeometry(geog)
val shape = if (streamSideExpands) {
JoinedGeometry.geometryToExpandedEnvelope(baseEnvelope, radius, isGeography = true)
} else {
baseEnvelope
}
shape.setUserData(GeographyJoinShape(geog, row, radius))
(shape, row)
}
})
case Some(distanceExpression) =>
streamResultsRaw.map(row => {
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
Expand Down Expand Up @@ -348,8 +381,9 @@ case class BroadcastIndexJoinExec(
* Per-iter helper that decides whether a candidate from the broadcast index actually satisfies
* the spatial predicate, and unpacks the candidate's `userData` into the output row.
*
* Two implementations: `JtsRefiner` for the planar JTS path (existing behaviour, byte-equivalent
* to the previous inline code), and `GeographyContainsRefiner` for the new Geography-on-S2 path.
* Three implementations: `JtsRefiner` for the planar JTS path, `GeographyRelationRefiner` for
* non-distance Geography predicates (CONTAINS, WITHIN, INTERSECTS, EQUALS), and
* `GeographyDistanceRefiner` for ST_DWithin on Geography.
*/
private sealed trait JoinRefiner {
def matches(candidate: Geometry, streamShape: Geometry): Boolean
Expand All @@ -368,22 +402,51 @@ private final class JtsRefiner(evaluator: SpatialPredicateEvaluator) extends Joi
}

/**
* Refines candidates with `Functions.contains` (S2 spherical containment). Caching of the per-
* Geography S2 ShapeIndex happens inside `WKBGeography.getShapeIndexGeography()`, so we do not
* need a JTS-style PreparedGeometry cache here — the build side keeps the same Geography JVM
* instances for the lifetime of the broadcast.
* Refines candidates with the appropriate `org.apache.sedona.common.geography.Functions`
* predicate (CONTAINS / WITHIN / INTERSECTS / EQUALS). Caching of the per-Geography S2 ShapeIndex
* happens inside `WKBGeography.getShapeIndexGeography()`, so we do not need a JTS-style
* PreparedGeometry cache here — the build side keeps the same Geography JVM instances for the
* lifetime of the broadcast. `swap` flips operand order when the build side does not correspond
* to the predicate's left-hand argument (handles the `RIGHT JOIN` / right-broadcast case for
* asymmetric predicates).
*/
private final class GeographyContainsRefiner(swap: Boolean) extends JoinRefiner {
private final class GeographyRelationRefiner(predicate: SpatialPredicate, swap: Boolean)
extends JoinRefiner {
override def matches(candidate: Geometry, streamShape: Geometry): Boolean = {
val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape]
val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape]
if (swap) {
GeographyFunctions.contains(streamShapeData.geog, buildShape.geog)
} else {
GeographyFunctions.contains(buildShape.geog, streamShapeData.geog)
val (a, b) =
if (swap) (streamShapeData.geog, buildShape.geog)
else (buildShape.geog, streamShapeData.geog)
predicate match {
case SpatialPredicate.CONTAINS => GeographyFunctions.contains(a, b)
case SpatialPredicate.WITHIN => GeographyFunctions.within(a, b)
case SpatialPredicate.INTERSECTS => GeographyFunctions.intersects(a, b)
case SpatialPredicate.EQUALS => GeographyFunctions.equals(a, b)
case other =>
throw new UnsupportedOperationException(
s"Geography broadcast spatial join does not support predicate $other")
}
}

override def unpackRow(candidate: Geometry): UnsafeRow =
candidate.getUserData.asInstanceOf[GeographyJoinShape].row
}

/**
* Refines candidates for ST_DWithin on Geography. The per-row distance threshold is carried on
* the stream-side `GeographyJoinShape.radius`, populated when the stream shape is built.
*/
private final class GeographyDistanceRefiner(swap: Boolean) extends JoinRefiner {
override def matches(candidate: Geometry, streamShape: Geometry): Boolean = {
val buildShape = candidate.getUserData.asInstanceOf[GeographyJoinShape]
val streamShapeData = streamShape.getUserData.asInstanceOf[GeographyJoinShape]
val (a, b) =
if (swap) (streamShapeData.geog, buildShape.geog)
else (buildShape.geog, streamShapeData.geog)
GeographyFunctions.dWithin(a, b, streamShapeData.radius)
}

override def unpackRow(candidate: Geometry): UnsafeRow =
candidate.getUserData.asInstanceOf[GeographyJoinShape].row
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,21 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
isGeography = false,
extraCondition,
geographyShape = geographyShape))
// ST_Intersects / ST_Within / ST_Equals on Geography have no broadcast index path
// yet (the Geography refiner is ST_Contains-specific), so gate Geography inputs and
// let them fall back to row-by-row evaluation.
// ST_Intersects / ST_Within / ST_Equals on Geography route through the
// Geography-aware refiner (`GeographyRelationRefiner`); geometry inputs continue
// through `inferredJoinDetection`.
case ST_Intersects(Seq(leftShape, rightShape))
if isGeographyInput(leftShape) || isGeographyInput(rightShape) =>
Some(
JoinQueryDetection(
left,
right,
leftShape,
rightShape,
SpatialPredicate.INTERSECTS,
isGeography = false,
extraCondition,
geographyShape = true))
case ST_Intersects(Seq(leftShape, rightShape)) =>
inferredJoinDetection(
left,
Expand All @@ -239,6 +251,18 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
rightShape,
SpatialPredicate.INTERSECTS,
extraCondition)
case ST_Within(Seq(leftShape, rightShape))
if isGeographyInput(leftShape) || isGeographyInput(rightShape) =>
Some(
JoinQueryDetection(
left,
right,
leftShape,
rightShape,
SpatialPredicate.WITHIN,
isGeography = false,
extraCondition,
geographyShape = true))
case ST_Within(Seq(leftShape, rightShape)) =>
inferredJoinDetection(
left,
Expand All @@ -247,6 +271,18 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
rightShape,
SpatialPredicate.WITHIN,
extraCondition)
case ST_Equals(Seq(leftShape, rightShape))
if isGeographyInput(leftShape) || isGeographyInput(rightShape) =>
Some(
JoinQueryDetection(
left,
right,
leftShape,
rightShape,
SpatialPredicate.EQUALS,
isGeography = false,
extraCondition,
geographyShape = true))
case ST_Equals(Seq(leftShape, rightShape)) =>
inferredJoinDetection(
left,
Expand All @@ -260,16 +296,22 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
case pred: RS_Predicate =>
getRasterJoinDetection(left, right, pred, extraCondition)
case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
val geographyShape =
isGeographyInput(leftShape) || isGeographyInput(rightShape)
Some(
JoinQueryDetection(
left,
right,
leftShape,
rightShape,
SpatialPredicate.INTERSECTS,
isGeography = false,
isGeography = geographyShape,
condition,
Some(distance)))
Some(distance),
geographyShape = geographyShape))
// Note: the 4-arg ST_DWithin is geometry-only; on Geography input the Spark
// analyzer rejects the call before reaching this matcher, so no Geography
// guard is needed here.
case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) =>
val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
Some(
Expand Down Expand Up @@ -484,7 +526,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
}
} else {
queryDetection match {
// Geography ST_Contains has no partition/range path — fall back to row-by-row.
// Geography predicates (ST_Contains/Within/Intersects/Equals/DWithin) have no
// partition/range path — fall back to row-by-row.
case Some(detection) if detection.geographyShape =>
Nil
case Some(
Expand Down Expand Up @@ -843,9 +886,33 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
.map { distanceExpr =>
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
case Some(side) =>
if (broadcastSide.get == side) (Some(distanceExpr), None)
else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
else (None, Some(distanceExpr))
if (geographyShape) {
// Geography distance joins read the per-row radius from the stream-side
// GeographyJoinShape inside GeographyDistanceRefiner, so the radius MUST
// be available on the streamed side. The stream-side expression is later
// re-bound against `streamed.output` in BroadcastIndexJoinExec, so we
// can only forward it to the stream side when it is either a literal
// (no references) or already bound to the streamed side.
if (distanceExpr.references.isEmpty) {
// Literal: keep build-side expansion AND populate stream-side radius.
(Some(distanceExpr), Some(distanceExpr))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR expand both sides of the join (double expand)?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Yes, it double expand. The geography path was expanding both sides while the geometry path expands only one. I have fixed this so the geography follow the geometry approach.

} else if (broadcastSide.get == side) {
// Non-literal expression bound to the broadcast/index side cannot be
// re-bound against streamed.output. Reject up front rather than
// planning a broadcast geography join that will fail at execution.
throw new UnsupportedOperationException(
"Geography distance broadcast joins do not support non-literal " +
"distance expressions bound to the broadcast/index side; bind " +
"the distance expression to the streamed side or use a literal.")
} else {
// Bound to the streamed side: stream-only (no build-side expansion).
(None, Some(distanceExpr))
}
} else {
if (broadcastSide.get == side) (Some(distanceExpr), None)
else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
else (None, Some(distanceExpr))
}
case _ =>
throw new IllegalArgumentException(
"Distance expression must be bound to one side of the join")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import org.locationtech.jts.geom.{Envelope, Geometry, GeometryFactory}

/**
* Payload stored in `userData` on each Geography index entry. Carries both the deserialized
* Geography (for S2 predicate refinement) and the original row (for the join output).
* Geography (for S2 predicate refinement) and the original row (for the join output). For
* `ST_DWithin` joins, `radius` carries the per-row distance threshold from the row that produced
* this shape; for non-distance predicates it remains 0.0.
*/
case class GeographyJoinShape(geog: Geography, row: UnsafeRow)
case class GeographyJoinShape(geog: Geography, row: UnsafeRow, radius: Double = 0.0)

/**
* Utility functions for generating geometries for spatial join.
Expand Down Expand Up @@ -89,9 +91,8 @@ object JoinedGeometry {
* in meter
*/
private def expandEnvelopeForGeography(envelope: Envelope, distance: Double): Envelope = {
// Here we use the polar radius of the spheroid as the radius of the sphere, so that the expanded
// envelope will work for both spherical and spheroidal distances.
val sphereRadius = 6357000.0
Haversine.expandEnvelope(envelope, distance, sphereRadius)
// Use the polar radius of the spheroid as the radius of the sphere so that the expanded
// envelope upper-bounds both spherical and spheroidal distances.
Haversine.expandEnvelope(envelope, distance, Haversine.EARTH_POLAR_RADIUS)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ case class SpatialIndexExec(
val boundShape = BindReferences.bindReference(shape, child.output)
val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
val spatialRDD = distance match {
case Some(distanceExpression) if geographyShape =>
toExpandedGeographyEnvelopeRDD(
resultRaw,
boundShape,
BindReferences.bindReference(distanceExpression, child.output))
case Some(distanceExpression) =>
toExpandedEnvelopeRDD(
resultRaw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,38 @@ trait TraitJoinQueryBase {
spatialRdd
}

/**
* Geography variant of [[toExpandedEnvelopeRDD]]. Each row becomes a JTS geometry whose
* envelope is the Geography's lat/lng bounding rectangle expanded by `boundRadius` meters using
* the Haversine-based polar-radius approximation in
* [[JoinedGeometry.geometryToExpandedEnvelope]] (with `isGeography=true`). The Geography object
* and per-row radius are carried alongside the original row in `userData` via
* [[GeographyJoinShape]] so the join executor can perform S2-based ST_DWithin refinement.
*/
def toExpandedGeographyEnvelopeRDD(
rdd: RDD[UnsafeRow],
shapeExpression: Expression,
boundRadius: Expression): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(rdd
.flatMap { x =>
val geogBytes = shapeExpression.eval(x).asInstanceOf[Array[Byte]]
if (geogBytes == null) {
None
} else {
val geog = GeographyWKBSerializer.deserialize(geogBytes)
val distance = boundRadius.eval(x).asInstanceOf[Double]
val baseEnvelope = JoinedGeometry.geographyToEnvelopeGeometry(geog)
val expandedEnvelope =
JoinedGeometry.geometryToExpandedEnvelope(baseEnvelope, distance, isGeography = true)
expandedEnvelope.setUserData(GeographyJoinShape(geog, x.copy, distance))
Some(expandedEnvelope)
}
}
.toJavaRDD())
spatialRdd
}

def doSpatialPartitioning(
dominantShapes: SpatialRDD[Geometry],
followerShapes: SpatialRDD[Geometry],
Expand Down
Loading
Loading