Skip to content

Commit

Permalink
Merge pull request databrickslabs#507 from databrickslabs/st_z
Browse files Browse the repository at this point in the history
add ST_Z
  • Loading branch information
Milos Colic authored Feb 1, 2024
2 parents 5e6a189 + ca64b35 commit dfa372f
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 7 deletions.
57 changes: 55 additions & 2 deletions docs/source/api/spatial-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@ st_x

.. function:: st_x(col)

Returns the x coordinate of the input geometry.
Returns the x coordinate of the centroid point of the input geometry.

:param col: Geometry
:type col: Column
Expand Down Expand Up @@ -1880,7 +1880,7 @@ st_y
****
.. function:: st_y(col)

Returns the y coordinate of the input geometry.
Returns the y coordinate of the centroid point of the input geometry.

:param col: Geometry
:type col: Column
Expand Down Expand Up @@ -2036,6 +2036,59 @@ st_ymin
+-----------------+


st_z
****
.. function:: st_z(col)

Returns the z coordinate of an arbitrary point of the input geometry `geom`.

:param col: Point Geometry
:type col: Column
:rtype: Column: DoubleType

:example:

.. tabs::
.. code-tab:: py

df = spark.createDataFrame([{'wkt': 'POINT (30 10 20)'}])
df.select(st_z('wkt')).show()
+-----------------+
|st_z(wkt) |
+-----------------+
| 20.0|
+-----------------+

.. code-tab:: scala

val df = List(("POINT (30 10 20)")).toDF("wkt")
df.select(st_z(col("wkt"))).show()
+-----------------+
|st_z(wkt) |
+-----------------+
| 20.0|
+-----------------+

.. code-tab:: sql

SELECT st_z("POINT (30 10 20)")
+-----------------+
|st_z(wkt) |
+-----------------+
| 20.0|
+-----------------+

.. code-tab:: r R

df <- createDataFrame(data.frame(wkt = "POINT (30 10 20)"))
showDF(select(df, st_z(column("wkt"))), truncate=F)
+-----------------+
|st_z(wkt) |
+-----------------+
| 20.0|
+-----------------+


st_zmax
*******

