diff --git a/geomesa-gt/geomesa-gt-partitioning/src/main/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartitionedPostgisDialect.scala b/geomesa-gt/geomesa-gt-partitioning/src/main/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartitionedPostgisDialect.scala index f2ef192cb34d..cbcd52c596cd 100644 --- a/geomesa-gt/geomesa-gt-partitioning/src/main/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartitionedPostgisDialect.scala +++ b/geomesa-gt/geomesa-gt-partitioning/src/main/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartitionedPostgisDialect.scala @@ -192,6 +192,7 @@ class PartitionedPostgisDialect(store: JDBCDataStore) extends PostGISDialect(sto } override def encodePostColumnCreateTable(att: AttributeDescriptor, sql: StringBuffer): Unit = { + import PartitionedPostgisDialect.Config.GeometryAttributeConversions att match { case gd: GeometryDescriptor => val nullable = gd.getMinOccurs <= 0 || gd.isNillable @@ -200,21 +201,16 @@ class PartitionedPostgisDialect(store: JDBCDataStore) extends PostGISDialect(sto if (i == -1 || (nullable && i != sql.length() - 8) || (!nullable && i != sql.length() - 17)) { logger.warn(s"Found geometry-type attribute but no geometry column binding: $sql") } else { - val srid = - Option(gd.getUserData.get(JDBCDataStore.JDBC_NATIVE_SRID).asInstanceOf[Integer]) - .orElse(Option(gd.getCoordinateReferenceSystem).flatMap(crs => Try(CRS.lookupEpsgCode(crs, true)).filter(_ != null).toOption)) - .map(_.intValue()) - .getOrElse(-1) + val srid = gd.getSrid.getOrElse(-1) val geomType = PartitionedPostgisDialect.GeometryMappings.getOrElse(gd.getType.getBinding, "GEOMETRY") - val geomTypeWithDims = - Option(gd.getUserData.get(Hints.COORDINATE_DIMENSION).asInstanceOf[Integer]).map(_.intValue) match { - case None | Some(2) => geomType - case Some(3) => s"${geomType}Z" - case Some(4) => s"${geomType}ZM" - case Some(d) => - throw new IllegalArgumentException( - s"PostGIS only supports geometries with 2, 3 and 4 dimensions, but found: $d") - } + val geomTypeWithDims = gd.getCoordinateDimensions match { + case None | Some(2) => geomType + case Some(3) => s"${geomType}Z" + case Some(4) => s"${geomType}ZM" + case Some(d) => + throw new IllegalArgumentException( + s"PostGIS only supports geometries with 2, 3 and 4 dimensions, but found: $d") + } sql.insert(i + 8, s" ($geomTypeWithDims, $srid)") } @@ -345,5 +341,15 @@ object PartitionedPostgisDialect { def getPagesPerRange: Int = Option(sft.getUserData.get(PagesPerRange).asInstanceOf[String]).map(int).getOrElse(128) def isFilterWholeWorld: Boolean = Option(sft.getUserData.get(FilterWholeWorld).asInstanceOf[String]).forall(_.toBoolean) } + + implicit class GeometryAttributeConversions(val d: GeometryDescriptor) extends AnyVal { + def getSrid: Option[Int] = + Option(d.getUserData.get(JDBCDataStore.JDBC_NATIVE_SRID)).map(int) + .orElse( + Option(d.getCoordinateReferenceSystem) + .flatMap(crs => Try(CRS.lookupEpsgCode(crs, true)).filter(_ != null).toOption.map(_.intValue()))) + def getCoordinateDimensions: Option[Int] = + Option(d.getUserData.get(Hints.COORDINATE_DIMENSION)).map(int) + } } } diff --git a/geomesa-gt/geomesa-gt-partitioning/src/test/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/EscapeTest.scala b/geomesa-gt/geomesa-gt-partitioning/src/test/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartionedPostgisDialectTest.scala similarity index 59% rename from geomesa-gt/geomesa-gt-partitioning/src/test/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/EscapeTest.scala rename to geomesa-gt/geomesa-gt-partitioning/src/test/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartionedPostgisDialectTest.scala index 33ede0fe2381..40a06762f929 100644 --- a/geomesa-gt/geomesa-gt-partitioning/src/test/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/EscapeTest.scala +++ b/geomesa-gt/geomesa-gt-partitioning/src/test/scala/org/locationtech/geomesa/gt/partition/postgis/dialect/PartionedPostgisDialectTest.scala @@ -8,23 +8,40 @@ package org.locationtech.geomesa.gt.partition.postgis.dialect +import org.geotools.feature.AttributeTypeBuilder +import org.geotools.jdbc.JDBCDataStore import org.junit.runner.RunWith +import org.locationtech.jts.geom.Point import org.specs2.mutable.Specification import org.specs2.runner.JUnitRunner @RunWith(classOf[JUnitRunner]) -class EscapeTest extends Specification { +class PartionedPostgisDialectTest extends Specification { "PartitionedPostgisDialect" should { + "Escape literal values" in { SqlLiteral("foo'bar").raw mustEqual "foo'bar" SqlLiteral("foo'bar").quoted mustEqual "'foo''bar'" SqlLiteral("foo\"bar").quoted mustEqual "'foo\"bar'" } + "Escape identifiers" in { FunctionName("foo'bar").raw mustEqual "foo'bar" FunctionName("foo'bar").quoted mustEqual "\"foo'bar\"" FunctionName("foo\"bar").quoted mustEqual "\"foo\"\"bar\"" } + + "handle strings or ints as user data" in { + foreach(Seq("4326", 4326, Int.box(4326), null)) { srid => + val builder = new AttributeTypeBuilder().binding(classOf[Point]) + builder.userData(JDBCDataStore.JDBC_NATIVE_SRID, srid) + builder.crs(org.locationtech.geomesa.utils.geotools.CRS_EPSG_4326) + val attr = builder.buildDescriptor("geom") + val buf = new StringBuffer("geometry") + new PartitionedPostgisDialect(null).encodePostColumnCreateTable(attr, buf) + buf.toString mustEqual "geometry (POINT, 4326)" + } + } } }