Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.sedona.spark.SedonaContext
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.parser.ParserFactory
import org.apache.spark.sql.sedona_sql.optimization.Box2DCastResolutionRule
import org.apache.spark.sql.sedona_sql.optimization.{Box2DCastResolutionRule, Box3DCastResolutionRule}
import org.slf4j.{Logger, LoggerFactory}

class SedonaSqlExtensions extends (SparkSessionExtensions => Unit) {
Expand All @@ -42,6 +42,11 @@ class SedonaSqlExtensions extends (SparkSessionExtensions => Unit) {
// refusing arbitrary UDT-to-UDT casts.
e.injectResolutionRule(_ => new Box2DCastResolutionRule)

// Resolve geometry→Box3D casts during analysis. Only the forward direction lands here;
// the inverse cast (`CAST(box3d AS geometry)`) is deferred until Box3D has an
// `ST_GeomFromBox3D` counterpart driven by a concrete consumer.
e.injectResolutionRule(_ => new Box3DCastResolutionRule)

// Inject Sedona SQL parser
if (enableParser) {
// Try to inject the Sedona SQL parser but gracefully handle initialization failures.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.optimization

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.sedona_sql.UDT.{Box3DUDT, GeometryUDT}
import org.apache.spark.sql.sedona_sql.expressions.ST_Box3D

/**
* Analyzer rule that resolves Catalyst casts from Geometry to Box3D. Spark's `Cast.canCast`
* returns `false` for arbitrary UDT-to-UDT casts, so without this rule the analyzer would reject
* `CAST(geom AS box3d)`. We rewrite during analysis (before `CheckAnalysis`) so the downstream
* optimizer and codegen path see the expression tree of an ordinary Sedona expression:
*
* - `CAST(geom AS box3d)` → `ST_Box3D(geom)` (planar 3D bounding box of the geometry;
* geometries without a Z dimension fold into `zmin = zmax = 0` per PostGIS)
*
* The inverse direction (`CAST(box3d AS geometry)`) is intentionally deferred until Box3D has a
* `ST_GeomFromBox3D` counterpart and a concrete consumer has driven the choice of output geometry
* shape. Implicit type coercion is also out of scope here; it requires hooking into Catalyst's
* type coercion rules and is tracked separately.
*/
class Box3DCastResolutionRule extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case c: Cast
if c.child.resolved
&& c.child.dataType.isInstanceOf[GeometryUDT]
&& c.dataType.isInstanceOf[Box3DUDT] =>
ST_Box3D(Seq(c.child))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.sedona_sql.UDT.{Box3DUDT, GeometryUDT}
import org.apache.spark.sql.sedona_sql.expressions.ST_Box3D
import org.apache.spark.sql.sedona_sql.optimization.Box3DCastResolutionRule
import org.apache.spark.sql.types.LongType
import org.scalatest.funspec.AnyFunSpec

class Box3DCastResolutionRuleSuite extends AnyFunSpec {

private val rule = new Box3DCastResolutionRule

private def projectExprPlan(input: AttributeReference, expr: Expression): LogicalPlan = {
val rel = LocalRelation(input)
Project(Seq(Alias(expr, "out")()), rel)
}

describe("Box3DCastResolutionRule") {
it("rewrites Cast(geometry-typed expression, Box3DUDT) into ST_Box3D") {
val geomAttr = AttributeReference("g", GeometryUDT(), nullable = true)()
val cast = Cast(geomAttr, Box3DUDT)
val rewritten = rule(projectExprPlan(geomAttr, cast))
val outExpr =
rewritten.asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
assert(outExpr.isInstanceOf[ST_Box3D])
assert(outExpr.asInstanceOf[ST_Box3D].inputExpressions == Seq(geomAttr))
assert(outExpr.dataType.isInstanceOf[Box3DUDT])
}

it("leaves Cast(Box3D-typed expression, GeometryUDT) untouched (inverse cast not in scope)") {
val boxAttr = AttributeReference("b", Box3DUDT, nullable = true)()
val cast = Cast(boxAttr, GeometryUDT())
val rewritten = rule(projectExprPlan(boxAttr, cast))
val outExpr =
rewritten.asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
assert(outExpr.isInstanceOf[Cast])
}

it("leaves unrelated casts untouched") {
val geomAttr = AttributeReference("g", GeometryUDT(), nullable = true)()
val cast = Cast(Literal(1), LongType)
val rewritten = rule(projectExprPlan(geomAttr, cast))
val outExpr =
rewritten.asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
assert(outExpr.isInstanceOf[Cast])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,20 @@ package org.apache.sedona.sql.parser

import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, Box3DUDT, GeometryUDT}
import org.apache.spark.sql.types.DataType

class SedonaSqlAstBuilder extends SparkSqlAstBuilder {

/**
* Recognize Sedona UDT names (GEOMETRY, BOX2D) as primitive data types so SQL `CAST(... AS
* geometry)` / `CAST(... AS box2d)` parse to the matching UDT.
* Recognize Sedona UDT names (GEOMETRY, BOX2D, BOX3D) as primitive data types so SQL `CAST(... AS
* geometry)` / `CAST(... AS box2d)` / `CAST(... AS box3d)` parse to the matching UDT.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = {
ctx.getText.toUpperCase() match {
case "GEOMETRY" => GeometryUDT()
case "BOX2D" => Box2DUDT
case "BOX3D" => Box3DUDT
case _ => super.visitPrimitiveDataType(ctx)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql

import org.apache.sedona.common.geometryObjects.Box3D
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.sedona_sql.UDT.Box3DUDT

class Box3DCastSuite extends TestBaseScala {

/**
* SQL `CAST(... AS box3d)` parsing requires Sedona's `SedonaSqlAstBuilder` to be active. The
* test base randomizes `spark.sedona.enableParserExtension` across CI runs, and `SparkContext`
* is JVM-singleton so the active value can differ from this suite's session-level config. Probe
* directly by parsing a tiny CAST: this matches the behavior the SQL tests actually depend on,
* and caches the answer for the rest of the suite. DataFrame `.cast(...)` tests run
* unconditionally because the resolution rule is always injected.
*/
private lazy val sqlCastSupported: Boolean = {
try {
sparkSession
.sql("SELECT CAST(ST_GeomFromText('POINT (0 0)') AS box3d) AS b")
.collect()
true
} catch {
case _: org.apache.spark.sql.catalyst.parser.ParseException => false
}
}

describe("Geometry → Box3D Catalyst cast") {

it("DataFrame .cast(Box3DUDT) rewrites to ST_Box3D") {
import sparkSession.implicits._
val df = Seq("LINESTRING Z(0 0 -3, 5 10 7)").toDF("wkt")
val box = df
.select(expr("ST_GeomFromText(wkt)").alias("g"))
.select(col("g").cast(Box3DUDT).alias("b"))
.collect()
.head
.getAs[Box3D]("b")
assert(box == new Box3D(0.0, 0.0, -3.0, 5.0, 10.0, 7.0))
}

it("DataFrame .cast(Box3DUDT) on XY geometry folds Z = 0") {
import sparkSession.implicits._
val df = Seq("LINESTRING (0 0, 5 10)").toDF("wkt")
val box = df
.select(expr("ST_GeomFromText(wkt)").alias("g"))
.select(col("g").cast(Box3DUDT).alias("b"))
.collect()
.head
.getAs[Box3D]("b")
assert(box == new Box3D(0.0, 0.0, 0.0, 5.0, 10.0, 0.0))
}

it("DataFrame .cast(Box3DUDT) on NULL geometry returns null") {
val box = sparkSession
.sql("SELECT ST_GeomFromText(NULL) AS g")
.select(col("g").cast(Box3DUDT).alias("b"))
.collect()
.head
.getAs[Box3D]("b")
assert(box == null)
}

it("SQL CAST(geom AS box3d) returns the 3D bbox") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS box3d)` syntax")
val box = sparkSession
.sql("SELECT CAST(ST_GeomFromText('LINESTRING Z(0 0 -3, 5 10 7)') AS box3d) AS b")
.collect()
.head
.getAs[Box3D]("b")
assert(box == new Box3D(0.0, 0.0, -3.0, 5.0, 10.0, 7.0))
}

it("SQL CAST(geom AS box3d) on XY geometry folds Z = 0") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS box3d)` syntax")
val box = sparkSession
.sql("SELECT CAST(ST_GeomFromText('LINESTRING (0 0, 5 10)') AS box3d) AS b")
.collect()
.head
.getAs[Box3D]("b")
assert(box == new Box3D(0.0, 0.0, 0.0, 5.0, 10.0, 0.0))
}

it("SQL CAST(NULL geometry AS box3d) returns null") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS box3d)` syntax")
val box = sparkSession
.sql("SELECT CAST(ST_GeomFromText(NULL) AS box3d) AS b")
.collect()
.head
.getAs[Box3D]("b")
assert(box == null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,20 @@ package org.apache.sedona.sql.parser

import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, Box3DUDT, GeometryUDT}
import org.apache.spark.sql.types.DataType

class SedonaSqlAstBuilder extends SparkSqlAstBuilder {

/**
* Recognize Sedona UDT names (GEOMETRY, BOX2D) as primitive data types so SQL `CAST(... AS
* geometry)` / `CAST(... AS box2d)` parse to the matching UDT.
* Recognize Sedona UDT names (GEOMETRY, BOX2D, BOX3D) as primitive data types so SQL `CAST(...
* AS geometry)` / `CAST(... AS box2d)` / `CAST(... AS box3d)` parse to the matching UDT.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = {
ctx.getText.toUpperCase() match {
case "GEOMETRY" => GeometryUDT()
case "BOX2D" => Box2DUDT
case "BOX3D" => Box3DUDT
case _ => super.visitPrimitiveDataType(ctx)
}
}
Expand Down
Loading
Loading