Expand Down
21 changes: 19 additions & 2 deletions python/mosaic/api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"st_zmax",
"st_x",
"st_y",
"st_z",
"flatten_polygons",
"grid_boundaryaswkb",
"grid_boundary",
Expand Down Expand Up @@ -753,7 +754,7 @@ def st_updatesrid(

def st_x(geom: ColumnOrName) -> Column:
"""
Returns the x coordinate of the input geometry `geom`.
Returns the x coordinate of the centroid point of the input geometry `geom`.
Parameters
----------
Expand All @@ -769,7 +770,7 @@ def st_x(geom: ColumnOrName) -> Column:

def st_y(geom: ColumnOrName) -> Column:
"""
Returns the y coordinate of the input geometry `geom`.
Returns the y coordinate of the centroid point of the input geometry `geom`.
Parameters
----------
Expand All @@ -783,6 +784,22 @@ def st_y(geom: ColumnOrName) -> Column:
return config.mosaic_context.invoke_function("st_y", pyspark_to_java_column(geom))


def st_z(geom: ColumnOrName) -> Column:
"""
Returns the z coordinate of an arbitrary point of the input geometry `geom`.
Parameters
----------
geom : Column
Returns
-------
Column (DoubleType)
"""
return config.mosaic_context.invoke_function("st_z", pyspark_to_java_column(geom))


def st_geometrytype(geom: ColumnOrName) -> Column:
"""
Returns the type of the input geometry `geom` (“POINT”, “LINESTRING”, “POLYGON” etc.).
Expand Down
21 changes: 20 additions & 1 deletion python/test/test_vector_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from pyspark.sql.functions import abs, col, first, lit, sqrt
from pyspark.sql.functions import abs, col, concat, first, lit, sqrt

from .context import api
from .utils import MosaicTestCase
Expand All @@ -27,6 +27,25 @@ def test_st_point(self):
)
self.assertListEqual([rw.points for rw in result], expected)

def test_st_z(self):
expected = [
0,
1,
]
result = (
self.spark.range(2)
.select(col("id").cast("double"))
.withColumn(
"points",
api.st_geomfromwkt(
concat(lit("POINT (9 9 "), "id", lit(")"))
),
)
.withColumn("z", api.st_z("points"))
.collect()
)
self.assertListEqual([rw.z for rw in result], expected)

def test_st_bindings_happy_flow(self):
# Checks that the python bindings do not throw exceptions
# Not testing the logic, since that is tested in Scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ trait MosaicGeometry extends GeometryWriter with Serializable {

def getCentroid: MosaicPoint

def getAnyPoint: MosaicPoint

def getDimension: Int

def isEmpty: Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ abstract class MosaicGeometryJTS(geom: Geometry) extends MosaicGeometry {
MosaicPointJTS(centroid)
}

override def getAnyPoint: MosaicPointJTS = {
// while this doesn't return the centroid but an arbitrary point via getCoordinate in JTS,
// inlike getCentroid this supports a Z coordinate.

val coord = geom.getCoordinate
val gf = new GeometryFactory()
val point = gf.createPoint(coord)
MosaicPointJTS(point)
}

override def isEmpty: Boolean = geom.isEmpty

override def boundary: MosaicGeometryJTS = MosaicGeometryJTS(geom.getBoundary)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object ST_X extends WithExpressionInfo {
override def name: String = "st_x"

override def usage: String =
"_FUNC_(expr1) - Returns x coordinate of a point or x coordinate of the centroid if the geometry isnt a point."
"_FUNC_(expr1) - Returns x coordinate of a point or x coordinate of the centroid if the geometry isn't a point."

override def example: String =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object ST_Y extends WithExpressionInfo {
override def name: String = "st_y"

override def usage: String =
"_FUNC_(expr1) - Returns y coordinate of a point or y coordinate of the centroid if the geometry isnt a point."
"_FUNC_(expr1) - Returns y coordinate of a point or y coordinate of the centroid if the geometry isn't a point."

override def example: String =
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.core.geometry.MosaicGeometry
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.geometry.base.UnaryVectorExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{DataType, DoubleType}

/**
* SQL expression that returns Z coordinate of the input point. Input must be a point.
*
* @param inputGeom
* Expression containing the geometry.
* @param expressionConfig
* Mosaic execution context, e.g. geometryAPI, indexSystem, etc. Additional
* arguments for the expression (expressionConfigs).
*/
case class ST_Z(
inputGeom: Expression,
expressionConfig: MosaicExpressionConfig
) extends UnaryVectorExpression[ST_Z](inputGeom, returnsGeometry = false, expressionConfig) {

override def dataType: DataType = DoubleType

override def geometryTransform(geometry: MosaicGeometry): Any = geometry.getAnyPoint.getZ

override def geometryCodeGen(geometryRef: String, ctx: CodegenContext): (String, String) = {
val resultRef = ctx.freshName("result")
val code = s"""double $resultRef = $geometryRef.getAnyPoint().getZ();"""
(code, resultRef)
}

}

/** Expression info required for the expression registration for spark SQL. */
object ST_Z extends WithExpressionInfo {

override def name: String = "st_z"

override def usage: String =
"_FUNC_(expr1) - Returns z coordinate of a point or z coordinate of an arbitrary point in geometry if it isn't a point."

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(a);
| 12.3
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[ST_Z](1, expressionConfig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
mosaicRegistry.registerExpression[ST_Within](expressionConfig)
mosaicRegistry.registerExpression[ST_X](expressionConfig)
mosaicRegistry.registerExpression[ST_Y](expressionConfig)
mosaicRegistry.registerExpression[ST_Z](expressionConfig)
mosaicRegistry.registerExpression[ST_Haversine](expressionConfig)

// noinspection ScalaDeprecation
Expand Down Expand Up @@ -600,6 +601,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
ColumnAdapter(ST_Translate(geom1.expr, xd.expr, yd.expr, expressionConfig))
def st_x(geom: Column): Column = ColumnAdapter(ST_X(geom.expr, expressionConfig))
def st_y(geom: Column): Column = ColumnAdapter(ST_Y(geom.expr, expressionConfig))
def st_z(geom: Column): Column = ColumnAdapter(ST_Z(geom.expr, expressionConfig))
def st_xmax(geom: Column): Column = ColumnAdapter(ST_MinMaxXYZ(geom.expr, expressionConfig, "X", "MAX"))
def st_xmin(geom: Column): Column = ColumnAdapter(ST_MinMaxXYZ(geom.expr, expressionConfig, "X", "MIN"))
def st_ymax(geom: Column): Column = ColumnAdapter(ST_MinMaxXYZ(geom.expr, expressionConfig, "Y", "MAX"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.functions.MosaicContext
import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
import org.scalatest.matchers.must.Matchers.noException
import org.scalatest.matchers.should.Matchers.{an, be, convertToAnyShouldWrapper}

trait ST_ZBehaviors extends MosaicSpatialQueryTest {

def stzBehavior(mosaicContext: MosaicContext): Unit = {
spark.sparkContext.setLogLevel("FATAL")
val mc = mosaicContext
import mc.functions._
val sc = spark
import sc.implicits._
mc.register(spark)

val rows = List(
("POINT (2 3 5)", 5),
("POINT (7 11 13)", 13),
("POINT (17 19 23)", 23),
("POINT (29 31 37)", 37)
)

val result = rows
.toDF("wkt", "expected")
.withColumn("result", st_z($"wkt"))
.where($"expected" === $"result")

result.count shouldBe 4
}

def stzCodegen(mosaicContext: MosaicContext): Unit = {
spark.sparkContext.setLogLevel("FATAL")
val mc = mosaicContext
val sc = spark
import mc.functions._
import sc.implicits._
mc.register(spark)

val rows = List(
("POINT (2 3 5)", 5),
("POINT (7 11 13)", 13),
("POINT (17 19 23)", 23),
("POINT (29 31 37)", 37)
)

val points = rows.toDF("wkt", "expected")

val result = points
.withColumn("result", st_z($"wkt"))
.where($"expected" === $"result")

val queryExecution = result.queryExecution
val plan = queryExecution.executedPlan

val wholeStageCodegenExec = plan.find(_.isInstanceOf[WholeStageCodegenExec])

wholeStageCodegenExec.isDefined shouldBe true

val codeGenStage = wholeStageCodegenExec.get.asInstanceOf[WholeStageCodegenExec]
val (_, code) = codeGenStage.doCodeGen()

noException should be thrownBy CodeGenerator.compile(code)

val stZ = ST_Z(lit(1).expr, mc.expressionConfig)
val ctx = new CodegenContext
an[Error] should be thrownBy stZ.genCode(ctx)
}

def auxiliaryMethods(mosaicContext: MosaicContext): Unit = {
spark.sparkContext.setLogLevel("FATAL")
val mc = mosaicContext
mc.register(spark)

val stZ = ST_Z(lit("POINT (2 3 4)").expr, mc.expressionConfig)

stZ.child shouldEqual lit("POINT (2 3 4)").expr
stZ.dataType shouldEqual DoubleType
noException should be thrownBy stZ.makeCopy(Array(stZ.child))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest
import org.apache.spark.sql.test.SharedSparkSession

class ST_ZTest extends MosaicSpatialQueryTest with SharedSparkSession with ST_ZBehaviors {

testAllGeometriesNoCodegen("Testing stZ NO_CODEGEN") { stzBehavior }
testAllGeometriesCodegen("Testing stZ CODEGEN") { stzBehavior }
testAllGeometriesCodegen("Testing stZ CODEGEN compilation") { stzCodegen }
testAllGeometriesNoCodegen("Testing stZ auxiliary methods") { auxiliaryMethods }

}

0 comments on commit dfa372f

Please sign in to comment.