diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index bab122a617e..3bb5ea7f8c3 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -538,5 +538,10 @@ object Catalog extends AbstractCatalog with Logging { // are only constructed when registerAll is called and Spark is set up. This lets the // categorization invariant test access `Catalog.expressions` without bootstrapping Spark. lazy val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] = - Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr(), new ST_Collect_Agg()) + Seq( + new ST_Envelope_Aggr, + new ST_Extent, + new ST_Intersection_Aggr, + new ST_Union_Aggr(), + new ST_Collect_Agg()) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala index 608ad5c1410..2140c8716b1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.common.Functions +import org.apache.sedona.common.geometryObjects.Box2D import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator @@ -164,6 +165,51 @@ private[apache] class ST_Envelope_Aggr def zero: Option[EnvelopeBuffer] = None } +/** + * Return the planar bounding box (Box2D) of all geometries in the given column. Returns NULL when + * the input contains no rows or all rows are null/empty geometries. Mirrors PostGIS ST_Extent. + * + * ST_Envelope_Aggr is left untouched (returns a polygon Geometry) for backwards compatibility. + */ +private[apache] class ST_Extent extends Aggregator[Geometry, Option[EnvelopeBuffer], Box2D] { + + val outputSerde: ExpressionEncoder[Box2D] = ExpressionEncoder[Box2D]() + + def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): Option[EnvelopeBuffer] = { + if (input == null || input.isEmpty) return buffer + val env = input.getEnvelopeInternal + val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, env.getMaxY) + buffer match { + case Some(b) => Some(b.merge(envBuffer)) + case None => Some(envBuffer) + } + } + + def merge( + buffer1: Option[EnvelopeBuffer], + buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = { + (buffer1, buffer2) match { + case (Some(b1), Some(b2)) => Some(b1.merge(b2)) + case (Some(_), None) => buffer1 + case (None, Some(_)) => buffer2 + case (None, None) => None + } + } + + def finish(reduction: Option[EnvelopeBuffer]): Box2D = { + reduction match { + case Some(b) => new Box2D(b.minX, b.minY, b.maxX, b.maxY) + case None => null + } + } + + def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = Encoders.product[Option[EnvelopeBuffer]] + + def outputEncoder: ExpressionEncoder[Box2D] = outputSerde + + def zero: Option[EnvelopeBuffer] = None +} + /** * Return the polygon intersection of all Polygon in the given column */ diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala index bba263d2d68..6e1314ba676 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala @@ -18,6 +18,7 @@ */ package org.apache.sedona.sql +import org.apache.sedona.common.geometryObjects.Box2D import org.apache.spark.sql.DataFrame import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory, Polygon} @@ -73,6 +74,48 @@ class aggregateFunctionTestScala extends TestBaseScala { assert(env.getMaxY == 4.0) } + it("Passed ST_Extent") { + val df = sparkSession.sql( + "SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES ('POINT (1 2)'), ('POINT (4 5)'), ('LINESTRING (-3 0, 0 0)') AS t(wkt)") + df.createOrReplaceTempView("t") + val bbox = + sparkSession.sql("SELECT ST_Extent(geom) AS bbox FROM t").take(1)(0).getAs[Box2D](0) + assert(bbox.getXMin == -3.0) + assert(bbox.getYMin == 0.0) + assert(bbox.getXMax == 4.0) + assert(bbox.getYMax == 5.0) + } + + it("ST_Extent returns null over zero rows") { + val emptyDf = sparkSession.sql( + "SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (NULL) AS t(wkt) WHERE wkt IS NOT NULL") + emptyDf.createOrReplaceTempView("empty_extent") + val result = sparkSession.sql("SELECT ST_Extent(geom) FROM empty_extent") + assert(result.take(1)(0).get(0) == null) + } + + it("ST_Extent returns null when all inputs are null or empty") { + val nullDf = sparkSession.sql( + "SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (CAST(NULL AS STRING)), ('POINT EMPTY'), ('POLYGON EMPTY') AS t(wkt)") + nullDf.createOrReplaceTempView("null_extent") + val result = sparkSession.sql("SELECT ST_Extent(geom) FROM null_extent") + assert(result.take(1)(0).get(0) == null) + } + + it("ST_Extent ignores null and empty rows mixed with valid geometries") { + val mixedDf = sparkSession.sql( + "SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (CAST(NULL AS STRING)), ('POINT EMPTY'), ('POINT (10 20)'), ('POINT (-5 -5)') AS t(wkt)") + mixedDf.createOrReplaceTempView("mixed_extent") + val bbox = sparkSession + .sql("SELECT ST_Extent(geom) FROM mixed_extent") + .take(1)(0) + .getAs[Box2D](0) + assert(bbox.getXMin == -5.0) + assert(bbox.getYMin == -5.0) + assert(bbox.getXMax == 10.0) + assert(bbox.getYMax == 20.0) + } + it("Passed ST_Union_aggr") { var polygonCsvDf = sparkSession.read