diff --git a/.github/actions/r_build/action.yml b/.github/actions/r_build/action.yml index 2fa260bb4..e970f2fdb 100644 --- a/.github/actions/r_build/action.yml +++ b/.github/actions/r_build/action.yml @@ -6,23 +6,54 @@ runs: - name: Setup R build environment shell: bash run: | - sudo apt-get update && sudo apt-get install -y curl libcurl4-openssl-dev pkg-config libharfbuzz-dev libfribidi-dev - - name: Download and unpack Spark + sudo apt-get update && sudo apt-get install -y curl libcurl4-openssl-dev pkg-config libharfbuzz-dev libfribidi-dev + - name: Create download location for Spark shell: bash run: | - sudo mkdir -p /usr/spark-download/raw sudo mkdir -p /usr/spark-download/unzipped + sudo mkdir -p /usr/spark-download/raw sudo chown -R $USER: /usr/spark-download/ - wget -P /usr/spark-download/raw https://archive.apache.org/dist/spark/spark-3.2.1/spark-3.2.1-bin-hadoop2.7.tgz + - name: Cache Spark download + id: cache-spark + uses: actions/cache@v3 + with: + path: /usr/spark-download/unzipped + key: r_build-spark + - if: ${{ steps.cache-spark.outputs.cache-hit != 'true' }} + name: Download and unpack Spark + shell: bash + run: | + wget -P /usr/spark-download/raw https://archive.apache.org/dist/spark/spark-3.2.1/spark-3.2.1-bin-hadoop2.7.tgz tar zxvf /usr/spark-download/raw/spark-3.2.1-bin-hadoop2.7.tgz -C /usr/spark-download/unzipped - - - - name: Build R package + - name: Create R environment shell: bash run: | - cd R sudo mkdir -p /usr/lib/R/site-library sudo chown -R $USER: /usr/lib/R/site-library + - name: Setup R + uses: r-lib/actions/setup-r@v2 + with: + r-version: ${{ matrix.R }} + use-public-rspm: true + - name: Install R dependencies + shell: bash + run: | + cd R + Rscript --vanilla install_deps.R + - name: Generate R bindings + shell: bash + run: | + cd R + Rscript --vanilla generate_R_bindings.R ../src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala + - name: Build R docs + shell: bash + run: | + cd R + Rscript --vanilla generate_docs.R + - name: Build R package + shell: bash + run: | + cd R Rscript --vanilla build_r_package.R - name: Test R package shell: bash diff --git a/.github/actions/scala_build/action.yml b/.github/actions/scala_build/action.yml index fd663ae29..950ae2200 100644 --- a/.github/actions/scala_build/action.yml +++ b/.github/actions/scala_build/action.yml @@ -15,12 +15,21 @@ runs: - name: Test and build the scala JAR - skip tests is false if: inputs.skip_tests == 'false' shell: bash - run: sudo mvn -q clean install + run: | + pip install databricks-mosaic-gdal==3.4.3 + sudo tar -xf /home/runner/.local/lib/python3.8/site-packages/databricks-mosaic-gdal/resources/gdal-3.4.3-filetree.tar.xz -C / + sudo tar -xhf /home/runner/.local/lib/python3.8/site-packages/databricks-mosaic-gdal/resources/gdal-3.4.3-symlinks.tar.xz -C / + sudo add-apt-repository ppa:ubuntugis/ubuntugis-unstable + sudo apt clean && sudo apt -o Acquire::Retries=3 update --fix-missing -y + sudo apt-get -o Acquire::Retries=3 update -y + sudo apt-get -o Acquire::Retries=3 install -y gdal-bin=3.4.3+dfsg-1~focal0 libgdal-dev=3.4.3+dfsg-1~focal0 python3-gdal=3.4.3+dfsg-1~focal0 + sudo mvn -q clean install - name: Build the scala JAR - skip tests is true if: inputs.skip_tests == 'true' shell: bash - run: sudo mvn -q clean install -DskipTests + run: sudo mvn -q clean install -DskipTests -Dscoverage.skip - name: Publish test coverage + if: inputs.skip_tests == 'false' uses: codecov/codecov-action@v1 - name: Copy Scala artifacts to GH Actions run shell: bash diff --git a/.github/workflows/build_main.yml b/.github/workflows/build_main.yml index 97313288e..ac5cb0623 100644 --- a/.github/workflows/build_main.yml +++ b/.github/workflows/build_main.yml @@ -16,7 +16,7 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.8.10 ] + python: [ 3.9 ] spark: [ 3.2.1 ] R: [ 4.1.2 ] steps: @@ -26,5 +26,7 @@ jobs: uses: ./.github/actions/scala_build - name: build python uses: ./.github/actions/python_build + - name: build R + uses: ./.github/actions/r_build - name: upload artefacts uses: ./.github/actions/upload_artefacts \ No newline at end of file diff --git a/.github/workflows/build_python.yml b/.github/workflows/build_python.yml index b0f4d4aee..c2492002c 100644 --- a/.github/workflows/build_python.yml +++ b/.github/workflows/build_python.yml @@ -12,7 +12,7 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.8.10 ] + python: [ 3.9 ] spark: [ 3.2.1 ] R: [ 4.1.2 ] steps: diff --git a/.github/workflows/build_r.yml b/.github/workflows/build_r.yml index 8ae7352b2..644ba9d7d 100644 --- a/.github/workflows/build_r.yml +++ b/.github/workflows/build_r.yml @@ -13,7 +13,7 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.8.10 ] + python: [ 3.9 ] spark: [ 3.2.1 ] R: [ 4.1.2 ] steps: diff --git a/.github/workflows/build_scala.yml b/.github/workflows/build_scala.yml index 6f8b52fad..c6297e6f0 100644 --- a/.github/workflows/build_scala.yml +++ b/.github/workflows/build_scala.yml @@ -11,7 +11,7 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.8.10 ] + python: [ 3.9 ] spark: [ 3.2.1 ] R: [ 4.1.2 ] steps: diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index cd0480f1b..02a48c40d 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python: [3.8.10] + python: [3.9] spark: [3.2.1] steps: - name: checkout code diff --git a/.gitignore b/.gitignore index b290fd7d8..975675c69 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ #IntelliJ files .idea *.iml +tmp_ #VSCode files .vscode @@ -65,6 +66,7 @@ coverage.xml .hypothesis/ .pytest_cache/ /python/test/.run/ +spatial_knn # Translations *.mo diff --git a/R/.gitignore b/R/.gitignore index bef411cb1..eb4bec116 100644 --- a/R/.gitignore +++ b/R/.gitignore @@ -1 +1,2 @@ **/.Rhistory +**/*.tar.gz diff --git a/R/build_r_package.R b/R/build_r_package.R index 807688f91..ab82b99b1 100644 --- a/R/build_r_package.R +++ b/R/build_r_package.R @@ -1,31 +1,3 @@ -repos = c( - "https://cran.ma.imperial.ac.uk" = "https://cran.ma.imperial.ac.uk" - ,"https://www.stats.bris.ac.uk/R" = "https://www.stats.bris.ac.uk/R" - ,"https://cran.rstudio.com/" = "https://cran.rstudio.com/" -) - -mirror_is_up <- function(x){ - out <- tryCatch({ - available.packages(contrib.url(x)) - } - ,error = function(cond){return(0)} - ,warning = function(cond){return(0)} - ,finally = function(cond){} - ) - return(length(out)) -} - -mirror_status = lapply(repos, mirror_is_up) -for(repo in names(mirror_status)){ - if (mirror_status[[repo]] > 1){ - repo <<- repo - break - } -} - -install.packages("pkgbuild", repos=repo) -install.packages("roxygen2", repos=repo) -install.packages("sparklyr", repos=repo) spark_location <- "/usr/spark-download/unzipped/spark-3.2.1-bin-hadoop2.7" Sys.setenv(SPARK_HOME = spark_location) @@ -33,22 +5,11 @@ library(SparkR, lib.loc = c(file.path(spark_location, "R", "lib"))) library(pkgbuild) -library(roxygen2) library(sparklyr) build_mosaic_bindings <- function(){ - # build functions - scala_file_path <- "../src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala" - system_cmd <- paste0(c("Rscript --vanilla generate_R_bindings.R", scala_file_path), collapse = " ") - system(system_cmd) - - # build doc - roxygen2::roxygenize("sparkR-mosaic/sparkrMosaic") - roxygen2::roxygenize("sparklyr-mosaic/sparklyrMosaic") - - ## build package pkgbuild::build("sparkR-mosaic/sparkrMosaic") pkgbuild::build("sparklyr-mosaic/sparklyrMosaic") diff --git a/R/generate_docs.R b/R/generate_docs.R new file mode 100644 index 000000000..06b23e6fa --- /dev/null +++ b/R/generate_docs.R @@ -0,0 +1,14 @@ +spark_location <- "/usr/spark-download/unzipped/spark-3.2.1-bin-hadoop2.7" +Sys.setenv(SPARK_HOME = spark_location) + +library(SparkR, lib.loc = c(file.path(spark_location, "R", "lib"))) +library(roxygen2) + +build_mosaic_docs <- function(){ + # build doc + roxygen2::roxygenize("sparkR-mosaic/sparkrMosaic") + roxygen2::roxygenize("sparklyr-mosaic/sparklyrMosaic") + +} + +build_mosaic_docs() \ No newline at end of file diff --git a/R/install_deps.R b/R/install_deps.R new file mode 100644 index 000000000..d05207329 --- /dev/null +++ b/R/install_deps.R @@ -0,0 +1,5 @@ +options(repos = c(CRAN = "https://packagemanager.posit.co/cran/__linux__/focal/latest")) + +install.packages("pkgbuild") +install.packages("roxygen2") +install.packages("sparklyr") \ No newline at end of file diff --git a/R/sparkR-mosaic/enableMosaic.R b/R/sparkR-mosaic/enableMosaic.R index 40a5c7b32..2f81da4f8 100644 --- a/R/sparkR-mosaic/enableMosaic.R +++ b/R/sparkR-mosaic/enableMosaic.R @@ -17,14 +17,11 @@ enableMosaic <- function( geometryAPI="JTS" ,indexSystem="H3" - ,rasterAPI="GDAL" ){ geometry_api <- sparkR.callJStatic(x="com.databricks.labs.mosaic.core.geometry.api.GeometryAPI", methodName="apply", geometryAPI) indexing_system <- sparkR.callJStatic(x="com.databricks.labs.mosaic.core.index.IndexSystemFactory", methodName="getIndexSystem", indexSystem) - raster_api <- sparkR.callJStatic(x="com.databricks.labs.mosaic.core.raster.api.RasterAPI", methodName="apply", rasterAPI) - - mosaic_context <- sparkR.newJObject(x="com.databricks.labs.mosaic.functions.MosaicContext", indexing_system, geometry_api, raster_api) + mosaic_context <- sparkR.newJObject(x="com.databricks.labs.mosaic.functions.MosaicContext", indexing_system, geometry_api) functions <<- sparkR.callJMethod(mosaic_context, "functions") # register the sql functions for use in sql() commands sparkR.callJMethod(mosaic_context, "register") diff --git a/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION b/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION index 3fe2ab1a6..eb25a280a 100644 --- a/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION +++ b/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION @@ -8,7 +8,7 @@ Description: This package extends SparkR to bring the Databricks Mosaic for geos License: Databricks Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.0 +RoxygenNote: 7.2.3 Collate: 'enableMosaic.R' 'generics.R' diff --git a/R/sparkR-mosaic/tests.R b/R/sparkR-mosaic/tests.R index 34b0c61f7..f4071c3d2 100644 --- a/R/sparkR-mosaic/tests.R +++ b/R/sparkR-mosaic/tests.R @@ -1,5 +1,3 @@ -repo<-"https://cran.ma.imperial.ac.uk/" - spark_location <- "/usr/spark-download/unzipped/spark-3.2.1-bin-hadoop2.7" Sys.setenv(SPARK_HOME = spark_location) library(SparkR, lib.loc = c(file.path(spark_location, "R", "lib"))) @@ -66,7 +64,7 @@ sdf <- withColumn(sdf, "transformed_geom", st_transform(column("geom_with_srid") # Grid functions sdf <- withColumn(sdf, "grid_longlatascellid", grid_longlatascellid(lit(1), lit(1), lit(1L))) sdf <- withColumn(sdf, "grid_pointascellid", grid_pointascellid(column("point_wkt"), lit(1L))) -sdf <- withColumn(sdf, "grid_boundaryaswkb", grid_boundaryaswkb( SparkR::cast(lit(1), "long"))) +sdf <- withColumn(sdf, "grid_boundaryaswkb", grid_boundaryaswkb(column("grid_pointascellid"))) sdf <- withColumn(sdf, "grid_polyfill", grid_polyfill(column("wkt"), lit(1L))) sdf <- withColumn(sdf, "grid_tessellateexplode", grid_tessellateexplode(column("wkt"), lit(1L))) sdf <- withColumn(sdf, "grid_tessellate", grid_tessellate(column("wkt"), lit(1L))) @@ -74,7 +72,7 @@ sdf <- withColumn(sdf, "grid_tessellate", grid_tessellate(column("wkt"), lit(1L) # Deprecated sdf <- withColumn(sdf, "point_index_lonlat", point_index_lonlat(lit(1), lit(1), lit(1L))) sdf <- withColumn(sdf, "point_index_geom", point_index_geom(column("point_wkt"), lit(1L))) -sdf <- withColumn(sdf, "index_geometry", index_geometry( SparkR::cast(lit(1), "long"))) +sdf <- withColumn(sdf, "index_geometry", index_geometry(column("point_index_geom"))) sdf <- withColumn(sdf, "polyfill", polyfill(column("wkt"), lit(1L))) sdf <- withColumn(sdf, "mosaic_explode", mosaic_explode(column("wkt"), lit(1L))) sdf <- withColumn(sdf, "mosaicfill", mosaicfill(column("wkt"), lit(1L))) diff --git a/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION b/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION index 9bf65215a..0e5bcbb38 100644 --- a/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION +++ b/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION @@ -8,7 +8,7 @@ Description: This package extends sparklyr to bring the Databricks Mosaic for ge License: Databricks Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.0 +RoxygenNote: 7.2.3 Collate: 'enableMosaic.R' 'sparkFunctions.R' diff --git a/notebooks/prototypes/grid_tiles/00 Download STACs.py b/notebooks/prototypes/grid_tiles/00 Download STACs.py new file mode 100644 index 000000000..a19bfd9d1 --- /dev/null +++ b/notebooks/prototypes/grid_tiles/00 Download STACs.py @@ -0,0 +1,281 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC ## Install the libraries and prepare the environment + +# COMMAND ---------- + +# MAGIC %md +# MAGIC For this demo we will require a few spatial libraries that can be easily installed via pip install. We will be using gdal, rasterio, pystac and databricks-mosaic for data download and data manipulation. We will use planetary computer as the source of the raster data for the analysis. + +# COMMAND ---------- + +# MAGIC %pip install databricks-mosaic rasterio==1.3.5 --quiet gdal==3.4.3 pystac pystac_client planetary_computer tenacity rich + +# COMMAND ---------- + +import library +import pystac_client +import planetary_computer +import mosaic as mos + +from pyspark.sql import functions as F + +mos.enable_mosaic(spark, dbutils) +mos.enable_gdal(spark) + +# COMMAND ---------- + +# MAGIC %reload_ext autoreload +# MAGIC %autoreload 2 +# MAGIC %reload_ext library + +# COMMAND ---------- + +spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We will download census data from TIGER feed for this demo. The data can be downloaded as a zip to dbfs (or managed volumes). + +# COMMAND ---------- + +dbutils.fs.rm("/FileStore/geospatial/odin/census/", True) +dbutils.fs.mkdirs("/FileStore/geospatial/odin/census/") + +# COMMAND ---------- + +import urllib.request +urllib.request.urlretrieve( + "https://www2.census.gov/geo/tiger/TIGER2021/COUNTY/tl_2021_us_county.zip", + "/dbfs/FileStore/geospatial/odin/census/data.zip" +) + +# COMMAND ---------- + +# MAGIC %sh ls -al /dbfs/FileStore/geospatial/odin/census/ + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Mosaic has specialised readers for shape files and other GDAL supported formats. We dont need to unzip the data zip file. Just need to pass "vsizip" option to the reader. + +# COMMAND ---------- + +census_df = mos.read().format("multi_read_ogr")\ + .option("vsizip", "true")\ + .option("chunkSize", "50")\ + .load("dbfs:/FileStore/geospatial/odin/census/data.zip")\ + .cache() # We will cache the loaded data to avoid schema inference being done repeatedly for each query + +# COMMAND ---------- + +# MAGIC %md +# MAGIC For this exmaple we will focus on Alaska counties. Alska state code is 02 so we will apply a filter to our ingested data. + +# COMMAND ---------- + +census_df.where("STATEFP == 2").display() + +# COMMAND ---------- + +to_display = census_df\ + .where("STATEFP == 2")\ + .withColumn( + "geom_0", + mos.st_updatesrid("geom_0", "geom_0_srid", F.lit(4326)) + )\ + .select("geom_0") + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC to_display geom_0 geometry 50 + +# COMMAND ---------- + +cells = census_df\ + .where("STATEFP == 2")\ + .withColumn( + "geom_0", + mos.st_updatesrid("geom_0", "geom_0_srid", F.lit(4326)) + )\ + .withColumn("geom_0_srid", F.lit(4326))\ + .withColumn( + "grid", + mos.grid_tessellateexplode("geom_0", F.lit(3)) + ) + +# COMMAND ---------- + +cells.display() + +# COMMAND ---------- + +to_display = cells.select(mos.st_simplify("grid.wkb", F.lit(0.1)).alias("wkb")) + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC to_display wkb geometry 100000 + +# COMMAND ---------- + +# MAGIC %md +# MAGIC It is fairly easy to interface with the pysta_client and a remote raster data catalogs. We can browse resource collections and individual assets. + +# COMMAND ---------- + +time_range = "2021-06-01/2021-06-30" + +# COMMAND ---------- + +cell_jsons = cells\ + .withColumn("area_id", F.hash("geom_0"))\ + .withColumn("h3", F.col("grid.index_id"))\ + .groupBy("h3")\ + .agg( + mos.st_union_agg("grid.wkb").alias("geom_1") + )\ + .withColumn("geojson", mos.st_asgeojson(mos.grid_boundaryaswkb("h3")))\ + .drop("count", "geom_1") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Stac catalogs support easy download for area of interest provided as geojsons. With this in mind we will convert all our H3 cells of interest into geojsons and prepare stac requests. + +# COMMAND ---------- + +cell_jsons.display() + +# COMMAND ---------- + +cell_jsons.count() + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC cell_jsons h3 h3 + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Our framework allows for easy preparation of stac requests with only one line of code. This data is delta ready as this point and can easily be stored for lineage purposes. + +# COMMAND ---------- + +eod_items = library.get_assets_for_cells(cell_jsons.repartition(200), time_range ,"sentinel-2-l2a" ).cache() +eod_items.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC From this point we can easily extract the download links for items of interest. + +# COMMAND ---------- + +dbutils.fs.rm("/FileStore/geospatial/odin/alaska/", True) +dbutils.fs.mkdirs("/FileStore/geospatial/odin/alaska/") + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC DROP DATABASE IF EXISTS odin_alaska CASCADE; +# MAGIC CREATE DATABASE IF NOT EXISTS odin_alaska; + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC USE odin_alaska; + +# COMMAND ---------- + +def download_band(eod_items, band_name): + to_download = eod_items\ + .withColumn("timestamp", F.col("item_properties.datetime"))\ + .groupBy("item_id", "timestamp")\ + .agg( + *[F.first(cn).alias(cn) for cn in eod_items.columns if cn not in ["item_id"]] + )\ + .withColumn("date", F.to_date("timestamp"))\ + .withColumn("href", F.col("asset.href"))\ + .where( + f"asset.name == '{band_name}'" + ) + + spark.sql(f"DROP TABLE IF EXISTS alaska_{band_name}") + dbutils.fs.rm(f"/FileStore/geospatial/odin/alaska/{band_name}", True) + dbutils.fs.mkdirs(f"/FileStore/geospatial/odin/alaska/{band_name}") + + catalof_df = to_download\ + .withColumn( + "outputfile", + library.download_asset("href", F.lit(f"/dbfs/FileStore/geospatial/odin/alaska/{band_name}"), + F.concat(F.hash(F.rand()), F.lit(".tif"))) + ) + + catalof_df.write\ + .mode("overwrite")\ + .option("overwriteSchema", "true")\ + .format("delta")\ + .saveAsTable(f"alaska_{band_name}") + + +# COMMAND ---------- + +import rich.table + +region = census_df.where("STATEFP == 2").select(mos.st_asgeojson("geom_0").alias("geojson")).limit(1).collect()[0]["geojson"] + +catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1", + modifier=planetary_computer.sign_inplace, +) + +search = catalog.search( + collections=["sentinel-2-l2a"], + intersects=region, + datetime=time_range +) + +items = search.item_collection() + +table = rich.table.Table("Asset Key", "Description") +for asset_key, asset in items[0].assets.items(): + table.add_row(asset_key, asset.title) + +table + +# COMMAND ---------- + +bands = [] +for asset_key, asset in items[0].assets.items(): + bands.append(asset_key) + +bands = [b for b in bands if b not in ["visual", "preview", "safe-manifest", "tilejson", "rendered_preview", "granule-metadata", "inspire-metadata", "product-metadata", "datastrip-metadata"]] +bands + +# COMMAND ---------- + +for band in bands: + download_band(eod_items, band) + +# COMMAND ---------- + +# MAGIC %fs ls /FileStore/geospatial/odin/alaska/B08 + +# COMMAND ---------- + +import rasterio +from matplotlib import pyplot +from rasterio.plot import show + +fig, ax = pyplot.subplots(1, figsize=(12, 12)) +raster = rasterio.open("""/dbfs/FileStore/geospatial/odin/alaska/B08/2764922.tif""") +show(raster, ax=ax, cmap='Greens') +pyplot.show() + +# COMMAND ---------- + + diff --git a/notebooks/prototypes/grid_tiles/01 Gridded EOD Data - BNG.py b/notebooks/prototypes/grid_tiles/01 Gridded EOD Data - BNG.py new file mode 100644 index 000000000..88ac850b0 --- /dev/null +++ b/notebooks/prototypes/grid_tiles/01 Gridded EOD Data - BNG.py @@ -0,0 +1,442 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC ## Install the libraries and prepare the environment + +# COMMAND ---------- + +# MAGIC %pip install databricks-mosaic rasterio==1.3.5 --quiet gdal==3.4.3 pystac pystac_client planetary_computer tenacity rich pandas==1.5.3 + +# COMMAND ---------- + +spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") +spark.conf.set("spark.sql.adaptive.enabled", "false") +# spark.conf.set("spark.sql.inMemoryColumnarStorage.batchSize", "100") +spark.conf.set("spark.databricks.labs.mosaic.index.system", "BNG") +spark.conf.set("spark.databricks.labs.mosaic.geometry.api", "JTS") + +# COMMAND ---------- + +import mosaic as mos +from pyspark.sql import functions as F + +mos.enable_mosaic(spark, dbutils) +mos.enable_gdal(spark) + +# COMMAND ---------- + +import library +import rasterio + +from io import BytesIO +from matplotlib import pyplot +from rasterio.io import MemoryFile + +# COMMAND ---------- + +# MAGIC %reload_ext autoreload +# MAGIC %autoreload 2 +# MAGIC %reload_ext library + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Data load + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can easily browse the data we have downloaded in the notebook 00. The download metadata is stored as a delta table. + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC USE odin_uk; +# MAGIC SHOW TABLES; + +# COMMAND ---------- + +catalog_df = \ + spark.read.table("uk_b02")\ + .withColumn("souce_band", F.lit("B02"))\ + .repartition(200)\ + .cache() +catalog_df.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC For the purpose of raster data analysis mosaic framework provides a distributed gdal data readers. +# MAGIC We can also retile the images on read to make sure the imagery is balanced and more parallelised. + +# COMMAND ---------- + +# rst_tile -> rst_load +tiles_df = catalog_df\ + .repartition(200)\ + .withColumn("tile", mos.rst_tile("outputfile", F.lit(32)))\ + .withColumn("tile", mos.rst_subdivide("tile", F.lit(8))) + +# COMMAND ---------- + +tiles_df = ( + spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("dbfs:/FileStore/geospatial/odin/uk/B08") + .withColumn("tile", mos.rst_subdivide("tile", F.lit(32))) + .withColumn("size", mos.rst_memsize("tile")) +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC At this point all our imagery is held in memory, but we can easily access it and visualise it. + +# COMMAND ---------- + +tiles_df.count() + +# COMMAND ---------- + +to_plot = tiles_df.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[4]["tile"]["raster"]) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Mosaic framework provides the same tessellation principles for both vector and raster data. We can project both vector and raster data into a unified grid and from there it is very easy to combine and join raster to raster, vector to vector and raster to vector data. + +# COMMAND ---------- + +grid_tessellate_df = tiles_df\ + .repartition(200)\ + .withColumn("tile", mos.rst_tessellate("tile", F.lit("1km")))\ + .withColumn("index_id", F.col("tile.index_id")) + +to_plot = grid_tessellate_df.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[23]["tile"]["raster"]) + +# COMMAND ---------- + +grid_tessellate_df.display() + +# COMMAND ---------- + +def index_band(band_table, resolution): + catalog_df = \ + spark.read.table(band_table)\ + .withColumn("souce_band", F.col("asset.name")) + + tiles_df = catalog_df\ + .repartition(200)\ + .withColumn("tile", mos.rst_tile("outputfile", F.lit(100)))\ + .where(mos.rst_tryopen("tile"))\ + .withColumn("tile", mos.rst_subdivide("tile", F.lit(8)))\ + .withColumn("size", mos.rst_memsize("tile")) + + grid_tessellate_df = tiles_df\ + .repartition(200)\ + .withColumn("tile", mos.rst_tessellate("tile", F.lit(resolution)))\ + .withColumn("index_id", F.col("tile.index_id"))\ + .repartition(200) + + grid_tessellate_df\ + .write.mode("overwrite")\ + .option("overwriteSchema", "true")\ + .format("delta")\ + .saveAsTable(f"{band_table}_indexed") + +# COMMAND ---------- + +tables_to_index = spark.sql("SHOW TABLES")\ + .where("tableName not like '%indexed'")\ + .where("tableName not like '%gridded'")\ + .where("tableName not like '%tmp%'")\ + .where("tableName not like '%tiles%'")\ + .select("tableName").collect() +tables_to_index = [tbl["tableName"] for tbl in tables_to_index] +tables_to_index + +# COMMAND ---------- + +for tbl in tables_to_index: + index_band(tbl, "1km") + +# COMMAND ---------- + +grid_tessellate_df = spark.read.table("uk_b02_indexed") +grid_tessellate_df.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Raster for arbitrary corridors. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC To illustrate how easy is to combine vector and raster data we will use a traditionally hard problem. Extraction of raster data for an arbitrary corridors. + +# COMMAND ---------- + +# MAGIC %fs ls /FileStore/geospatial/odin/os_census/ + +# COMMAND ---------- + +greenspace_df = mos.read().format("multi_read_ogr")\ + .option("vsizip", "false")\ + .option("chunkSize", "500")\ + .load("dbfs:/FileStore/geospatial/odin/os_census/OS Open Greenspace (ESRI Shape File) GB/data/GB_GreenspaceSite.shp")\ + .cache() + +# COMMAND ---------- + +greenspace_df.display() + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC greenspace_df geom_0 geometry(bng) 100 + +# COMMAND ---------- + +aoi_df = greenspace_df\ + .select(F.col("geom_0").alias("wkt"))\ + .withColumn("feature_id", F.hash("wkt"))\ + .select( + mos.grid_tessellateexplode("wkt", F.lit("1km")).alias("grid"), + "feature_id" + )\ + .select("grid.*", "feature_id")\ + .withColumn("wkt", mos.st_astext("wkb")) + +# COMMAND ---------- + +aoi_df.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can visualise all the cells of interest for the provided arbitrary corridor. Since we are now operating in grid space it is very easy to get all raster images that match this specification. + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC aoi_df wkt geometry(bng) + +# COMMAND ---------- + +cells_of_interest = grid_tessellate_df.repartition(200).join(aoi_df, on=["index_id"]) + +# COMMAND ---------- + +cells_of_interest.display() + +# COMMAND ---------- + +result = cells_of_interest\ + .withColumn("geojson", mos.st_buffer(mos.st_setsrid(mos.as_json(mos.st_asgeojson("wkb")), F.lit(27700)), F.lit(10)))\ + .select("date", "index_id", "tile", "geojson", "feature_id")\ + .withColumn("tile", mos.rst_clip("tile", "geojson"))\ + .withColumn("parent_id", F.substring("index_id", 0, 3))\ + .groupBy("feature_id", "parent_id").agg( + mos.rst_merge_agg("tile").alias("tile") + )\ + .groupBy("feature_id").agg( + mos.rst_merge_agg("tile").alias("tile") + ) + +result.display() + +# COMMAND ---------- + +result.write.mode("overwrite").saveAsTable("uk_green_spaces_tiles") + +# COMMAND ---------- + +spark.read.table("uk_green_spaces_tiles").count() + +# COMMAND ---------- + +to_plot = spark.read.table("uk_green_spaces_tiles").orderBy(F.length("tile.raster").desc()).limit(100).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[1]["tile"]["raster"]) + +# COMMAND ---------- + +library.plot_raster(to_plot[2]["tile"]["raster"]) + +# COMMAND ---------- + +library.plot_raster(to_plot[23]["tile"]["raster"]) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Our framework provides a very easy way to provide rasterio lambda functions that we can distribute and scale up without any involvement from the end user. + +# COMMAND ---------- + +def mean_band_1(dataset): + try: + return dataset.statistics(bidx = 1).mean + except: + return 0.0 + +with_measurement = cells_of_interest.withColumn( + "rasterio_lambda", library.rasterio_lambda("tile.raster", lambda dataset: mean_band_1(dataset) ) +) + +# COMMAND ---------- + +with_measurement.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable("mosaic_odin_uk_gridded") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Raster to Timeseries projection + +# COMMAND ---------- + +# MAGIC %md +# MAGIC With this power of expression that rasterio provides and power of distribution that mosaic provides we can easily convert rasters to numerical values with arbitrary mathematical complexity. Since all of our imagery is timestamped, our raster to number projection in effect is creating time series bound to H3 cells. + +# COMMAND ---------- + +with_measurement = spark.read.table("mosaic_odin_uk_gridded") + +# COMMAND ---------- + +with_measurement.where("rasterio_lambda > 0").display() + +# COMMAND ---------- + +measurements = with_measurement\ + .select( + "index_id", + "date", + "rasterio_lambda", + "wkb" + )\ + .where("rasterio_lambda > 0")\ + .groupBy("index_id", "date")\ + .agg( + F.avg("rasterio_lambda").alias("measure"), + F.first("wkb").alias("wkb") + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC At this point our data is effectively became timeseries data and can be modeled as virtual IOT devices that are fixed in spatial context. + +# COMMAND ---------- + +measurements.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can easily visualise data for individual dates in a spatial contex by leveraging our H3 locations. + +# COMMAND ---------- + +df_06_03 = measurements.where("date == '2021-06-03'").drop("tile") + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC df_06_03 index_id bng 5000 + +# COMMAND ---------- + +grid_tessellate_df = spark.read.table("uk_b02_indexed") +grid_tessellate_df.display() + +# COMMAND ---------- + +kring_images = grid_tessellate_df\ + .where(mos.rst_tryopen("tile"))\ + .withColumn("kring", mos.grid_cellkringexplode("index_id", F.lit(1)))\ + .groupBy("kring")\ + .agg(mos.rst_merge_agg("tile").alias("tile")) + +# COMMAND ---------- + +to_plot = kring_images.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[2][1]["raster"]) + +# COMMAND ---------- + +# for peatlands is 984 instead of 960 +expanded_imgs = kring_images\ + .withColumn("bbox", F.expr("rst_boundingbox(tile)"))\ + .withColumn("cent", mos.st_centroid("bbox"))\ + .withColumn("cent_cell", mos.grid_pointascellid("cent", F.lit("1km")))\ + .withColumn("clip_region", mos.grid_boundaryaswkb("cent_cell"))\ + .withColumn("clip_region", mos.st_buffer(mos.st_setsrid(mos.as_json(mos.st_asgeojson("clip_region")), F.lit(27700)), F.lit(16)))\ + .where(mos.st_area("clip_region") > 0)\ + .withColumn("tile", mos.rst_clip("tile", "clip_region")) + +# COMMAND ---------- + +to_plot = expanded_imgs.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[2][1]["raster"]) + +# COMMAND ---------- + +to_plot = expanded_imgs\ + .where(F.array_contains(mos.grid_cellkring("kring", F.lit(1)), F.lit("TQ3757")))\ + .select("bbox", "clip_region", "cent_cell") + + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC to_plot bbox geometry(bng) + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC to_plot clip_region geometry(bng) + +# COMMAND ---------- + +rolling_tiles = expanded_imgs\ + .withColumn("tile", mos.rst_to_overlapping_tiles("tile", F.lit(30), F.lit(30), F.lit(50)))\ + .withColumn("bbox", mos.rst_boundingbox("tile")) + +# COMMAND ---------- + +to_plot = rolling_tiles.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[1]["tile"]["raster"]) + +# COMMAND ---------- + +to_plot = rolling_tiles.select("bbox") + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC to_plot bbox geometry(bng) + +# COMMAND ---------- + + diff --git a/notebooks/prototypes/grid_tiles/01 Gridded EOD Data.py b/notebooks/prototypes/grid_tiles/01 Gridded EOD Data.py new file mode 100644 index 000000000..a6cd78bdf --- /dev/null +++ b/notebooks/prototypes/grid_tiles/01 Gridded EOD Data.py @@ -0,0 +1,292 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC ## Install the libraries and prepare the environment + +# COMMAND ---------- + +# MAGIC %pip install databricks-mosaic rasterio==1.3.5 --quiet gdal==3.4.3 pystac pystac_client planetary_computer tenacity rich + +# COMMAND ---------- + +spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") +spark.conf.set("spark.sql.adaptive.enabled", "false") + +# COMMAND ---------- + +import library +import mosaic as mos +import rasterio + +from io import BytesIO +from matplotlib import pyplot +from rasterio.io import MemoryFile +from pyspark.sql import functions as F + +mos.enable_mosaic(spark, dbutils) +mos.enable_gdal(spark) + +# COMMAND ---------- + +# MAGIC %reload_ext autoreload +# MAGIC %autoreload 2 +# MAGIC %reload_ext library + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Data load + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can easily browse the data we have downloaded in the notebook 00. The download metadata is stored as a delta table. + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC USE odin_alaska; +# MAGIC SHOW TABLES; + +# COMMAND ---------- + +catalog_df = \ + spark.read.table("alaska_b04")\ + .withColumn("souce_band", F.lit("B04")) +catalog_df.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC For the purpose of raster data analysis mosaic framework provides a distributed gdal data readers. +# MAGIC We can also retile the images on read to make sure the imagery is balanced and more parallelised. + +# COMMAND ---------- + +tiles_df = catalog_df\ + .repartition(200, F.rand())\ + .withColumn("raster", mos.rst_subdivide("outputfile", F.lit(8)))\ + .withColumn("size", mos.rst_memsize("raster"))\ + .where(~mos.rst_isempty("raster")) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC At this point all our imagery is held in memory, but we can easily access it and visualise it. + +# COMMAND ---------- + +to_plot = tiles_df.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[7]["raster"]) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Mosaic framework provides the same tessellation principles for both vector and raster data. We can project both vector and raster data into a unified grid and from there it is very easy to combine and join raster to raster, vector to vector and raster to vector data. + +# COMMAND ---------- + +grid_tessellate_df = tiles_df\ + .repartition(200, F.rand())\ + .withColumn("raster", mos.rst_tessellate("raster", F.lit(6)))\ + .withColumn("index_id", F.col("raster.index_id")) + +# COMMAND ---------- + +to_plot = grid_tessellate_df.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[2]["raster"]["raster"]) + +# COMMAND ---------- + +spark.read.table("alaska_b08")\ + .withColumn("souce_band", F.col("asset.name"))\ + .repartition(200, F.rand())\ + .where(F.expr("rst_tryopen(outputfile)"))\ + .withColumn("raster", mos.rst_subdivide("outputfile", F.lit(8)))\ + + .display() + +# COMMAND ---------- + +def index_band(band_table, resolution): + catalog_df = \ + spark.read.table(band_table)\ + .where(F.expr("rst_tryopen(outputfile)"))\ + .withColumn("souce_band", F.col("asset.name")) + + tiles_df = catalog_df\ + .repartition(200, F.rand())\ + .withColumn("raster", mos.rst_subdivide("outputfile", F.lit(8)))\ + .withColumn("size", mos.rst_memsize("raster")) + + grid_tessellate_df = tiles_df\ + .repartition(200, F.rand())\ + .withColumn("raster", mos.rst_tessellate("raster", F.lit(resolution)))\ + .withColumn("index_id", F.col("raster.index_id")) + + grid_tessellate_df.write.mode("overwrite").format("delta").saveAsTable(f"{band_table}_indexed") + +# COMMAND ---------- + +tables_to_index = spark.sql("SHOW TABLES").where("tableName not like '%indexed'").select("tableName").collect() +tables_to_index = [tbl["tableName"] for tbl in tables_to_index] +tables_to_index + +# COMMAND ---------- + +index_band("alaska_b02", 6) + +# COMMAND ---------- + +index_band("alaska_b03", 6) + +# COMMAND ---------- + +index_band("alaska_b04", 6) + +# COMMAND ---------- + +index_band("alaska_b08", 6) + +# COMMAND ---------- + +grid_tessellate_df = spark.read.table("alaska_b02_indexed") +grid_tessellate_df.limit(20).display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Raster for arbitrary corridors. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC To illustrate how easy is to combine vector and raster data we will use a traditionally hard problem. Extraction of raster data for an arbitrary corridors. + +# COMMAND ---------- + +line_example = "LINESTRING(-158.34445841325555 68.0176784075422,-155.55393106950555 68.0423396963395,-154.82883341325555 67.84431100260183,-159.33322794450555 67.81114172848677,-160.01438028825555 67.47684671455214,-154.43332560075555 67.56925103744871,-154.01584513200555 67.30791374746678,-160.16818888200555 67.25700024664256,-160.58566935075555 66.94924133006975,-153.73020060075555 67.0693906319206,-154.49924356950555 66.70715520513478,-160.12424356950555 66.70715520513478,-159.02561075700555 66.37476822845568,-154.56516153825555 66.49774379983036,-155.04855997575555 66.22462528148408,-158.76193888200555 66.16254082040112,-157.94895060075555 65.94851918639993,-155.64182169450555 66.0021934684043,-158.58615763200555 66.55900493948819,-155.26828653825555 67.43472555587037,-161.64035685075555 67.86087797718164,-161.66232950700555 67.44315575603868)" + +line_df = spark.createDataFrame([line_example], "string")\ + .select(F.col("value").alias("wkt"))\ + .select( + mos.grid_tessellateexplode("wkt", F.lit(6)).alias("grid") + )\ + .select("grid.*") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can visualise all the cells of interest for the provided arbitrary corridor. Since we are now operating in grid space it is very easy to get all raster images that match this specification. + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC line_df index_id h3 + +# COMMAND ---------- + +cells_of_interest = grid_tessellate_df.repartition(40, F.rand()).join(line_df, on=["index_id"]) + +# COMMAND ---------- + +cells_of_interest.limit(10).display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Our framework provides a very easy way to provide rasterio lambda functions that we can distribute and scale up without any involvement from the end user. + +# COMMAND ---------- + +src = rasterio.open("/dbfs/FileStore/geospatial/odin/dais23demo/1805789896tif") +avg = src.statistics(bidx = 1).mean +avg + +# COMMAND ---------- + +def mean_band_1(dataset): + try: + return dataset.statistics(bidx = 1).mean + except: + return 0.0 + +with_measurement = cells_of_interest.withColumn( + "rasterio_lambda", library.rasterio_lambda("raster.raster", lambda dataset: mean_band_1(dataset) ) +) + +# COMMAND ---------- + +with_measurement.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable("mosaic_odin_gridded") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Raster to Timeseries projection + +# COMMAND ---------- + +# MAGIC %md +# MAGIC With this power of expression that rasterio provides and power of distribution that mosaic provides we can easily convert rasters to numerical values with arbitrary mathematical complexity. Since all of our imagery is timestamped, our raster to number projection in effect is creating time series bound to H3 cells. + +# COMMAND ---------- + +with_measurement = spark.read.table("mosaic_odin_gridded") + +# COMMAND ---------- + +with_measurement.where("rasterio_lambda > 0").display() + +# COMMAND ---------- + +measurements = with_measurement\ + .select( + "index_id", + "date", + "rasterio_lambda", + "wkb" + )\ + .where("rasterio_lambda > 0")\ + .groupBy("index_id", "date")\ + .agg( + F.avg("rasterio_lambda").alias("measure"), + F.first("wkb").alias("wkb") + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC At this point our data is effectively became timeseries data and can be modeled as virtual IOT devices that are fixed in spatial context. + +# COMMAND ---------- + +measurements.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can easily visualise data for individual dates in a spatial contex by leveraging our H3 locations. + +# COMMAND ---------- + +df_06_20 = measurements.where("date == '2021-06-20'") +df_06_03 = measurements.where("date == '2021-06-03'") + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC df_06_20 index_id h3 5000 + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC df_06_03 index_id h3 5000 + +# COMMAND ---------- + + diff --git a/notebooks/prototypes/grid_tiles/03 Band Stacking.py b/notebooks/prototypes/grid_tiles/03 Band Stacking.py new file mode 100644 index 000000000..da54328eb --- /dev/null +++ b/notebooks/prototypes/grid_tiles/03 Band Stacking.py @@ -0,0 +1,164 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC ## Install the libraries and prepare the environment + +# COMMAND ---------- + +# MAGIC %pip install databricks-mosaic rasterio==1.3.5 --quiet gdal==3.4.3 pystac pystac_client planetary_computer tenacity rich + +# COMMAND ---------- + +import library +import mosaic as mos +import rasterio + +from io import BytesIO +from matplotlib import pyplot +from rasterio.io import MemoryFile +from pyspark.sql import functions as F + +mos.enable_mosaic(spark, dbutils) +mos.enable_gdal(spark) + +# COMMAND ---------- + +# MAGIC %reload_ext autoreload +# MAGIC %autoreload 2 +# MAGIC %reload_ext library + +# COMMAND ---------- + +spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") +spark.conf.set("spark.sql.adaptive.enabled", "false") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Data load + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We can easily browse the data we have downloaded in the notebook 00. The download metadata is stored as a delta table. + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC USE odin_alaska + +# COMMAND ---------- + +df_b02 = spark.read.table("alaska_b02_indexed")\ + .withColumn("h3", F.col("raster.index_id")) +df_b03 = spark.read.table("alaska_b03_indexed")\ + .withColumn("h3", F.col("raster.index_id")) +df_b04 = spark.read.table("alaska_b04_indexed")\ + .withColumn("h3", F.col("raster.index_id")) +df_b08 = spark.read.table("alaska_b08_indexed")\ + .withColumn("h3", F.col("raster.index_id")) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC For the purpose of raster data analysis mosaic framework provides a distributed gdal data readers. +# MAGIC We can also retile the images on read to make sure the imagery is balanced and more parallelised. + +# COMMAND ---------- + +df_b03.limit(20).display() + +# COMMAND ---------- + +df_b03.groupBy("h3", "date").count().display() + +# COMMAND ---------- + +to_plot = df_b03.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[37]["raster"]["raster"]) + +# COMMAND ---------- + +counts = df_b03.select("h3", "date").groupBy("h3", "date").count() + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC counts h3 h3 100000000 + +# COMMAND ---------- + +df_b02_resolved = df_b02.groupBy("h3", "date")\ + .agg(mos.rst_merge(F.collect_list("raster.raster")).alias("raster")) + +df_b03_resolved = df_b03.groupBy("h3", "date")\ + .agg(mos.rst_merge(F.collect_list("raster.raster")).alias("raster")) + +df_b04_resolved = df_b04.groupBy("h3", "date")\ + .agg(mos.rst_merge(F.collect_list("raster.raster")).alias("raster")) + +df_b08_resolved = df_b08.groupBy("h3", "date")\ + .agg(mos.rst_merge(F.collect_list("raster.raster")).alias("raster")) + + +# COMMAND ---------- + +df_b03_resolved.groupBy("h3", "date").count().display() + +# COMMAND ---------- + +# MAGIC %%mosaic_kepler +# MAGIC df_b03_resolved h3 h3 1000000 + +# COMMAND ---------- + +stacked_df = df_b02_resolved\ + .repartition(200, F.rand())\ + .withColumnRenamed("raster", "b02")\ + .join( + df_b03_resolved\ + .repartition(200, F.rand())\ + .withColumnRenamed("raster", "b03"), + on = ["h3", "date"] + )\ + .join( + df_b04_resolved\ + .repartition(200, F.rand())\ + .withColumnRenamed("raster", "b04"), + on = ["h3", "date"] + )\ + .join( + df_b08_resolved\ + .repartition(200, F.rand())\ + .withColumnRenamed("raster", "b08"), + on = ["h3", "date"] + )\ + .withColumn("raster", mos.rst_mergebands(F.array("b04", "b03", "b02", "b08"))) # b04 = red b03 = blue b02 = green b08 = nir + +# COMMAND ---------- + +stacked_df.count() + +# COMMAND ---------- + +stacked_df.limit(50).display() + +# COMMAND ---------- + +ndvi_test = stacked_df.withColumn( + "ndvi", mos.rst_ndvi("raster", F.lit(4), F.lit(1)) +) + +# COMMAND ---------- + +to_plot = ndvi_test.limit(50).collect() + +# COMMAND ---------- + +library.plot_raster(to_plot[4]["ndvi"]) + +# COMMAND ---------- + + diff --git a/notebooks/prototypes/grid_tiles/04 SAM Integration.py b/notebooks/prototypes/grid_tiles/04 SAM Integration.py new file mode 100644 index 000000000..c9e516ef3 --- /dev/null +++ b/notebooks/prototypes/grid_tiles/04 SAM Integration.py @@ -0,0 +1,140 @@ +# Databricks notebook source + + +# COMMAND ---------- + +# MAGIC %pip install rasterio==1.3.5 --quiet gdal==3.4.3 pystac pystac_client planetary_computer torch transformers + +# COMMAND ---------- + +import library +import sam_lib +import mosaic as mos +import rasterio + +from io import BytesIO +from matplotlib import pyplot +from rasterio.io import MemoryFile +from pyspark.sql import functions as F + +mos.enable_mosaic(spark, dbutils) +mos.enable_gdal(spark) + +# COMMAND ---------- + +# MAGIC %reload_ext autoreload +# MAGIC %autoreload 2 +# MAGIC %reload_ext library +# MAGIC %reload_ext sam_lib + +# COMMAND ---------- + +spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") +spark.conf.set("spark.sql.adaptive.enabled", "false") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Data load and model load + +# COMMAND ---------- + +with_measurement = spark.read.table("mosaic_odin_gridded") + +# COMMAND ---------- + +library.plot_raster(with_measurement.limit(50).collect()[0]["raster"]) + +# COMMAND ---------- + +import torch +from PIL import Image +import requests +from transformers import SamModel, SamProcessor +import torch + +# COMMAND ---------- + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) +processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + +# COMMAND ---------- + +tiles = with_measurement.limit(50).collect() + +# COMMAND ---------- + +raster = tiles[1]["raster"] +raw_image = Image.open(BytesIO(raster)) + +# COMMAND ---------- + +library.plot_raster(raster) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Apply SAM on one of the tiles + +# COMMAND ---------- + +def get_masks(raw_image): + inputs = processor(raw_image, return_tensors="pt").to(device) + image_embeddings = model.get_image_embeddings(inputs["pixel_values"]) + inputs.pop("pixel_values", None) + inputs.update({"image_embeddings": image_embeddings}) + + with torch.no_grad(): + outputs = model(**inputs) + + masks = processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu() + ) + return masks + +def get_scores(raw_image): + inputs = processor(raw_image, return_tensors="pt").to(device) + image_embeddings = model.get_image_embeddings(inputs["pixel_values"]) + inputs.pop("pixel_values", None) + inputs.update({"image_embeddings": image_embeddings}) + + with torch.no_grad(): + outputs = model(**inputs) + + scores = outputs.iou_scores + return scores + +# COMMAND ---------- + +scores = get_scores(raw_image) +masks = get_masks(raw_image) + +# COMMAND ---------- + +sam_lib.show_masks_on_image(raw_image, masks[0], scores) + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Scaling model scoring with pandas UDFs + +# COMMAND ---------- + +import pandas as pd +from pyspark.sql.functions import col, pandas_udf +from pyspark.sql.types import LongType + +# Declare the function and create the UDF +def apply_sam(rasters: pd.Series) -> pd.Series: + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + return rasters\ + .apply(raster: Image.open(BytesIO(raster)))\ + .apply(image: get_masks(image)) diff --git a/notebooks/prototypes/grid_tiles/library.py b/notebooks/prototypes/grid_tiles/library.py new file mode 100644 index 000000000..830b59b9a --- /dev/null +++ b/notebooks/prototypes/grid_tiles/library.py @@ -0,0 +1,148 @@ +import shapely.geometry +import mosaic as mos +import pystac_client +import planetary_computer +import json +import requests +import rasterio +from io import BytesIO +from matplotlib import pyplot +from rasterio.io import MemoryFile +import rasterio +from matplotlib import pyplot +from rasterio.plot import show + +from pyspark.sql.types import * +from pyspark.sql import functions as F +from pyspark.sql.functions import udf + +from pyspark.sql.functions import pandas_udf +import pandas as pd + +catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1", + modifier=planetary_computer.sign_inplace +) + + +def generate_cells(extent, resolution, spark, mos): + polygon = shapely.geometry.box(*extent, ccw=True) + wkt_poly = str(polygon.wkt) + cells = spark.createDataFrame([[wkt_poly]], ["geom"]) + cells = cells.withColumn("grid", mos.grid_tessellateexplode("geom", F.lit(resolution))) + return cells + + +@udf("array") +def get_assets(item): + item_dict = json.loads(item) + assets = item_dict["assets"] + return [json.dumps({**{"name": asset}, **assets[asset]}) for asset in assets] + + +@pandas_udf("array") +def get_items(geojson: pd.Series, datetime: pd.Series, collections: pd.Series) -> pd.Series: + + from tenacity import retry, wait_exponential + + @retry(wait=wait_exponential(multiplier=2, min=4, max=120)) + def search_with_retry(geojson, catalog, collection, dt): + search = catalog.search( + collections = collection, + intersects = geojson, + datetime = dt + ) + items = search.item_collection() + return [json.dumps(item.to_dict()) for item in items] + + def search_catalog(geojson, catalog, collection, dt): + try: + return search_with_retry(geojson, catalog, collection, dt) + except Exception as inst: + return [str(inst)] + + catalog = pystac_client.Client.open( + "https://planetarycomputer.microsoft.com/api/stac/v1", + modifier=planetary_computer.sign_inplace + ) + + dt = datetime[0] + coll = collections[0] + return geojson.apply( + lambda gj: search_catalog(gj, catalog, coll, dt) + ) + + +def get_assets_for_cells(cells_df, period, source): + return cells_df\ + .withColumn("items", get_items("geojson", F.lit(period), F.array(F.lit(source))))\ + .repartition(200, F.rand())\ + .withColumn("items", F.explode("items"))\ + .withColumn("assets", get_assets("items"))\ + .repartition(200, F.rand())\ + .withColumn("assets", F.explode("assets"))\ + .withColumn("asset", F.from_json(F.col("assets"), MapType(StringType(), StringType())))\ + .withColumn("item", F.from_json(F.col("items"), MapType(StringType(), StringType())))\ + .withColumn("item_properties", F.from_json("item.properties", MapType(StringType(), StringType())))\ + .withColumn("item_collection", F.col("item.collection"))\ + .withColumn("item_bbox", F.col("item.bbox"))\ + .withColumn("item_id", F.col("item.id"))\ + .withColumn("stac_version", F.col("item.stac_version"))\ + .drop("assets", "items", "item")\ + .repartition(200, F.rand()) + + +def get_unique_hrefs(assets_df, item_name): + return assets_df\ + .select( + "area_id", + "h3", + "asset.name", + "asset.href", + "item_id", + F.to_date("item_properties.datetime").alias("date") + ).where( + f"name == '{item_name}'" + ).groupBy( + "href", "item_id", "date" + )\ + .agg(F.first("h3").alias("h3")) + + +@udf("string") +def download_asset(href, dir_path, filename): + try: + outpath = f"{dir_path}/{filename}" + # Make the actual request, set the timeout for no data to 10 seconds and enable streaming responses so we don't have to keep the large files in memory + request = requests.get(href, timeout=100, stream=True) + + # Open the output file and make sure we write in binary mode + with open(outpath, 'wb') as fh: + # Walk through the request response in chunks of 1024 * 1024 bytes, so 1MiB + for chunk in request.iter_content(1024 * 1024): + # Write the chunk to the file + fh.write(chunk) + # Optionally we can check here if the download is taking too long + return outpath + except: + return "" + + +def plot_raster(raster): + fig, ax = pyplot.subplots(1, figsize=(12, 12)) + + with MemoryFile(BytesIO(raster)) as memfile: + with memfile.open() as src: + show(src, ax=ax) + pyplot.show() + + +def rasterio_lambda(raster, lambda_f): + @udf("double") + def f_udf(f_raster): + with MemoryFile(BytesIO(f_raster)) as memfile: + with memfile.open() as dataset: + x = lambda_f(dataset) + return float(x) + + return f_udf(raster) \ No newline at end of file diff --git a/notebooks/prototypes/grid_tiles/sam_lib.py b/notebooks/prototypes/grid_tiles/sam_lib.py new file mode 100644 index 000000000..f785178f7 --- /dev/null +++ b/notebooks/prototypes/grid_tiles/sam_lib.py @@ -0,0 +1,91 @@ +import numpy as np +import matplotlib.pyplot as plt + +def show_mask(mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + +def show_boxes_on_image(raw_image, boxes): + plt.figure(figsize=(10,10)) + plt.imshow(raw_image) + for box in boxes: + show_box(box, plt.gca()) + plt.axis('on') + plt.show() + +def show_points_on_image(raw_image, input_points, input_labels=None): + plt.figure(figsize=(10,10)) + plt.imshow(raw_image) + input_points = np.array(input_points) + if input_labels is None: + labels = np.ones_like(input_points[:, 0]) + else: + labels = np.array(input_labels) + show_points(input_points, labels, plt.gca()) + plt.axis('on') + plt.show() + +def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None): + plt.figure(figsize=(10,10)) + plt.imshow(raw_image) + input_points = np.array(input_points) + if input_labels is None: + labels = np.ones_like(input_points[:, 0]) + else: + labels = np.array(input_labels) + show_points(input_points, labels, plt.gca()) + for box in boxes: + show_box(box, plt.gca()) + plt.axis('on') + plt.show() + + +def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None): + plt.figure(figsize=(10,10)) + plt.imshow(raw_image) + input_points = np.array(input_points) + if input_labels is None: + labels = np.ones_like(input_points[:, 0]) + else: + labels = np.array(input_labels) + show_points(input_points, labels, plt.gca()) + for box in boxes: + show_box(box, plt.gca()) + plt.axis('on') + plt.show() + + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + + +def show_masks_on_image(raw_image, masks, scores): + if len(masks.shape) == 4: + masks = masks.squeeze() + if scores.shape[0] == 1: + scores = scores.squeeze() + + nb_predictions = scores.shape[-1] + fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15)) + + for i, (mask, score) in enumerate(zip(masks, scores)): + mask = mask.cpu().detach() + axes[i].imshow(np.array(raw_image)) + show_mask(mask, axes[i]) + axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}") + axes[i].axis("off") + plt.show() \ No newline at end of file diff --git a/pom.xml b/pom.xml index f9ec2b99a..eb4d8d27d 100644 --- a/pom.xml +++ b/pom.xml @@ -93,7 +93,7 @@ h3 - 3.7.0 + 3.7.3 org.locationtech.jts @@ -157,6 +157,27 @@ + + org.scoverage + scoverage-maven-plugin + 1.4.11 + + + scoverage-report + package + + check + report-only + + + + + ${minimum.coverage} + true + ${scala.version} + skipTests=false + + net.alchim31.maven @@ -185,6 +206,7 @@ true + @{argLine} -Djava.library.path=/usr/local/lib;/usr/java/packages/lib;/usr/lib64;/lib64;/lib;/usr/lib @@ -227,27 +249,6 @@ - - org.scoverage - scoverage-maven-plugin - 1.4.11 - - - scoverage-report - package - - check - report-only - - - - - ${minimum.coverage} - true - ${scala.version} - skipTests=false - - org.apache.maven.plugins maven-resources-plugin diff --git a/python/mosaic/api/aggregators.py b/python/mosaic/api/aggregators.py index 930352f92..3f1f54ac8 100644 --- a/python/mosaic/api/aggregators.py +++ b/python/mosaic/api/aggregators.py @@ -8,7 +8,17 @@ # Spatial aggregators # ####################### -__all__ = ["st_intersection_aggregate", "st_intersects_aggregate", "st_union_agg", "grid_cell_union_agg", "grid_cell_intersection_agg"] +__all__ = [ + "st_intersection_aggregate", + "st_intersects_aggregate", + "st_union_agg", + "grid_cell_union_agg", + "grid_cell_intersection_agg", + "rst_merge_agg", + "rst_combineavg_agg", + "st_intersection_agg", + "st_intersects_agg", +] def st_intersection_aggregate( @@ -38,6 +48,33 @@ def st_intersection_aggregate( ) +def st_intersection_agg( + leftIndex: ColumnOrName, rightIndex: ColumnOrName +) -> Column: + """ + Computes the intersection of all `leftIndex` : `rightIndex` pairs + and unions these to produce a single geometry. + + Parameters + ---------- + leftIndex : Column + The index field of the left-hand geometry + rightIndex : Column + The index field of the right-hand geometry + + Returns + ------- + Column + The aggregated intersection geometry. + + """ + return config.mosaic_context.invoke_function( + "st_intersection_aggregate", + pyspark_to_java_column(leftIndex), + pyspark_to_java_column(rightIndex), + ) + + def st_intersects_aggregate( leftIndex: ColumnOrName, rightIndex: ColumnOrName ) -> Column: @@ -62,6 +99,32 @@ def st_intersects_aggregate( pyspark_to_java_column(rightIndex), ) + +def st_intersects_agg( + leftIndex: ColumnOrName, rightIndex: ColumnOrName +) -> Column: + """ + Tests if any `leftIndex` : `rightIndex` pairs intersect. + + Parameters + ---------- + leftIndex : Column + The index field of the left-hand geometry + rightIndex : Column + The index field of the right-hand geometry + + Returns + ------- + Column (BooleanType) + + """ + return config.mosaic_context.invoke_function( + "st_intersects_aggregate", + pyspark_to_java_column(leftIndex), + pyspark_to_java_column(rightIndex), + ) + + def st_union_agg(geom: ColumnOrName) -> Column: """ Returns the point set union of the aggregated geometries. @@ -97,6 +160,7 @@ def grid_cell_intersection_agg(chips: ColumnOrName) -> Column: "grid_cell_intersection_agg", pyspark_to_java_column(chips) ) + def grid_cell_union_agg(chips: ColumnOrName) -> Column: """ Returns the chip representing the aggregated union of chips on some grid cell. @@ -112,4 +176,40 @@ def grid_cell_union_agg(chips: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( "grid_cell_union_agg", pyspark_to_java_column(chips) - ) \ No newline at end of file + ) + + +def rst_merge_agg(raster: ColumnOrName) -> Column: + """ + Returns the raster representing the aggregated union of rasters on some grid cell. + + Parameters + ---------- + raster: Column + + Returns + ------- + Column + The union raster. + """ + return config.mosaic_context.invoke_function( + "rst_merge_agg", pyspark_to_java_column(raster) + ) + + +def rst_combineavg_agg(raster: ColumnOrName) -> Column: + """ + Returns the raster representing the aggregated average of rasters. + + Parameters + ---------- + raster: Column + + Returns + ------- + Column + The average raster. + """ + return config.mosaic_context.invoke_function( + "rst_combineavg_agg", pyspark_to_java_column(raster) + ) diff --git a/python/mosaic/api/enable.py b/python/mosaic/api/enable.py index 0996c24ef..c6dde26a4 100644 --- a/python/mosaic/api/enable.py +++ b/python/mosaic/api/enable.py @@ -56,10 +56,15 @@ def enable_mosaic(spark: SparkSession, dbutils=None) -> None: isSupported = config.mosaic_context._context.checkDBR(spark._jsparkSession) if not isSupported: - print("DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime. \n") - print("DEPRECATION WARNING: Mosaic will stop working on this cluster from version v0.4.0+. \n") - print("Please use a Databricks Photon-enabled Runtime (for performance benefits) or Runtime ML (for spatial AI benefits). \n") - + print( + "DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime. \n" + ) + print( + "DEPRECATION WARNING: Mosaic will stop working on this cluster from version v0.4.0+. \n" + ) + print( + "Please use a Databricks Photon-enabled Runtime (for performance benefits) or Runtime ML (for spatial AI benefits). \n" + ) # Not yet added to the pyspark API with warnings.catch_warnings(): diff --git a/python/mosaic/api/functions.py b/python/mosaic/api/functions.py index cc659d826..76df46554 100644 --- a/python/mosaic/api/functions.py +++ b/python/mosaic/api/functions.py @@ -18,6 +18,7 @@ "st_convexhull", "st_buffer", "st_bufferloop", + "st_buffer_cap_style", "st_dump", "st_envelope", "st_srid", @@ -48,13 +49,7 @@ "st_zmax", "st_x", "st_y", - - "rst_bandmetadata", - "rst_metadata", - "rst_subdatasets", - "flatten_polygons", - "grid_boundaryaswkb", "grid_boundary", "grid_longlatascellid", @@ -73,14 +68,12 @@ "grid_geometrykloop", "grid_geometrykringexplode", "grid_geometrykloopexplode", - "point_index_geom", "point_index_lonlat", "index_geometry", "polyfill", "mosaic_explode", "mosaicfill", - ] @@ -183,7 +176,9 @@ def st_buffer(geom: ColumnOrName, radius: ColumnOrName) -> Column: ) -def st_bufferloop(geom: ColumnOrName, inner_radius: ColumnOrName, outer_radius: ColumnOrName) -> Column: +def st_bufferloop( + geom: ColumnOrName, inner_radius: ColumnOrName, outer_radius: ColumnOrName +) -> Column: """ Compute the buffered geometry loop (hollow ring) based on geom and provided radius-es. The result geometry is a polygon/multipolygon with a hole in the center. @@ -209,7 +204,34 @@ def st_bufferloop(geom: ColumnOrName, inner_radius: ColumnOrName, outer_radius: "st_bufferloop", pyspark_to_java_column(geom), pyspark_to_java_column(inner_radius), - pyspark_to_java_column(outer_radius) + pyspark_to_java_column(outer_radius), + ) + + +def st_buffer_cap_style(geom: ColumnOrName, radius: ColumnOrName, cap_style: ColumnOrName) -> Column: + """ + Compute the buffered geometry based on geom and radius. + + Parameters + ---------- + geom : Column + The input geometry + radius : Column + The radius of buffering + cap_style : Column + The cap style of the buffer + + Returns + ------- + Column + A geometry + + """ + return config.mosaic_context.invoke_function( + "st_buffer_cap_style", + pyspark_to_java_column(geom), + pyspark_to_java_column(radius), + pyspark_to_java_column(cap_style) ) @@ -309,7 +331,7 @@ def st_transform(geom: ColumnOrName, srid: ColumnOrName) -> Column: def st_hasvalidcoordinates( - geom: ColumnOrName, crs: ColumnOrName, which: ColumnOrName + geom: ColumnOrName, crs: ColumnOrName, which: ColumnOrName ) -> Column: """ Checks if all points in geometry are valid with respect to crs bounds. @@ -516,7 +538,10 @@ def st_distance(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: pyspark_to_java_column(geom2), ) -def st_haversine(lat1: ColumnOrName, lng1: ColumnOrName, lat2: ColumnOrName, lng2: ColumnOrName) -> Column: + +def st_haversine( + lat1: ColumnOrName, lng1: ColumnOrName, lat2: ColumnOrName, lng2: ColumnOrName +) -> Column: """ Compute the haversine distance in kilometers between two latitude/longitude pairs. @@ -556,9 +581,7 @@ def st_difference(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: The difference geometry. """ return config.mosaic_context.invoke_function( - "st_difference", - pyspark_to_java_column(geom1), - pyspark_to_java_column(geom2) + "st_difference", pyspark_to_java_column(geom1), pyspark_to_java_column(geom2) ) @@ -647,7 +670,9 @@ def st_union(left_geom: ColumnOrName, right_geom: ColumnOrName) -> Column: The union geometry. """ return config.mosaic_context.invoke_function( - "st_union", pyspark_to_java_column(left_geom), pyspark_to_java_column(right_geom) + "st_union", + pyspark_to_java_column(left_geom), + pyspark_to_java_column(right_geom), ) @@ -669,7 +694,9 @@ def st_unaryunion(geom: ColumnOrName) -> Column: ) -def st_updatesrid(geom: ColumnOrName, srcSRID: ColumnOrName, destSRID: ColumnOrName) -> Column: +def st_updatesrid( + geom: ColumnOrName, srcSRID: ColumnOrName, destSRID: ColumnOrName +) -> Column: """ Updates the SRID of the input geometry `geom` from `srcSRID` to `destSRID`. @@ -690,7 +717,10 @@ def st_updatesrid(geom: ColumnOrName, srcSRID: ColumnOrName, destSRID: ColumnOrN Geometry with updated SRID """ return config.mosaic_context.invoke_function( - "st_updatesrid", pyspark_to_java_column(geom), pyspark_to_java_column(srcSRID), pyspark_to_java_column(destSRID) + "st_updatesrid", + pyspark_to_java_column(geom), + pyspark_to_java_column(srcSRID), + pyspark_to_java_column(destSRID), ) @@ -707,9 +737,7 @@ def st_x(geom: ColumnOrName) -> Column: Column (DoubleType) """ - return config.mosaic_context.invoke_function( - "st_x", pyspark_to_java_column(geom) - ) + return config.mosaic_context.invoke_function("st_x", pyspark_to_java_column(geom)) def st_y(geom: ColumnOrName) -> Column: @@ -725,9 +753,7 @@ def st_y(geom: ColumnOrName) -> Column: Column (DoubleType) """ - return config.mosaic_context.invoke_function( - "st_y", pyspark_to_java_column(geom) - ) + return config.mosaic_context.invoke_function("st_y", pyspark_to_java_column(geom)) def st_geometrytype(geom: ColumnOrName) -> Column: @@ -856,87 +882,6 @@ def st_zmax(geom: ColumnOrName) -> Column: ) -def rst_metadata(raster: ColumnOrName, path: Any = "") -> Column: - """ - Extracts metadata from a raster row. - - Parameters - ---------- - raster : ColumnOrName - The input raster column. - path : ColumnOrName - The path of the metadata within the raster row. - - Returns - ------- - Column - A map column containing the metadata. - - """ - if type(path) == str: - path = lit(path) - return config.mosaic_context.invoke_function( - "rst_metadata", - pyspark_to_java_column(raster), - pyspark_to_java_column(path) - ) - - -def rst_subdatasets(raster: ColumnOrName, path: Any = "") -> Column: - """ - Extracts subdatasets from a raster row. - - Parameters - ---------- - raster : ColumnOrName - The input raster column. - path : ColumnOrName - The path of subdatasets within the raster row. - - Returns - ------- - Column - A map column containing the subdatasets. - - """ - if type(path) == str: - path = lit(path) - return config.mosaic_context.invoke_function( - "rst_subdatasets", - pyspark_to_java_column(raster), - pyspark_to_java_column(path) - ) - - -def rst_bandmetadata(raster: ColumnOrName, band: ColumnOrName, path: Any = "") -> Column: - """ - Extracts band metadata from a raster row. - - Parameters - ---------- - raster : ColumnOrName - The input raster column. - band : ColumnOrName - The band index. - path : ColumnOrName - The path of the metadata within the raster row and the band. - - Returns - ------- - Column - A map column containing the metadata. - - """ - if type(path) == str: - path = lit(path) - return config.mosaic_context.invoke_function( - "rst_bandmetadata", - pyspark_to_java_column(raster), - pyspark_to_java_column(band), - pyspark_to_java_column(path) - ) - - def flatten_polygons(geom: ColumnOrName) -> Column: """ Explodes a multi-geometry into one row per constituent geometry. @@ -975,6 +920,7 @@ def grid_boundaryaswkb(index_id: ColumnOrName) -> Column: "grid_boundaryaswkb", pyspark_to_java_column(index_id) ) + def grid_cellarea(index_id: ColumnOrName) -> Column: """ Returns the area of the grid cell in km^2. @@ -990,8 +936,7 @@ def grid_cellarea(index_id: ColumnOrName) -> Column: The area of the grid cell in km^2 """ return config.mosaic_context.invoke_function( - "grid_cellarea", - pyspark_to_java_column(index_id) + "grid_cellarea", pyspark_to_java_column(index_id) ) @@ -1014,12 +959,12 @@ def grid_boundary(index_id: ColumnOrName, format_name: ColumnOrName) -> Column: return config.mosaic_context.invoke_function( "grid_boundary", pyspark_to_java_column(index_id), - pyspark_to_java_column(format_name) + pyspark_to_java_column(format_name), ) def grid_longlatascellid( - lon: ColumnOrName, lat: ColumnOrName, resolution: ColumnOrName + lon: ColumnOrName, lat: ColumnOrName, resolution: ColumnOrName ) -> Column: """ Returns the grid's cell ID associated with the input `lng` and `lat` coordinates at a given grid `resolution`. @@ -1087,7 +1032,7 @@ def grid_polyfill(geom: ColumnOrName, resolution: ColumnOrName) -> Column: def grid_tessellate( - geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True + geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True ) -> Column: """ Generates: @@ -1122,7 +1067,7 @@ def grid_tessellate( def grid_tessellateexplode( - geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True + geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True ) -> Column: """ Generates: @@ -1154,9 +1099,8 @@ def grid_tessellateexplode( pyspark_to_java_column(keep_core_geometries), ) -def grid_cell_intersection( - left_chip: ColumnOrName, right_chip: ColumnOrName -) -> Column: + +def grid_cell_intersection(left_chip: ColumnOrName, right_chip: ColumnOrName) -> Column: """ Returns the chip representing the intersection of two chips based on the same grid cell. @@ -1176,9 +1120,8 @@ def grid_cell_intersection( pyspark_to_java_column(right_chip), ) -def grid_cell_union( - left_chip: ColumnOrName, right_chip: ColumnOrName -) -> Column: + +def grid_cell_union(left_chip: ColumnOrName, right_chip: ColumnOrName) -> Column: """ Returns the chip representing the union of two chips based on the same grid cell. @@ -1199,9 +1142,7 @@ def grid_cell_union( ) -def grid_cellkring( - cellid: ColumnOrName, k: ColumnOrName -) -> Column: +def grid_cellkring(cellid: ColumnOrName, k: ColumnOrName) -> Column: """ Returns the k-ring of cells around the input cell ID. @@ -1222,9 +1163,7 @@ def grid_cellkring( ) -def grid_cellkloop( - cellid: ColumnOrName, k: ColumnOrName -) -> Column: +def grid_cellkloop(cellid: ColumnOrName, k: ColumnOrName) -> Column: """ Returns the k loop (hollow ring) of cells around the input cell ID. @@ -1245,9 +1184,7 @@ def grid_cellkloop( ) -def grid_cellkringexplode( - cellid: ColumnOrName, k: ColumnOrName -) -> Column: +def grid_cellkringexplode(cellid: ColumnOrName, k: ColumnOrName) -> Column: """ Returns the exploded k-ring of cells around the input cell ID. @@ -1268,9 +1205,7 @@ def grid_cellkringexplode( ) -def grid_cellkloopexplode( - cellid: ColumnOrName, k: ColumnOrName -) -> Column: +def grid_cellkloopexplode(cellid: ColumnOrName, k: ColumnOrName) -> Column: """ Returns the exploded k loop (hollow ring) of cells around the input cell ID. @@ -1292,7 +1227,7 @@ def grid_cellkloopexplode( def grid_geometrykring( - geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName + geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName ) -> Column: """ Returns the k-ring of cells around the input geometry. @@ -1317,7 +1252,7 @@ def grid_geometrykring( def grid_geometrykloop( - geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName + geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName ) -> Column: """ Returns the k loop (hollow ring) of cells around the input geometry. @@ -1342,7 +1277,7 @@ def grid_geometrykloop( def grid_geometrykringexplode( - geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName + geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName ) -> Column: """ Returns the exploded k-ring of cells around the input geometry. @@ -1367,7 +1302,7 @@ def grid_geometrykringexplode( def grid_geometrykloopexplode( - geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName + geom: ColumnOrName, resolution: ColumnOrName, k: ColumnOrName ) -> Column: """ Returns the exploded k loop (hollow ring) of cells around the input geometry. @@ -1414,7 +1349,7 @@ def point_index_geom(geom: ColumnOrName, resolution: ColumnOrName) -> Column: def point_index_lonlat( - lon: ColumnOrName, lat: ColumnOrName, resolution: ColumnOrName + lon: ColumnOrName, lat: ColumnOrName, resolution: ColumnOrName ) -> Column: """ [Deprecated] alias for `grid_longlatascellid` @@ -1471,7 +1406,7 @@ def polyfill(geom: ColumnOrName, resolution: ColumnOrName) -> Column: def mosaic_explode( - geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True + geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True ) -> Column: """ [Deprecated] alias for `grid_tessellateexplode` @@ -1506,7 +1441,7 @@ def mosaic_explode( def mosaicfill( - geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True + geom: ColumnOrName, resolution: ColumnOrName, keep_core_geometries: Any = True ) -> Column: """ [Deprecated] alias for `grid_tessellate` diff --git a/python/mosaic/api/gdal.py b/python/mosaic/api/gdal.py index 31b229814..0887bfc00 100644 --- a/python/mosaic/api/gdal.py +++ b/python/mosaic/api/gdal.py @@ -2,15 +2,13 @@ from typing import Any import subprocess -__all__ = [ - "setup_gdal", - "enable_gdal" -] +__all__ = ["setup_gdal", "enable_gdal"] + def setup_gdal( - spark: SparkSession, - init_script_path: str = "/dbfs/FileStore/geospatial/mosaic/gdal/", - shared_objects_path: str = "/dbfs/FileStore/geospatial/mosaic/gdal/") -> None: + spark: SparkSession, + init_script_path: str = "/dbfs/FileStore/geospatial/mosaic/gdal/", +) -> None: """ Prepare GDAL init script and shared objects required for GDAL to run on spark. This function will generate the init script that will install GDAL on each worker node. @@ -22,9 +20,6 @@ def setup_gdal( The active SparkSession. init_script_path : str Path to write out the init script for GDAL installation. - shared_objects_path : str - Path to write out shared objects (libgdalalljni.so and libgdalalljni.so.30) that GDAL requires at runtime. - Note: If you dont use the default path you will need to update the generated init script. Returns ------- @@ -35,11 +30,13 @@ def setup_gdal( sc._jvm.com.databricks.labs.mosaic.functions, "MosaicContext" ) mosaicGDALObject = getattr(sc._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL") - mosaicGDALObject.prepareEnvironment(spark._jsparkSession, init_script_path, shared_objects_path) + mosaicGDALObject.prepareEnvironment(spark._jsparkSession, init_script_path) print("GDAL setup complete.\n") print(f"Shared objects (*.so) stored in: {shared_objects_path}.\n") print(f"Init script stored in: {init_script_path}.\n") - print("Please restart the cluster with the generated init script to complete the setup.\n") + print( + "Please restart the cluster with the generated init script to complete the setup.\n" + ) def enable_gdal(spark: SparkSession) -> None: @@ -60,13 +57,21 @@ def enable_gdal(spark: SparkSession) -> None: mosaicContextClass = getattr( sc._jvm.com.databricks.labs.mosaic.functions, "MosaicContext" ) - mosaicGDALObject = getattr(sc._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL") + mosaicGDALObject = getattr( + sc._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL" + ) mosaicGDALObject.enableGDAL(spark._jsparkSession) print("GDAL enabled.\n") - result = subprocess.run(['gdalinfo', '--version'], stdout=subprocess.PIPE) + result = subprocess.run(["gdalinfo", "--version"], stdout=subprocess.PIPE) print(result.stdout.decode() + "\n") except Exception as e: - print("GDAL not enabled. Mosaic with GDAL requires that GDAL be installed on the cluster.\n") - print("Please run setup_gdal() to generate the init script for install GDAL install.\n") - print("After the init script is generated, please restart the cluster with the init script to complete the setup.\n") - print("Error: " + str(e)) \ No newline at end of file + print( + "GDAL not enabled. Mosaic with GDAL requires that GDAL be installed on the cluster.\n" + ) + print( + "Please run setup_gdal() to generate the init script for install GDAL install.\n" + ) + print( + "After the init script is generated, please restart the cluster with the init script to complete the setup.\n" + ) + print("Error: " + str(e)) diff --git a/python/mosaic/api/raster.py b/python/mosaic/api/raster.py index bc9a42a5c..5bc140f72 100644 --- a/python/mosaic/api/raster.py +++ b/python/mosaic/api/raster.py @@ -10,12 +10,22 @@ __all__ = [ "rst_bandmetadata", + "rst_boundingbox", + "rst_clip", + "rst_combineavg", + "rst_fromfile", + "rst_frombands", "rst_georeference", + "ret_getnodata", + "rst_getsubdataset", "rst_height", "rst_isempty", + "rst_initnodata", "rst_memsize", "rst_metadata", + "rst_merge", "rst_numbands", + "rst_ndvi", "rst_pixelheight", "rst_pixelwidth", "rst_rastertogridavg", @@ -30,17 +40,22 @@ "rst_rotation", "rst_scalex", "rst_scaley", + "rst_setnodata", "rst_skewx", "rst_skewy", "rst_srid", "rst_subdatasets", "rst_summary", + "rst_subdivide", + "rst_tessellate", + "rst_to_overlapping_tiles", + "rst_tryopen", "rst_upperleftx", "rst_upperlefty", "rst_width", "rst_worldtorastercoord", "rst_worldtorastercoordx", - "rst_worldtorastercoordy" + "rst_worldtorastercoordy", ] @@ -62,11 +77,74 @@ def rst_bandmetadata(raster: ColumnOrName, band: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_bandmetadata", - pyspark_to_java_column(raster), - pyspark_to_java_column(band) + "rst_bandmetadata", pyspark_to_java_column(raster), pyspark_to_java_column(band) + ) + + +def rst_boundingbox(raster: ColumnOrName) -> Column: + """ + Returns the bounding box of the raster as a WKT polygon. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + + Returns + ------- + Column (StringType) + A WKT polygon representing the bounding box of the raster. + + """ + return config.mosaic_context.invoke_function( + "rst_boundingbox", pyspark_to_java_column(raster) + ) + + +def rst_clip(raster: ColumnOrName, geometry: ColumnOrName) -> Column: + """ + Clips the raster to the given geometry. + The result is the path to the clipped raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + geometry : Column (StringType) + The geometry to clip the raster to. + + Returns + ------- + Column (StringType) + The path to the clipped raster. + + """ + return config.mosaic_context.invoke_function( + "rst_clip", pyspark_to_java_column(raster), pyspark_to_java_column(geometry) + ) + + +def rst_combineavg(rasters: ColumnOrName) -> Column: + """ + Combines the rasters into a single raster. + + Parameters + ---------- + rasters : Column (ArrayType(StringType)) + Raster tiles to combine. + + Returns + ------- + Column (RasterTile) + The combined raster tile. + + """ + return config.mosaic_context.invoke_function( + "rst_combineavg", pyspark_to_java_column(rasters) ) + def rst_georeference(raster: ColumnOrName) -> Column: """ Returns GeoTransform of the raster as a GT array of doubles. @@ -90,10 +168,57 @@ def rst_georeference(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_georeference", - pyspark_to_java_column(raster) + "rst_georeference", pyspark_to_java_column(raster) + ) + + +def ret_getnodata(raster: ColumnOrName) -> Column: + """ + Returns the nodata value of the band. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + band : Column (IntegerType) + Band index, starts from 1. + + Returns + ------- + Column (DoubleType) + The nodata value of the band. + + """ + return config.mosaic_context.invoke_function( + "ret_getnodata", pyspark_to_java_column(raster) + ) + + +def rst_getsubdataset(raster: ColumnOrName, subdataset: ColumnOrName) -> Column: + """ + Returns the subdataset of the raster. + The subdataset is the path to the subdataset of the raster. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + subdataset : Column (IntegerType) + The index of the subdataset to get. + + Returns + ------- + Column (StringType) + The path to the subdataset. + + """ + return config.mosaic_context.invoke_function( + "rst_getsubdataset", + pyspark_to_java_column(raster), + pyspark_to_java_column(subdataset), ) + def rst_height(raster: ColumnOrName) -> Column: """ Parameters @@ -108,10 +233,31 @@ def rst_height(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_height", + "rst_height", pyspark_to_java_column(raster) + ) + + +def rst_initnodata(raster: ColumnOrName) -> Column: + """ + Initializes the nodata value of the band. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + + Returns + ------- + Column (StringType) + The path to the raster file. + + """ + return config.mosaic_context.invoke_function( + "rst_initnodata", pyspark_to_java_column(raster) ) + def rst_isempty(raster: ColumnOrName) -> Column: """ Parameters @@ -126,10 +272,10 @@ def rst_isempty(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_isempty", - pyspark_to_java_column(raster) + "rst_isempty", pyspark_to_java_column(raster) ) + def rst_memsize(raster: ColumnOrName) -> Column: """ Parameters @@ -144,10 +290,10 @@ def rst_memsize(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_memsize", - pyspark_to_java_column(raster) + "rst_memsize", pyspark_to_java_column(raster) ) + def rst_metadata(raster: ColumnOrName) -> Column: """ Parameters @@ -162,10 +308,54 @@ def rst_metadata(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_metadata", - pyspark_to_java_column(raster) + "rst_metadata", pyspark_to_java_column(raster) + ) + + +def rst_merge(rasters: ColumnOrName) -> Column: + """ + Merges the rasters into a single raster. + The result is the path to the merged raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + rasters : Column (ArrayType(StringType)) + Paths to the rasters to merge. + + Returns + ------- + Column (StringType) + The path to the merged raster. + + """ + return config.mosaic_context.invoke_function( + "rst_merge", pyspark_to_java_column(rasters) ) + +def rst_frombands(bands: ColumnOrName) -> Column: + """ + Merges the bands into a single raster. + The result is the path to the merged raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + bands : Column (ArrayType(StringType)) + Paths to the bands to merge. + + Returns + ------- + Column (StringType) + The path to the merged raster. + + """ + return config.mosaic_context.invoke_function( + "rst_frombands", pyspark_to_java_column(bands) + ) + + def rst_numbands(raster: ColumnOrName) -> Column: """ Parameters @@ -180,10 +370,39 @@ def rst_numbands(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_numbands", - pyspark_to_java_column(raster) + "rst_numbands", pyspark_to_java_column(raster) ) + +def rst_ndvi(raster: ColumnOrName, band1: ColumnOrName, band2: ColumnOrName) -> Column: + """ + Computes the NDVI of the raster. + The result is the path to the NDVI raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + band1 : Column (IntegerType) + The first band index. + band2 : Column (IntegerType) + The second band index. + + Returns + ------- + Column (StringType) + The path to the NDVI raster. + + """ + return config.mosaic_context.invoke_function( + "rst_ndvi", + pyspark_to_java_column(raster), + pyspark_to_java_column(band1), + pyspark_to_java_column(band2), + ) + + def rst_pixelheight(raster: ColumnOrName) -> Column: """ Parameters @@ -198,10 +417,10 @@ def rst_pixelheight(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_pixelheight", - pyspark_to_java_column(raster) + "rst_pixelheight", pyspark_to_java_column(raster) ) + def rst_pixelwidth(raster: ColumnOrName) -> Column: """ Parameters @@ -216,10 +435,10 @@ def rst_pixelwidth(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_pixelwidth", - pyspark_to_java_column(raster) + "rst_pixelwidth", pyspark_to_java_column(raster) ) + def rst_rastertogridavg(raster: ColumnOrName, resolution: ColumnOrName) -> Column: """ The result is a 2D array of cells, where each cell is a struct of (cellID, value). @@ -241,9 +460,10 @@ def rst_rastertogridavg(raster: ColumnOrName, resolution: ColumnOrName) -> Colum return config.mosaic_context.invoke_function( "rst_rastertogridavg", pyspark_to_java_column(raster), - pyspark_to_java_column(resolution) + pyspark_to_java_column(resolution), ) + def rst_rastertogridcount(raster: ColumnOrName, resolution: ColumnOrName) -> Column: """ The result is a 2D array of cells, where each cell is a struct of (cellID, value). @@ -265,9 +485,10 @@ def rst_rastertogridcount(raster: ColumnOrName, resolution: ColumnOrName) -> Col return config.mosaic_context.invoke_function( "rst_rastertogridcount", pyspark_to_java_column(raster), - pyspark_to_java_column(resolution) + pyspark_to_java_column(resolution), ) + def rst_rastertogridmax(raster: ColumnOrName, resolution: ColumnOrName) -> Column: """ The result is a 2D array of cells, where each cell is a struct of (cellID, value). @@ -289,9 +510,10 @@ def rst_rastertogridmax(raster: ColumnOrName, resolution: ColumnOrName) -> Colum return config.mosaic_context.invoke_function( "rst_rastertogridmax", pyspark_to_java_column(raster), - pyspark_to_java_column(resolution) + pyspark_to_java_column(resolution), ) + def rst_rastertogridmedian(raster: ColumnOrName, resolution: ColumnOrName) -> Column: """ The result is a 2D array of cells, where each cell is a struct of (cellID, value). @@ -313,9 +535,10 @@ def rst_rastertogridmedian(raster: ColumnOrName, resolution: ColumnOrName) -> Co return config.mosaic_context.invoke_function( "rst_rastertogridmedian", pyspark_to_java_column(raster), - pyspark_to_java_column(resolution) + pyspark_to_java_column(resolution), ) + def rst_rastertogridmin(raster: ColumnOrName, resolution: ColumnOrName) -> Column: """ The result is a 2D array of cells, where each cell is a struct of (cellID, value). @@ -337,10 +560,13 @@ def rst_rastertogridmin(raster: ColumnOrName, resolution: ColumnOrName) -> Colum return config.mosaic_context.invoke_function( "rst_rastertogridmin", pyspark_to_java_column(raster), - pyspark_to_java_column(resolution) + pyspark_to_java_column(resolution), ) -def rst_rastertoworldcoord(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName) -> Column: + +def rst_rastertoworldcoord( + raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName +) -> Column: """ Computes the world coordinates of the raster pixel at the given x and y coordinates. The result is a WKT point geometry. @@ -361,10 +587,13 @@ def rst_rastertoworldcoord(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrNam "rst_rastertoworldcoord", pyspark_to_java_column(raster), pyspark_to_java_column(x), - pyspark_to_java_column(y) + pyspark_to_java_column(y), ) -def rst_rastertoworldcoordx(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName) -> Column: + +def rst_rastertoworldcoordx( + raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName +) -> Column: """ Computes the world coordinates of the raster pixel at the given x and y coordinates. The result is the X coordinate of the point after applying the GeoTransform of the raster. @@ -384,10 +613,13 @@ def rst_rastertoworldcoordx(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrNa "rst_rastertoworldcoordx", pyspark_to_java_column(raster), pyspark_to_java_column(x), - pyspark_to_java_column(y) + pyspark_to_java_column(y), ) -def rst_rastertoworldcoordy(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName) -> Column: + +def rst_rastertoworldcoordy( + raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName +) -> Column: """ Computes the world coordinates of the raster pixel at the given x and y coordinates. The result is the Y coordinate of the point after applying the GeoTransform of the raster. @@ -407,10 +639,13 @@ def rst_rastertoworldcoordy(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrNa "rst_rastertoworldcoordy", pyspark_to_java_column(raster), pyspark_to_java_column(x), - pyspark_to_java_column(y) + pyspark_to_java_column(y), ) -def rst_retile(raster: ColumnOrName, tileWidth: ColumnOrName, tileHeight: ColumnOrName) -> Column: + +def rst_retile( + raster: ColumnOrName, tileWidth: ColumnOrName, tileHeight: ColumnOrName +) -> Column: """ Retiles the raster to the given tile size. The result is a collection of new raster files. The new rasters are stored in the checkpoint directory. @@ -432,9 +667,10 @@ def rst_retile(raster: ColumnOrName, tileWidth: ColumnOrName, tileHeight: Column "rst_retile", pyspark_to_java_column(raster), pyspark_to_java_column(tileWidth), - pyspark_to_java_column(tileHeight) + pyspark_to_java_column(tileHeight), ) + def rst_rotation(raster: ColumnOrName) -> Column: """ Computes the rotation of the raster in degrees. @@ -453,10 +689,10 @@ def rst_rotation(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_rotation", - pyspark_to_java_column(raster) + "rst_rotation", pyspark_to_java_column(raster) ) + def rst_scalex(raster: ColumnOrName) -> Column: """ Computes the scale of the raster in the X direction. @@ -473,10 +709,10 @@ def rst_scalex(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_scalex", - pyspark_to_java_column(raster) + "rst_scalex", pyspark_to_java_column(raster) ) + def rst_scaley(raster: ColumnOrName) -> Column: """ Computes the scale of the raster in the Y direction. @@ -493,10 +729,34 @@ def rst_scaley(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_scaley", - pyspark_to_java_column(raster) + "rst_scaley", pyspark_to_java_column(raster) + ) + + +def rst_setnodata(raster: ColumnOrName, nodata: ColumnOrName) -> Column: + """ + Sets the nodata value of the band. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + nodata : Column (DoubleType) + The nodata value to set. + + Returns + ------- + Column (StringType) + The path to the raster file. + + """ + return config.mosaic_context.invoke_function( + "rst_setnodata", + pyspark_to_java_column(raster), + pyspark_to_java_column(nodata), ) + def rst_skewx(raster: ColumnOrName) -> Column: """ Computes the skew of the raster in the X direction. @@ -513,10 +773,10 @@ def rst_skewx(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_skewx", - pyspark_to_java_column(raster) + "rst_skewx", pyspark_to_java_column(raster) ) + def rst_skewy(raster: ColumnOrName) -> Column: """ Computes the skew of the raster in the Y direction. @@ -533,10 +793,10 @@ def rst_skewy(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_skewy", - pyspark_to_java_column(raster) + "rst_skewy", pyspark_to_java_column(raster) ) + def rst_srid(raster: ColumnOrName) -> Column: """ Computes the SRID of the raster. @@ -554,10 +814,10 @@ def rst_srid(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_srid", - pyspark_to_java_column(raster) + "rst_srid", pyspark_to_java_column(raster) ) + def rst_subdatasets(raster: ColumnOrName) -> Column: """ Computes the subdatasets of the raster. @@ -576,10 +836,10 @@ def rst_subdatasets(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_subdatasets", - pyspark_to_java_column(raster) + "rst_subdatasets", pyspark_to_java_column(raster) ) + def rst_summary(raster: ColumnOrName) -> Column: """ Computes the summary of the raster. @@ -599,10 +859,112 @@ def rst_summary(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_summary", - pyspark_to_java_column(raster) + "rst_summary", pyspark_to_java_column(raster) ) + +def rst_tessellate(raster: ColumnOrName, resolution: ColumnOrName) -> Column: + """ + Clip the raster into raster tiles where each tile is a grid tile for the given resolution. + The tile set union forms the original raster. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + resolution : Column (IntegerType) + The resolution of the tiles. + + Returns + ------- + Column (RasterTiles) + A struct containing the tiles of the raster. + + """ + return config.mosaic_context.invoke_function( + "rst_tessellate", + pyspark_to_java_column(raster), + pyspark_to_java_column(resolution), + ) + + +def rst_fromfile(raster: ColumnOrName, sizeInMB: ColumnOrName) -> Column: + """ + Tiles the raster into tiles of the given size. + :param raster: + :param sizeInMB: + :return: + """ + + return config.mosaic_context.invoke_function( + "rst_fromfile", + pyspark_to_java_column(raster), + pyspark_to_java_column(sizeInMB) + ) + + +def rst_to_overlapping_tiles(raster: ColumnOrName, width: ColumnOrName, height: ColumnOrName, overlap: ColumnOrName) -> Column: + """ + Tiles the raster into tiles of the given size. + :param raster: + :param sizeInMB: + :return: + """ + + return config.mosaic_context.invoke_function( + "rst_to_overlapping_tiles", + pyspark_to_java_column(raster), + pyspark_to_java_column(width), + pyspark_to_java_column(height), + pyspark_to_java_column(overlap) + ) + + +def rst_tryopen(raster: ColumnOrName) -> Column: + """ + Tries to open the raster and returns a flag indicating if the raster can be opened. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + + Returns + ------- + Column (BooleanType) + Whether the raster can be opened. + + """ + return config.mosaic_context.invoke_function( + "rst_tryopen", pyspark_to_java_column(raster) + ) + + +def rst_subdivide(raster: ColumnOrName, size_in_mb: ColumnOrName) -> Column: + """ + Subdivides the raster into tiles that have to be smaller than the given size in MB. + All the tiles have the same aspect ratio as the original raster. + + Parameters + ---------- + raster : Column (StringType) + Path to the raster file. + size_in_mb : Column (IntegerType) + The size of the tiles in MB. + + Returns + ------- + Column (RasterTiles) + A collection of tiles of the raster. + + """ + return config.mosaic_context.invoke_function( + "rst_subdivide", + pyspark_to_java_column(raster), + pyspark_to_java_column(size_in_mb), + ) + + def rst_upperleftx(raster: ColumnOrName) -> Column: """ Computes the upper left X coordinate of the raster. @@ -620,10 +982,10 @@ def rst_upperleftx(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_upperleftx", - pyspark_to_java_column(raster) + "rst_upperleftx", pyspark_to_java_column(raster) ) + def rst_upperlefty(raster: ColumnOrName) -> Column: """ Computes the upper left Y coordinate of the raster. @@ -641,10 +1003,10 @@ def rst_upperlefty(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_upperlefty", - pyspark_to_java_column(raster) + "rst_upperlefty", pyspark_to_java_column(raster) ) + def rst_width(raster: ColumnOrName) -> Column: """ Computes the width of the raster in pixels. @@ -661,11 +1023,13 @@ def rst_width(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_width", - pyspark_to_java_column(raster) + "rst_width", pyspark_to_java_column(raster) ) -def rst_worldtorastercoord(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName) -> Column: + +def rst_worldtorastercoord( + raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName +) -> Column: """ Computes the raster coordinates of the world coordinates. The raster coordinates are the pixel coordinates of the raster. @@ -684,11 +1048,13 @@ def rst_worldtorastercoord(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrNam """ return config.mosaic_context.invoke_function( - "rst_worldtorastercoord", - pyspark_to_java_column(raster) + "rst_worldtorastercoord", pyspark_to_java_column(raster) ) -def rst_worldtorastercoordx(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName) -> Column: + +def rst_worldtorastercoordx( + raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName +) -> Column: """ Computes the raster coordinates of the world coordinates. The raster coordinates are the pixel coordinates of the raster. @@ -708,11 +1074,13 @@ def rst_worldtorastercoordx(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrNa """ return config.mosaic_context.invoke_function( - "rst_worldtorastercoordx", - pyspark_to_java_column(raster) + "rst_worldtorastercoordx", pyspark_to_java_column(raster) ) -def rst_worldtorastercoordy(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName) -> Column: + +def rst_worldtorastercoordy( + raster: ColumnOrName, x: ColumnOrName, y: ColumnOrName +) -> Column: """ Computes the raster coordinates of the world coordinates. The raster coordinates are the pixel coordinates of the raster. @@ -732,6 +1100,5 @@ def rst_worldtorastercoordy(raster: ColumnOrName, x: ColumnOrName, y: ColumnOrNa """ return config.mosaic_context.invoke_function( - "rst_worldtorastercoordy", - pyspark_to_java_column(raster) - ) \ No newline at end of file + "rst_worldtorastercoordy", pyspark_to_java_column(raster) + ) diff --git a/python/mosaic/core/library_handler.py b/python/mosaic/core/library_handler.py index bdd5e1832..86cb2b98d 100644 --- a/python/mosaic/core/library_handler.py +++ b/python/mosaic/core/library_handler.py @@ -52,11 +52,15 @@ def mosaic_library_location(self): except Py4JJavaError as e: self._jar_filename = f"mosaic-{importlib.metadata.version('databricks-mosaic')}-jar-with-dependencies.jar" try: - with importlib.resources.path("mosaic.lib", self._jar_filename) as p: + with importlib.resources.path( + "mosaic.lib", self._jar_filename + ) as p: self._jar_path = p.as_posix() except FileNotFoundError as fnf: self._jar_filename = f"mosaic-{importlib.metadata.version('databricks-mosaic')}-SNAPSHOT-jar-with-dependencies.jar" - with importlib.resources.path("mosaic.lib", self._jar_filename) as p: + with importlib.resources.path( + "mosaic.lib", self._jar_filename + ) as p: self._jar_path = p.as_posix() return self._jar_path diff --git a/python/mosaic/core/mosaic_context.py b/python/mosaic/core/mosaic_context.py index 10e321615..ff49a4e37 100644 --- a/python/mosaic/core/mosaic_context.py +++ b/python/mosaic/core/mosaic_context.py @@ -7,7 +7,6 @@ class MosaicContext: - _context = None _geometry_api: str _index_system: str @@ -23,8 +22,12 @@ def __init__(self, spark: SparkSession): ) self._mosaicPackageRef = getattr(sc._jvm.com.databricks.labs.mosaic, "package$") self._mosaicPackageObject = getattr(self._mosaicPackageRef, "MODULE$") - self._mosaicGDALObject = getattr(sc._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL") - self._indexSystemFactory = getattr(sc._jvm.com.databricks.labs.mosaic.core.index, "IndexSystemFactory") + self._mosaicGDALObject = getattr( + sc._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL" + ) + self._indexSystemFactory = getattr( + sc._jvm.com.databricks.labs.mosaic.core.index, "IndexSystemFactory" + ) try: self._geometry_api = spark.conf.get( @@ -41,18 +44,15 @@ def __init__(self, spark: SparkSession): self._index_system = "H3" try: - self._raster_api = spark.conf.get( - "spark.databricks.labs.mosaic.raster.api" - ) + self._raster_api = spark.conf.get("spark.databricks.labs.mosaic.raster.api") except Py4JJavaError as e: self._raster_api = "GDAL" IndexSystem = self._indexSystemFactory.getIndexSystem(self._index_system) GeometryAPIClass = getattr(self._mosaicPackageObject, self._geometry_api) - RasterAPIClass = getattr(self._mosaicPackageObject, self._raster_api) self._context = self._mosaicContextClass.build( - IndexSystem, GeometryAPIClass(), RasterAPIClass() + IndexSystem, GeometryAPIClass() ) def invoke_function(self, name: str, *args: Any) -> MosaicColumn: diff --git a/python/mosaic/core/mosaic_frame.py b/python/mosaic/core/mosaic_frame.py index ac6e633b5..f2a5eb80c 100644 --- a/python/mosaic/core/mosaic_frame.py +++ b/python/mosaic/core/mosaic_frame.py @@ -37,8 +37,8 @@ def __init__(self, df: DataFrame, geometry_column_name: str): """ def get_optimal_resolution_str( - self, sample_rows: Optional[int] = None, sample_fraction: Optional[float] = None - ) -> str: + self, sample_rows: Optional[int] = None, sample_fraction: Optional[float] = None + ) -> str: """ Analyzes the geometries in the currently selected geometry column and proposes an optimal grid-index resolution. @@ -75,7 +75,6 @@ def get_optimal_resolution_str( return self._mosaicFrame.analyzer().getOptimalResolutionStr(sampleStrategy) return self._mosaicFrame.analyzer().getOptimalResolutionStr() - def get_optimal_resolution( self, sample_rows: Optional[int] = None, sample_fraction: Optional[float] = None ) -> int: diff --git a/python/mosaic/models/__init__.py b/python/mosaic/models/__init__.py index 008f75454..6643ab150 100644 --- a/python/mosaic/models/__init__.py +++ b/python/mosaic/models/__init__.py @@ -1 +1 @@ -from .knn import SpatialKNN \ No newline at end of file +from .knn import SpatialKNN diff --git a/python/mosaic/models/knn/__init__.py b/python/mosaic/models/knn/__init__.py index 5810634a1..0c5f2ab5c 100644 --- a/python/mosaic/models/knn/__init__.py +++ b/python/mosaic/models/knn/__init__.py @@ -1 +1 @@ -from .spatial_knn import SpatialKNN \ No newline at end of file +from .spatial_knn import SpatialKNN diff --git a/python/mosaic/models/knn/spatial_knn.py b/python/mosaic/models/knn/spatial_knn.py index b3083636d..e5b10a394 100644 --- a/python/mosaic/models/knn/spatial_knn.py +++ b/python/mosaic/models/knn/spatial_knn.py @@ -1,6 +1,7 @@ from pyspark.sql import SparkSession, DataFrame, SQLContext from mosaic.utils import scala_utils + class SpatialKNN: """ @@ -54,7 +55,7 @@ def setLandmarksFeatureCol(self, feature): self.model.setLandmarksFeatureCol(feature) return self - + def setLandmarksRowID(self, rowID): """ Set the row ID column name for the landmarks. @@ -64,7 +65,7 @@ def setLandmarksRowID(self, rowID): self.model.setLandmarksRowID(rowID) return self - + def setCandidatesFeatureCol(self, feature): """ Ser the feature column name for the candidates. @@ -74,7 +75,7 @@ def setCandidatesFeatureCol(self, feature): self.model.setCandidatesFeatureCol(feature) return self - + def setCandidatesRowID(self, rowID): """ Set the row ID column name for the candidates. @@ -84,7 +85,7 @@ def setCandidatesRowID(self, rowID): self.model.setCandidatesRowID(rowID) return self - + def setDistanceThreshold(self, d): """ Set the distance threshold for the nearest neighbours. @@ -94,7 +95,7 @@ def setDistanceThreshold(self, d): self.model.setDistanceThreshold(d) return self - + def setIndexResolution(self, resolution): """ Set the index resolution for the spatial index. @@ -203,8 +204,6 @@ def read(self): return self.model.read def load(self, path): - """ - - """ + """ """ - return self.model.load(path) \ No newline at end of file + return self.model.load(path) diff --git a/python/mosaic/readers/__init__.py b/python/mosaic/readers/__init__.py index d3fb99709..d7ad59b3a 100644 --- a/python/mosaic/readers/__init__.py +++ b/python/mosaic/readers/__init__.py @@ -1,2 +1,2 @@ from .mosaic_data_frame_reader import MosaicDataFrameReader -from .readers import read \ No newline at end of file +from .readers import read diff --git a/python/mosaic/readers/mosaic_data_frame_reader.py b/python/mosaic/readers/mosaic_data_frame_reader.py index 34b749a6b..085aea4a7 100644 --- a/python/mosaic/readers/mosaic_data_frame_reader.py +++ b/python/mosaic/readers/mosaic_data_frame_reader.py @@ -14,7 +14,8 @@ def __init__(self): """ self.spark = SparkSession.builder.getOrCreate() self.reader = getattr( - self.spark._jvm.com.databricks.labs.mosaic.datasource.multiread, "MosaicDataFrameReader" + self.spark._jvm.com.databricks.labs.mosaic.datasource.multiread, + "MosaicDataFrameReader", )(self.spark._jsparkSession) def format(self, format): diff --git a/python/mosaic/utils/kepler_magic.py b/python/mosaic/utils/kepler_magic.py index b18457218..66dd418fb 100644 --- a/python/mosaic/utils/kepler_magic.py +++ b/python/mosaic/utils/kepler_magic.py @@ -9,7 +9,15 @@ from mosaic.api.accessors import st_astext, st_aswkt from mosaic.api.constructors import st_geomfromwkt, st_geomfromwkb -from mosaic.api.functions import st_centroid, grid_pointascellid, grid_boundaryaswkb, st_setsrid, st_transform, st_x, st_y +from mosaic.api.functions import ( + st_centroid, + grid_pointascellid, + grid_boundaryaswkb, + st_setsrid, + st_transform, + st_x, + st_y, +) from mosaic.config import config from mosaic.utils.kepler_config import mosaic_kepler_config @@ -22,14 +30,13 @@ class MosaicKepler(Magics): """ def __init__(self, shell): - Magics.__init__(self, shell) - self.bng_crsid = 27700 - self.osgb36_crsid = 27700 - self.wgs84_crsid = 4326 + Magics.__init__(self, shell) + self.bng_crsid = 27700 + self.osgb36_crsid = 27700 + self.wgs84_crsid = 4326 @staticmethod def displayKepler(map_instance, height, width): - """ Display Kepler map instance in Jupyter notebook. @@ -104,7 +111,6 @@ def get_spark_df(table_name): @staticmethod def set_centroid(pandas_data, feature_type, feature_name): - """ Sets the centroid of the geometry column. @@ -141,14 +147,8 @@ def set_centroid(pandas_data, feature_type, feature_name): tmp_sdf = tmp_sdf.withColumn(feature_name, grid_boundaryaswkb(feature_name)) centroid = ( - tmp_sdf.select( - st_centroid(feature_name).alias("centroid") - ).select( - struct( - st_x("centroid").alias("x"), - st_y("centroid").alias("y") - ) - ) + tmp_sdf.select(st_centroid(feature_name).alias("centroid")) + .select(struct(st_x("centroid").alias("x"), st_y("centroid").alias("y"))) .limit(1) .collect()[0][0] ) @@ -159,7 +159,6 @@ def set_centroid(pandas_data, feature_type, feature_name): @cell_magic def mosaic_kepler(self, *args): - """ A magic command for visualizing data in KeplerGl. @@ -215,14 +214,18 @@ def mosaic_kepler(self, *args): feature_name, lower(conv(col(feature_name), 10, 16)) ) elif feature_type == "bng": - data = (data - .withColumn(feature_name, grid_boundaryaswkb(feature_name)) + data = ( + data.withColumn(feature_name, grid_boundaryaswkb(feature_name)) .withColumn(feature_name, st_geomfromwkb(feature_name)) .withColumn( feature_name, - st_transform(st_setsrid(feature_name, lit(self.bng_crsid)), lit(self.wgs84_crsid)) + st_transform( + st_setsrid(feature_name, lit(self.bng_crsid)), + lit(self.wgs84_crsid), + ), ) - .withColumn(feature_name, st_aswkt(feature_name))) + .withColumn(feature_name, st_aswkt(feature_name)) + ) elif feature_type == "geometry": data = data.withColumn(feature_name, st_astext(col(feature_name))) elif re.search("^geometry\(.*\)$", feature_type).start() != None: @@ -231,13 +234,16 @@ def mosaic_kepler(self, *args): crsid = self.bng_crsid else: crsid = int(crsid) - data = (data - .withColumn(feature_name, st_geomfromwkt(st_aswkt(feature_name))) + data = ( + data.withColumn(feature_name, st_geomfromwkt(st_aswkt(feature_name))) .withColumn( feature_name, - st_transform(st_setsrid(feature_name, lit(crsid)), lit(self.wgs84_crsid)) + st_transform( + st_setsrid(feature_name, lit(crsid)), lit(self.wgs84_crsid) + ), ) - .withColumn(feature_name, st_aswkt(feature_name))) + .withColumn(feature_name, st_aswkt(feature_name)) + ) else: raise Exception(f"Unsupported geometry type: {feature_type}.") diff --git a/python/mosaic/utils/notebook_utils.py b/python/mosaic/utils/notebook_utils.py index 406f5848f..e7a74015a 100644 --- a/python/mosaic/utils/notebook_utils.py +++ b/python/mosaic/utils/notebook_utils.py @@ -1,4 +1,10 @@ +import os + + class NotebookUtils: @staticmethod def displayHTML(html: str): + if os.environ.get("MOSAIC_JUPYTER", "FALSE") == "TRUE": + with open("./mosaic_kepler_view.html", "w") as f: + f.write(html) print(html) diff --git a/python/mosaic/utils/scala_utils.py b/python/mosaic/utils/scala_utils.py index 8f9a58745..20c46696c 100644 --- a/python/mosaic/utils/scala_utils.py +++ b/python/mosaic/utils/scala_utils.py @@ -1,4 +1,3 @@ - def scala_map_to_dict(scala_map): result = dict() for i in range(0, scala_map.size()): @@ -6,4 +5,4 @@ def scala_map_to_dict(scala_map): curr_key = current._1() curr_val = current._2() result[curr_key] = curr_val - return result \ No newline at end of file + return result diff --git a/python/setup.cfg b/python/setup.cfg index c19b95a51..d7a7fbe05 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -19,7 +19,7 @@ packages = find: python_requires = >=3.7.0 install_requires = keplergl==0.3.2 - h3==3.7.0 + h3==3.7.3 ipython>=7.22.0 [options.package_data] diff --git a/python/test/context.py b/python/test/context.py index 3aa17e6b0..d118955b3 100644 --- a/python/test/context.py +++ b/python/test/context.py @@ -4,4 +4,6 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import mosaic.api as api +import mosaic.readers as readers +import mosaic.api.raster as rst from mosaic.core import MosaicContext, MosaicLibraryHandler diff --git a/python/test/test_functions.py b/python/test/test_functions.py index 27b69e9fb..8b8666dc7 100644 --- a/python/test/test_functions.py +++ b/python/test/test_functions.py @@ -49,7 +49,10 @@ def test_st_bindings_happy_flow(self): .withColumn("st_centroid", api.st_centroid("wkt")) .withColumn("st_numpoints", api.st_numpoints("wkt")) .withColumn("st_length", api.st_length("wkt")) - .withColumn("st_haversine", api.st_haversine(lit(0.0), lit(90.0), lit(0.0), lit(0.0))) + .withColumn( + "st_haversine", + api.st_haversine(lit(0.0), lit(90.0), lit(0.0), lit(0.0)), + ) .withColumn("st_isvalid", api.st_isvalid("wkt")) .withColumn( "st_hasvalidcoordinates", @@ -70,7 +73,6 @@ def test_st_bindings_happy_flow(self): .withColumn("st_zmin", api.st_zmin("wkt")) .withColumn("st_zmax", api.st_zmax("wkt")) .withColumn("flatten_polygons", api.flatten_polygons("wkt")) - # SRID functions .withColumn( "geom_with_srid", api.st_setsrid(api.st_geomfromwkt("wkt"), lit(4326)) @@ -79,13 +81,18 @@ def test_st_bindings_happy_flow(self): .withColumn( "transformed_geom", api.st_transform("geom_with_srid", lit(3857)) ) - # Grid functions - .withColumn("grid_longlatascellid", api.grid_longlatascellid(lit(1), lit(1), lit(1))) - .withColumn("grid_pointascellid", api.grid_pointascellid("point_wkt", lit(1))) + .withColumn( + "grid_longlatascellid", api.grid_longlatascellid(lit(1), lit(1), lit(1)) + ) + .withColumn( + "grid_pointascellid", api.grid_pointascellid("point_wkt", lit(1)) + ) .withColumn("grid_boundaryaswkb", api.grid_boundaryaswkb(lit(1))) .withColumn("grid_polyfill", api.grid_polyfill("wkt", lit(1))) - .withColumn("grid_tessellateexplode", api.grid_tessellateexplode("wkt", lit(1))) + .withColumn( + "grid_tessellateexplode", api.grid_tessellateexplode("wkt", lit(1)) + ) .withColumn( "grid_tessellateexplode_no_core_chips", api.grid_tessellateexplode("wkt", lit(1), lit(False)), @@ -96,8 +103,6 @@ def test_st_bindings_happy_flow(self): ) .withColumn("grid_tessellate", api.grid_tessellate("wkt", lit(1))) .withColumn("grid_cellarea", api.grid_cellarea(lit(613177664827555839))) - - # Deprecated .withColumn( "point_index_lonlat", api.point_index_lonlat(lit(1), lit(1), lit(1)) @@ -122,7 +127,6 @@ def test_st_bindings_happy_flow(self): "mosaicfill_no_core_chips_bool", api.mosaicfill("wkt", lit(1), lit(False)), ) - ) self.assertEqual(result.count(), 1) @@ -213,16 +217,38 @@ def test_grid_kring_kloop(self): ], ["wkt", "point_wkt"], ) - result = (df - .withColumn("grid_longlatascellid", api.grid_longlatascellid(lit(1), lit(1), lit(1))) - .withColumn("grid_cellkring", api.grid_cellkring("grid_longlatascellid", lit(1))) - .withColumn("grid_cellkloop", api.grid_cellkloop("grid_longlatascellid", lit(1))) - .withColumn("grid_cellkringexplode", api.grid_cellkringexplode("grid_longlatascellid", lit(1))) - .withColumn("grid_cellkloopexplode", api.grid_cellkloopexplode("grid_longlatascellid", lit(1))) - .withColumn("grid_geometrykring", api.grid_geometrykring("wkt", lit(4), lit(1))) - .withColumn("grid_geometrykloop", api.grid_geometrykloop("wkt", lit(4), lit(1))) - .withColumn("grid_geometrykringexplode", api.grid_geometrykringexplode("wkt", lit(4), lit(1))) - .withColumn("grid_geometrykloopexplode", api.grid_geometrykloopexplode("wkt", lit(4), lit(1))) + result = ( + df.withColumn( + "grid_longlatascellid", api.grid_longlatascellid(lit(1), lit(1), lit(1)) + ) + .withColumn( + "grid_cellkring", api.grid_cellkring("grid_longlatascellid", lit(1)) + ) + .withColumn( + "grid_cellkloop", api.grid_cellkloop("grid_longlatascellid", lit(1)) + ) + .withColumn( + "grid_cellkringexplode", + api.grid_cellkringexplode("grid_longlatascellid", lit(1)), + ) + .withColumn( + "grid_cellkloopexplode", + api.grid_cellkloopexplode("grid_longlatascellid", lit(1)), + ) + .withColumn( + "grid_geometrykring", api.grid_geometrykring("wkt", lit(4), lit(1)) + ) + .withColumn( + "grid_geometrykloop", api.grid_geometrykloop("wkt", lit(4), lit(1)) + ) + .withColumn( + "grid_geometrykringexplode", + api.grid_geometrykringexplode("wkt", lit(4), lit(1)), + ) + .withColumn( + "grid_geometrykloopexplode", + api.grid_geometrykloopexplode("wkt", lit(4), lit(1)), + ) ) self.assertEqual(result.count() > 1, True) @@ -236,13 +262,13 @@ def test_grid_cell_union_intersection(self): ) df_chips = df.withColumn("chips", api.grid_tessellateexplode("wkt", lit(1))) - df_chips = ( - df_chips - .withColumn("intersection", api.grid_cell_intersection("chips", "chips")) - .withColumn("union", api.grid_cell_union("chips", "chips")) + df_chips = df_chips.withColumn( + "intersection", api.grid_cell_intersection("chips", "chips") + ).withColumn("union", api.grid_cell_union("chips", "chips")) + intersection = df_chips.groupBy("chips.index_id").agg( + api.grid_cell_intersection_agg("chips") ) - intersection = df_chips.groupBy("chips.index_id").agg(api.grid_cell_intersection_agg("chips")) self.assertEqual(intersection.count() >= 0, True) union = df_chips.groupBy("chips.index_id").agg(api.grid_cell_union_agg("chips")) - self.assertEqual(union.count() >= 0, True) \ No newline at end of file + self.assertEqual(union.count() >= 0, True) diff --git a/python/test/utils/mosaic_test_case_with_gdal.py b/python/test/utils/mosaic_test_case_with_gdal.py index c7f0e07bd..4a5fe321c 100644 --- a/python/test/utils/mosaic_test_case_with_gdal.py +++ b/python/test/utils/mosaic_test_case_with_gdal.py @@ -7,4 +7,5 @@ class MosaicTestCaseWithGDAL(MosaicTestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - api.enable_mosaic(cls.spark, install_gdal=True) + api.enable_mosaic(cls.spark) + api.enable_gdal(cls.spark) diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index c69ce0924..29d9b916c 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -8,7 +8,6 @@ class SparkTestCase(unittest.TestCase): - spark = None library_location = None diff --git a/scripts/mosaic-docker.sh b/scripts/mosaic-docker.sh new file mode 100644 index 000000000..85c867c6c --- /dev/null +++ b/scripts/mosaic-docker.sh @@ -0,0 +1 @@ +docker run --name mosaic-dev --rm -p 5005:5005 -v $PWD:/root/mosaic -e JAVA_TOOL_OPTIONS="-agentlib:jdwp=transport=dt_socket,address=5005,server=y,suspend=n" -it mosaic-dev:jdk8-gdal3.4-spark3.2 /bin/bash \ No newline at end of file diff --git a/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 234e8ad6f..796353ac7 100644 --- a/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -2,5 +2,5 @@ com.databricks.labs.mosaic.datasource.ShapefileFileFormat com.databricks.labs.mosaic.datasource.GeoDBFileFormat com.databricks.labs.mosaic.datasource.OpenGeoDBFileFormat com.databricks.labs.mosaic.datasource.OGRFileFormat -com.databricks.labs.mosaic.datasource.GDALFileFormat +com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat com.databricks.labs.mosaic.datasource.UserDefinedFileFormat \ No newline at end of file diff --git a/src/main/resources/gdal/ubuntu/lib/jni/libgdalalljni.so b/src/main/resources/gdal/ubuntu/lib/jni/libgdalalljni.so new file mode 100644 index 000000000..2bd18f5f0 Binary files /dev/null and b/src/main/resources/gdal/ubuntu/lib/jni/libgdalalljni.so differ diff --git a/src/main/resources/gdal/ubuntu/lib/jni/libgdalalljni.so.30 b/src/main/resources/gdal/ubuntu/lib/jni/libgdalalljni.so.30 new file mode 100644 index 000000000..2bd18f5f0 Binary files /dev/null and b/src/main/resources/gdal/ubuntu/lib/jni/libgdalalljni.so.30 differ diff --git a/src/main/resources/gdal/ubuntu/lib/ogdi/libgdal.so b/src/main/resources/gdal/ubuntu/lib/ogdi/libgdal.so new file mode 100644 index 000000000..014d78399 Binary files /dev/null and b/src/main/resources/gdal/ubuntu/lib/ogdi/libgdal.so differ diff --git a/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so b/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so new file mode 100644 index 000000000..4a28d7c70 Binary files /dev/null and b/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so differ diff --git a/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so.30 b/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so.30 new file mode 100644 index 000000000..4a28d7c70 Binary files /dev/null and b/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so.30 differ diff --git a/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so.30.0.3 b/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so.30.0.3 new file mode 100644 index 000000000..4a28d7c70 Binary files /dev/null and b/src/main/resources/gdal/ubuntu/usr/lib/libgdal.so.30.0.3 differ diff --git a/src/main/resources/scripts/apt_setup_docker.sh b/src/main/resources/scripts/apt_setup_docker.sh new file mode 100644 index 000000000..6a19f6362 --- /dev/null +++ b/src/main/resources/scripts/apt_setup_docker.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +apt-get update +apt-get install sudo +sudo apt-get install software-properties-common \ No newline at end of file diff --git a/src/main/resources/scripts/install-gdal-databricks.sh b/src/main/resources/scripts/install-gdal-databricks.sh index d3e249f4c..741ef5031 100644 --- a/src/main/resources/scripts/install-gdal-databricks.sh +++ b/src/main/resources/scripts/install-gdal-databricks.sh @@ -1,33 +1,28 @@ #!/bin/bash # -# File: init-gdal_3.4.3_ubuntugis.sh +# File: mosaic-gdal-3.4.3-filetree-init.sh # Author: Michael Johns -# Created: 2022-08-19 +# Modified: 2023-07-23 # +# !!! FOR DBR 11.x and 12.x ONLY [Ubuntu 20.04] !!! +# !!! NOT for DBR 13.x [Ubuntu 22.04] !!! +# +# 1. script is using custom tarballs for offline / self-contained install of GDAL +# 2. This will unpack files directly into the filetree across cluster nodes (vs run apt install) +# +# -- install databricks-mosaic-gdal on cluster +# - use version 3.4.3 (exactly) from pypi.org +pip install databricks-mosaic-gdal==3.4.3 -MOSAIC_GDAL_JNI_DIR="${MOSAIC_GDAL_JNI_DIR:-__DEFAULT_JNI_PATH__}" - -sudo rm -r /var/lib/apt/lists/* -sudo add-apt-repository main -sudo add-apt-repository universe -sudo add-apt-repository restricted -sudo add-apt-repository multiverse -sudo add-apt-repository ppa:ubuntugis/ubuntugis-unstable -sudo apt clean && sudo apt -o Acquire::Retries=3 update --fix-missing -y -sudo apt-get -o Acquire::Retries=3 update -y -sudo apt-get -o Acquire::Retries=3 install -y gdal-bin=3.4.3+dfsg-1~focal0 libgdal-dev=3.4.3+dfsg-1~focal0 python3-gdal=3.4.3+dfsg-1~focal0 +# -- find the install dir +# - if this were run in a notebook would use $VIRTUAL_ENV +# - since it is init script it lands in $DATABRICKS_ROOT_VIRTUALENV_ENV +GDAL_RESOURCE_DIR=$(find $DATABRICKS_ROOT_VIRTUALENV_ENV -name "databricks-mosaic-gdal") -# fix python file naming in osgeo package -cd /usr/lib/python3/dist-packages/osgeo \ - && mv _gdal.cpython-38-x86_64-linux-gnu.so _gdal.so \ - && mv _gdal_array.cpython-38-x86_64-linux-gnu.so _gdal_array.so \ - && mv _gdalconst.cpython-38-x86_64-linux-gnu.so _gdalconst.so \ - && mv _ogr.cpython-38-x86_64-linux-gnu.so _ogr.so \ - && mv _gnm.cpython-38-x86_64-linux-gnu.so _gnm.so \ - && mv _osr.cpython-38-x86_64-linux-gnu.so _osr.so +# -- untar files to root +# - from databricks-mosaic-gdal install dir +tar -xf $GDAL_RESOURCE_DIR/resources/gdal-3.4.3-filetree.tar.xz -C / -# add pre-build JNI shared object to the path -# please run MosaicGDAL.copySharedObjects("/dbfs/FileStore/geospatial/mosaic/gdal/") before enabling this init script -mkdir -p /usr/lib/jni -cp "${MOSAIC_GDAL_JNI_DIR}/libgdalalljni.so" /usr/lib/jni -cp "${MOSAIC_GDAL_JNI_DIR}/libgdalalljni.so.30" /usr/lib/jni \ No newline at end of file +# -- untar symlinks to root +# - from databricks-mosaic-gdal install dir +tar -xhf $GDAL_RESOURCE_DIR/resources/gdal-3.4.3-symlinks.tar.xz -C / diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala index 9c0384be1..f6768d4c5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala @@ -1,8 +1,12 @@ package com.databricks.labs.mosaic.core.geometry import com.databricks.labs.mosaic.core.crs.CRSBoundsProvider +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.geometry.linestring.MosaicLineString import com.databricks.labs.mosaic.core.geometry.point.MosaicPoint +import org.gdal.ogr.ogr +import org.gdal.osr.SpatialReference +import org.gdal.osr.osrConstants._ import org.locationtech.proj4j._ import java.util.Locale @@ -51,6 +55,8 @@ trait MosaicGeometry extends GeometryWriter with Serializable { def buffer(distance: Double): MosaicGeometry + def bufferCapStyle(distance: Double, capStyle: String): MosaicGeometry + def simplify(tolerance: Double): MosaicGeometry def intersection(other: MosaicGeometry): MosaicGeometry @@ -59,6 +65,16 @@ trait MosaicGeometry extends GeometryWriter with Serializable { def envelope: MosaicGeometry + def extent: (Double, Double, Double, Double) = { + val env = envelope + ( + env.minMaxCoord("X", "MIN"), + env.minMaxCoord("Y", "MIN"), + env.minMaxCoord("X", "MAX"), + env.minMaxCoord("Y", "MAX") + ) + } + def union(other: MosaicGeometry): MosaicGeometry def unaryUnion: MosaicGeometry @@ -97,6 +113,15 @@ trait MosaicGeometry extends GeometryWriter with Serializable { def transformCRSXY(sridTo: Int): MosaicGeometry + def osrTransformCRS(srcSR: SpatialReference, destSR: SpatialReference, geometryAPI: GeometryAPI): MosaicGeometry = { + if (srcSR.IsSame(destSR) == 1) return this + val ogcGeometry = ogr.CreateGeometryFromWkb(this.toWKB) + ogcGeometry.AssignSpatialReference(srcSR) + ogcGeometry.TransformTo(destSR) + val mosaicGeometry = geometryAPI.geometry(ogcGeometry.ExportToWkb, "WKB") + mosaicGeometry + } + def transformCRSXY(sridTo: Int, sridFrom: Int): MosaicGeometry = { transformCRSXY(sridTo, Some(sridFrom)) } @@ -127,6 +152,18 @@ trait MosaicGeometry extends GeometryWriter with Serializable { def setSpatialReference(srid: Int): Unit + def getSpatialReferenceOSR: SpatialReference = { + val srID = getSpatialReference + if (srID == 0) { + null + } else { + val geomCRS = new SpatialReference() + geomCRS.ImportFromEPSG(srID) + geomCRS.SetAxisMappingStrategy(OAMS_TRADITIONAL_GIS_ORDER) + geomCRS + } + } + def hasValidCoords(crsBoundsProvider: CRSBoundsProvider, crsCode: String, which: String): Boolean = { val crsCodeIn = crsCode.split(":") val crsBounds = which.toLowerCase(Locale.ROOT) match { diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryESRI.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryESRI.scala index 3cf170289..75dc2c727 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryESRI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryESRI.scala @@ -74,6 +74,9 @@ abstract class MosaicGeometryESRI(geom: OGCGeometry) extends MosaicGeometry { override def buffer(distance: Double): MosaicGeometryESRI = MosaicGeometryESRI(geom.buffer(distance)) + // This is NOOP in ESRI bindings, JTS provides a different implementation + override def bufferCapStyle(distance: Double, capStyle: String): MosaicGeometryESRI = buffer(distance) + override def simplify(tolerance: Double): MosaicGeometryESRI = MosaicGeometryESRI(geom.makeSimple()) override def envelope: MosaicGeometryESRI = MosaicGeometryESRI(geom.envelope()) @@ -218,11 +221,12 @@ object MosaicGeometryESRI extends GeometryReader { if (polygons.length == 1) Seq(polygons.head) else if (polygons.length > 1) Seq(polygons.reduce(_ union _)) else Nil val pieces = multiPoint ++ multiLine ++ multiPolygon - val result = if (pieces.length == 1) { - pieces.head - } else { - pieces.reduce(_ union _) - } + val result = + if (pieces.length == 1) { + pieces.head + } else { + pieces.reduce(_ union _) + } result.setSpatialReference(getSRID(srid)) result } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala index 6cb8cac1c..b360a1df8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala @@ -15,6 +15,7 @@ import org.locationtech.jts.geom.{Geometry, GeometryCollection, GeometryFactory} import org.locationtech.jts.geom.util.AffineTransformation import org.locationtech.jts.io._ import org.locationtech.jts.io.geojson.{GeoJsonReader, GeoJsonWriter} +import org.locationtech.jts.operation.buffer.{BufferOp, BufferParameters} import org.locationtech.jts.simplify.DouglasPeuckerSimplifier import java.util @@ -63,6 +64,20 @@ abstract class MosaicGeometryJTS(geom: Geometry) extends MosaicGeometry { MosaicGeometryJTS(buffered) } + override def bufferCapStyle(distance: Double, capStyle: String): MosaicGeometryJTS = { + val capStyleConst = capStyle match { + case "round" => BufferParameters.CAP_ROUND + case "flat" => BufferParameters.CAP_FLAT + case "square" => BufferParameters.CAP_SQUARE + case _ => BufferParameters.CAP_ROUND + } + val gBuf = new BufferOp(geom) + gBuf.setEndCapStyle(capStyleConst) + val buffered = gBuf.getResultGeometry(distance) + buffered.setSRID(geom.getSRID) + MosaicGeometryJTS(buffered) + } + override def simplify(tolerance: Double = 1e-8): MosaicGeometryJTS = { val simplified = DouglasPeuckerSimplifier.simplify(geom, tolerance) simplified.setSRID(geom.getSRID) @@ -194,7 +209,7 @@ object MosaicGeometryJTS extends GeometryReader { override def fromWKT(wkt: String): MosaicGeometryJTS = MosaicGeometryJTS(new WKTReader().read(wkt)) - //noinspection DuplicatedCode + // noinspection DuplicatedCode def compactCollection(geometries: Seq[Geometry], srid: Int): Geometry = { def appendGeometries(geometries: util.ArrayList[Geometry], toAppend: Seq[Geometry]): Unit = { if (toAppend.length == 1 && !toAppend.head.isEmpty) { diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/api/GeometryAPI.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/api/GeometryAPI.scala index 7061608e0..bbda2aaa9 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/api/GeometryAPI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/api/GeometryAPI.scala @@ -1,10 +1,12 @@ package com.databricks.labs.mosaic.core.geometry.api +import com.databricks.labs.mosaic.MOSAIC_GEOMETRY_API import com.databricks.labs.mosaic.codegen.format._ import com.databricks.labs.mosaic.core.geometry._ import com.databricks.labs.mosaic.core.geometry.point._ import com.databricks.labs.mosaic.core.types._ import com.databricks.labs.mosaic.core.types.model.{Coordinates, GeometryTypeEnum} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -15,6 +17,16 @@ abstract class GeometryAPI( reader: GeometryReader ) extends Serializable { + def createBbox(xMin: Double, yMin: Double, xMax: Double, yMax: Double): MosaicGeometry = { + val p1 = fromGeoCoord(Coordinates(xMin, yMin)) + val p2 = fromGeoCoord(Coordinates(xMin, yMax)) + val p3 = fromGeoCoord(Coordinates(xMax, yMax)) + val p4 = fromGeoCoord(Coordinates(xMax, yMin)) + val p5 = fromGeoCoord(Coordinates(xMin, yMin)) + geometry(Seq(p1, p2, p3, p4, p5), GeometryTypeEnum.POLYGON) + } + + def name: String def geometry(input: Any, typeName: String): MosaicGeometry = { @@ -110,4 +122,9 @@ object GeometryAPI extends Serializable { case _ => throw new Error(s"Unsupported API name: $name.") } + def apply(sparkSession: SparkSession): GeometryAPI = { + val apiName = sparkSession.conf.get(MOSAIC_GEOMETRY_API, "JTS") + apply(apiName) + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/geometrycollection/MosaicGeometryCollectionESRI.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/geometrycollection/MosaicGeometryCollectionESRI.scala index dfdcb1bc2..b688962db 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/geometrycollection/MosaicGeometryCollectionESRI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/geometrycollection/MosaicGeometryCollectionESRI.scala @@ -167,7 +167,7 @@ object MosaicGeometryCollectionESRI extends GeometryReader { // POINT by convention, MULTIPOINT are always flattened to POINT in the internal representation val coordinates = holesRings.head.head.coords MosaicPointESRI( - new OGCPoint(new Point(coordinates(0), coordinates(1)), spatialReference) + new OGCPoint(new Point(coordinates.head, coordinates(1)), spatialReference) ) } else { MosaicGeometryESRI.fromWKT("POINT EMPTY") diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/BNGIndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/BNGIndexSystem.scala index caa926fb9..f325ae8aa 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/BNGIndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/BNGIndexSystem.scala @@ -26,22 +26,25 @@ import scala.util.{Success, Try} * @see * [[https://en.wikipedia.org/wiki/Ordnance_Survey_National_Grid]] */ +//noinspection ScalaWeakerAccess object BNGIndexSystem extends IndexSystem(StringType) with Serializable { + override def crsID: Int = 27700 + val name = "BNG" /** * Quadrant encodings. The order is determined in a way that preserves * similarity to space filling curves. */ - val quadrants = Seq("", "SW", "NW", "NE", "SE") + val quadrants: Seq[String] = Seq("", "SW", "NW", "NE", "SE") /** * Resolution mappings from string names to integer encodings. Resolutions * are uses as integers in any index math so we need to convert sizes to * corresponding index resolutions. */ - val resolutionMap = + val resolutionMap: Map[String, Int] = Map( "500km" -> -1, "100km" -> 1, @@ -60,7 +63,7 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable { /** * Mapping from string names to edge sizes expressed in eastings/northings. */ - val sizeMap = + val sizeMap: Map[String, Int] = Map( "500km" -> 500000, "100km" -> 100000, @@ -82,21 +85,22 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable { * coverage of this index system having a lookup is more efficient than * performing any math transformations between ints and chars. */ - val letterMap = + val letterMap: Seq[Seq[String]] = Seq( - Seq("SV", "SW", "SX", "SY", "SZ", "TV", "TW"), - Seq("SQ", "SR", "SS", "ST", "SU", "TQ", "TR"), - Seq("SL", "SM", "SN", "SO", "SP", "TL", "TM"), - Seq("SF", "SG", "SH", "SJ", "SK", "TF", "TG"), - Seq("SA", "SB", "SC", "SD", "SE", "TA", "TB"), - Seq("NV", "NW", "NX", "NY", "NZ", "OV", "OW"), - Seq("NQ", "NR", "NS", "NT", "NU", "OQ", "OR"), - Seq("NL", "NM", "NN", "NO", "NP", "OL", "OM"), - Seq("NF", "NG", "NH", "NJ", "NK", "OF", "OG"), - Seq("NA", "NB", "NC", "ND", "NE", "OA", "OB"), - Seq("HV", "HW", "HX", "HY", "SZ", "JV", "JW"), - Seq("HQ", "HR", "HS", "HT", "HU", "JQ", "JR"), - Seq("HL", "HM", "HN", "HO", "HP", "JL", "JM") + Seq("SV", "SW", "SX", "SY", "SZ", "TV", "TW", "TX"), + Seq("SQ", "SR", "SS", "ST", "SU", "TQ", "TR", "TS"), + Seq("SL", "SM", "SN", "SO", "SP", "TL", "TM", "TN"), + Seq("SF", "SG", "SH", "SJ", "SK", "TF", "TG", "TH"), + Seq("SA", "SB", "SC", "SD", "SE", "TA", "TB", "TC"), + Seq("NV", "NW", "NX", "NY", "NZ", "OV", "OW", "OX"), + Seq("NQ", "NR", "NS", "NT", "NU", "OQ", "OR", "OS"), + Seq("NL", "NM", "NN", "NO", "NP", "OL", "OM", "ON"), + Seq("NF", "NG", "NH", "NJ", "NK", "OF", "OG", "OH"), + Seq("NA", "NB", "NC", "ND", "NE", "OA", "OB", "OC"), + Seq("HV", "HW", "HX", "HY", "SZ", "JV", "JW", "JX"), + Seq("HQ", "HR", "HS", "HT", "HU", "JQ", "JR", "JS"), + Seq("HL", "HM", "HN", "HO", "HP", "JL", "JM", "JN"), + Seq("HF", "HG", "HH", "HJ", "HK", "JF", "JG", "JH") ) /** @@ -254,7 +258,7 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable { * @return * Boolean representing validity. */ - def isValid(index: Long): Boolean = { + override def isValid(index: Long): Boolean = { val digits = indexDigits(index) val xLetterIndex = digits.slice(3, 5).mkString.toInt val yLetterIndex = digits.slice(1, 3).mkString.toInt @@ -434,7 +438,9 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable { val p2 = geometryAPI.fromCoords(Seq(x + edgeSize, y)) val p3 = geometryAPI.fromCoords(Seq(x + edgeSize, y + edgeSize)) val p4 = geometryAPI.fromCoords(Seq(x, y + edgeSize)) - geometryAPI.geometry(Seq(p1, p2, p3, p4, p1), POLYGON) + val geom = geometryAPI.geometry(Seq(p1, p2, p3, p4, p1), POLYGON) + geom.setSpatialReference(this.crsID) + geom } /** @@ -518,7 +524,7 @@ object BNGIndexSystem extends IndexSystem(StringType) with Serializable { val digits = indexDigits(index) val resolution = getResolution(digits) val edgeSize = getEdgeSize(resolution).asInstanceOf[Double] - val area = math.pow((edgeSize / 1000), 2) + val area = math.pow(edgeSize / 1000, 2) area } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/CustomIndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/CustomIndexSystem.scala index 28e765c68..8bd1b46c5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/CustomIndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/CustomIndexSystem.scala @@ -10,8 +10,13 @@ import org.apache.spark.unsafe.types.UTF8String import scala.util.{Success, Try} /** Implements the [[IndexSystem]] for any CRS system. */ +//noinspection ScalaWeakerAccess case class CustomIndexSystem(conf: GridConf) extends IndexSystem(LongType) with Serializable { + override def crsID: Int = conf.crsID.getOrElse( + throw new Error("CRS ID is not defined for this index system") + ) + val name = f"CUSTOM(${conf.boundXMin}, ${conf.boundXMax}, ${conf.boundYMin}, ${conf.boundYMax}, ${conf.cellSplits}, ${conf.rootCellSizeX}, ${conf.rootCellSizeY})" diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/GridConf.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/GridConf.scala index 1c7f7853d..4df2c290b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/GridConf.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/GridConf.scala @@ -7,7 +7,8 @@ case class GridConf( boundYMax: Long, cellSplits: Int, rootCellSizeX: Int, - rootCellSizeY: Int + rootCellSizeY: Int, + crsID: Option[Int] = None ) { private val spanX = boundXMax - boundXMin private val spanY = boundYMax - boundYMin @@ -15,16 +16,17 @@ case class GridConf( val resBits = 8 // We keep 8 Most Significant Bits for resolution val idBits = 56 // The rest can be used for the cell ID - val subCellsCount = cellSplits * cellSplits + //noinspection ScalaWeakerAccess + val subCellsCount: Int = cellSplits * cellSplits // We need a distinct value for each cell, plus one bit for the parent cell (all-zeroes for LSBs) // We compute it with log2(subCellsCount) - val bitsPerResolution = Math.ceil(Math.log10(subCellsCount) / Math.log10(2)).toInt + val bitsPerResolution: Int = Math.ceil(Math.log10(subCellsCount) / Math.log10(2)).toInt // A cell ID has to fit the reserved number of bits - val maxResolution = Math.min(20, Math.floor(idBits / bitsPerResolution).toInt) + val maxResolution: Int = Math.min(20, Math.floor(idBits / bitsPerResolution).toInt) - val rootCellCountX = Math.ceil(spanX.toDouble / rootCellSizeX).toInt - val rootCellCountY = Math.ceil(spanY.toDouble / rootCellSizeY).toInt + val rootCellCountX: Int = Math.ceil(spanX.toDouble / rootCellSizeX).toInt + val rootCellCountY: Int = Math.ceil(spanY.toDouble / rootCellSizeY).toInt } \ No newline at end of file diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala index 79ba98390..8b5c2e6c5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/H3IndexSystem.scala @@ -22,6 +22,8 @@ import scala.util.{Success, Try} */ object H3IndexSystem extends IndexSystem(LongType) with Serializable { + override def crsID: Int = 4326 + val name = "H3" // An instance of H3Core to be used for IndexSystem implementation. @@ -95,8 +97,11 @@ object H3IndexSystem extends IndexSystem(LongType) with Serializable { val boundary = h3.h3ToGeoBoundary(index).asScala val extended = boundary ++ List(boundary.head) - if (crossesNorthPole(index) || crossesSouthPole(index)) makePoleGeometry(boundary, crossesNorthPole(index), geometryAPI) - else makeSafeGeometry(extended, geometryAPI) + val geom = + if (crossesNorthPole(index) || crossesSouthPole(index)) makePoleGeometry(boundary, crossesNorthPole(index), geometryAPI) + else makeSafeGeometry(extended, geometryAPI) + geom.setSpatialReference(crsID) + geom } /** @@ -199,8 +204,12 @@ object H3IndexSystem extends IndexSystem(LongType) with Serializable { val boundary = h3.h3ToGeoBoundary(index).asScala val extended = boundary ++ List(boundary.head) - if (crossesNorthPole(index) || crossesSouthPole(index)) makePoleGeometry(boundary, crossesNorthPole(index), geometryAPI) - else makeSafeGeometry(extended, geometryAPI) + val geom = + if (crossesNorthPole(index) || crossesSouthPole(index)) makePoleGeometry(boundary, crossesNorthPole(index), geometryAPI) + else makeSafeGeometry(extended, geometryAPI) + + geom.setSpatialReference(crsID) + geom } override def format(id: Long): String = { diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala index 0144760a0..64ea08c7a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala @@ -6,6 +6,7 @@ import com.databricks.labs.mosaic.core.types.model.{Coordinates, GeometryTypeEnu import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.gdal.osr.SpatialReference /** * Defines the API that all index systems need to respect for Mosaic to support @@ -13,6 +14,25 @@ import org.apache.spark.unsafe.types.UTF8String */ abstract class IndexSystem(var cellIdType: DataType) extends Serializable { + // Passthrough if not redefined + def isValid(cellID: Long): Boolean = true + + def crsID: Int + + /** + * Returns the spatial reference of the index system. This is only + * available when GDAL is available. For proj4j please use crsID method. + * + * @return + * SpatialReference + */ + def osrSpatialRef: SpatialReference = { + val sr = new SpatialReference() + sr.ImportFromEPSG(crsID) + sr.SetAxisMappingStrategy(org.gdal.osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + sr + } + def distance(cellId: Long, cellId2: Long): Long def getCellIdDataType: DataType = cellIdType diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystemFactory.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystemFactory.scala index 65c02970b..db6b3a111 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystemFactory.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystemFactory.scala @@ -1,24 +1,65 @@ package com.databricks.labs.mosaic.core.index +import org.apache.spark.sql.SparkSession + object IndexSystemFactory { + /** + * Returns the index system based on the spark session configuration. If + * the spark session configuration is not set, it will default to H3. + * @param spark + * SparkSession + * @return + * IndexSystem + */ + def getIndexSystem(spark: SparkSession): IndexSystem = { + val indexSystem = spark.conf.get("spark.databricks.labs.mosaic.index.system", "H3") + getIndexSystem(indexSystem) + } + + /** + * Returns the index system based on the name provided. If the name is not + * supported, it will throw an error. For custom index systems, the format + * is as follows: CUSTOM(xMin, xMax, yMin, yMax, splits, rootCellSizeX, + * rootCellSizeY, crsID) or CUSTOM(xMin, xMax, yMin, yMax, splits, + * rootCellSizeX, rootCellSizeY) + * @param name + * String + * @return + * IndexSystem + */ def getIndexSystem(name: String): IndexSystem = { val customIndexRE = "CUSTOM\\((-?\\d+), ?(-?\\d+), ?(-?\\d+), ?(-?\\d+), ?(\\d+), ?(\\d+), ?(\\d+) ?\\)".r + val customIndexWithCRSRE = "CUSTOM\\((-?\\d+), ?(-?\\d+), ?(-?\\d+), ?(-?\\d+), ?(\\d+), ?(\\d+), ?(\\d+), ?(\\d+) ?\\)".r name match { - case "H3" => H3IndexSystem - case "BNG" => BNGIndexSystem - case customIndexRE(xMin, xMax, yMin, yMax, splits, rootCellSizeX, rootCellSizeY) - => new CustomIndexSystem( - GridConf( - xMin.toInt, - xMax.toInt, - yMin.toInt, - yMax.toInt, - splits.toInt, - rootCellSizeX.toInt, - rootCellSizeY.toInt)) + case "H3" => H3IndexSystem + case "BNG" => BNGIndexSystem + case customIndexRE(xMin, xMax, yMin, yMax, splits, rootCellSizeX, rootCellSizeY) => CustomIndexSystem( + GridConf( + xMin.toInt, + xMax.toInt, + yMin.toInt, + yMax.toInt, + splits.toInt, + rootCellSizeX.toInt, + rootCellSizeY.toInt + ) + ) + case customIndexWithCRSRE(xMin, xMax, yMin, yMax, splits, rootCellSizeX, rootCellSizeY, crsID) => CustomIndexSystem( + GridConf( + xMin.toInt, + xMax.toInt, + yMin.toInt, + yMax.toInt, + splits.toInt, + rootCellSizeX.toInt, + rootCellSizeY.toInt, + Some(crsID.toInt) + ) + ) case _ => throw new Error("Index not supported yet!") } } -} \ No newline at end of file + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRaster.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRaster.scala deleted file mode 100644 index 7e2c937d9..000000000 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRaster.scala +++ /dev/null @@ -1,90 +0,0 @@ -package com.databricks.labs.mosaic.core.raster - -import org.gdal.gdal.Dataset - -/** - * A base API for managing raster data in Mosaic. Any raster abstraction should - * extend this trait. - * - * @param path - * The path to the raster file. This has to be a path that can be read by the - * worker nodes. - * - * @param memSize - * The amount of memory occupied by the file in bytes. - */ -abstract class MosaicRaster(path: String, memSize: Long) extends Serializable { - - /** - * Writes out the current raster to the given checkpoint path. The raster - * is written out as a GeoTiff. Only single subdataset is supported. Apply - * mask to all bands. Trim down the raster to the provided extent. - * @param stageId - * the UUI of the computation stage generating the raster. Used to avoid - * writing collisions. - * @param rasterId - * the UUID of the raster. Used to avoid writing collisions. - * @param extent - * The extent to trim the raster to. - * @param checkpointPath - * The path to write the raster to. - * - * @return - * Returns the path to the written raster. - */ - def saveCheckpoint(stageId: String, rasterId: Long, extent: (Int, Int, Int, Int), checkpointPath: String): String - - /** @return Returns the metadata of the raster file. */ - def metadata: Map[String, String] - - /** - * @return - * Returns the key->value pairs of subdataset->description for the - * raster. - */ - def subdatasets: Map[String, String] - - /** @return Returns the number of bands in the raster. */ - def numBands: Int - - /** @return Returns the SRID in the raster. */ - def SRID: Int - - /** @return Returns the proj4 projection string in the raster. */ - def proj4String: String - - /** @return Returns the x size of the raster. */ - def xSize: Int - - /** @return Returns the y size of the raster. */ - def ySize: Int - - /** @return Returns the bandId-th Band from the raster. */ - def getBand(bandId: Int): MosaicRasterBand - - /** @return Returns the extent(xmin, ymin, xmax, ymax) of the raster. */ - def extent: Seq[Double] - - /** @return Returns the GDAL Dataset representing the raster. */ - def getRaster: Dataset - - /** Cleans up the raster driver and references. */ - def cleanUp(): Unit - - /** @return Returns the amount of memory occupied by the file in bytes. */ - def getMemSize: Long = memSize - - /** - * A template method for transforming the raster bands into new bands. Each - * band is transformed into a new band using the transform function. - * Override this method for tiling, clipping, warping, etc. type of - * expressions. - * - * @tparam T - * The type of the result from the transformation of a band. - * @param f - * The transform function. Will be applied on each band. - */ - def transformBands[T](f: MosaicRasterBand => T): Seq[T] - -} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterBand.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterBand.scala deleted file mode 100644 index af8f204a0..000000000 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterBand.scala +++ /dev/null @@ -1,132 +0,0 @@ -package com.databricks.labs.mosaic.core.raster - -/** - * A base API for managing raster bands in Mosaic. Any raster band abstraction - * should extend this trait. - */ -trait MosaicRasterBand extends Serializable { - - /** @return Returns the bandId of the band. */ - def index: Int - - /** @return Returns the description of the band. */ - def description: String - - /** @return Returns the metadata of the band. */ - def metadata: Map[String, String] - - /** @return Returns the unit type of the band pixels. */ - def units: String - - /** @return Returns the data type (numeric) of the band pixels. */ - def dataType: Int - - /** @return Returns the x size of the band. */ - def xSize: Int - - /** @return Returns the y size of the band. */ - def ySize: Int - - /** @return Returns the minimum pixel value of the band. */ - def minPixelValue: Double - - /** @return Returns the maximum pixel value of the band. */ - def maxPixelValue: Double - - /** - * @return - * Returns the value used to represent transparent pixels of the band. - */ - def noDataValue: Double - - /** - * @return - * Returns the scale in which pixels are represented. It is the unit - * value of a pixel. If the pixel value is 5.1 and pixel scale is 10.0 - * then the actual pixel value is 51.0. - */ - def pixelValueScale: Double - - /** - * @return - * Returns the offset in which pixels are represented. It is the unit - * value of a pixel. If the pixel value is 5.1 and pixel offset is 10.0 - * then the actual pixel value is 15.1. - */ - def pixelValueOffset: Double - - /** - * @return - * Returns the pixel value with scale and offset applied. If the pixel - * value is 5.1 and pixel scale is 10.0 and pixel offset is 10.0 then the - * actual pixel value is 61.0. - */ - def pixelValueToUnitValue(pixelValue: Double): Double - - /** - * @return - * Returns the pixels of the raster as a 1D array. - */ - def values: Array[Double] = values(0, 0, xSize, ySize) - - /** - * @return - * Returns the pixels of the raster as a 1D array. - */ - def maskValues: Array[Double] = maskValues(0, 0, xSize, ySize) - - /** - * @param xOffset - * The x offset of the raster. The x offset is the number of pixels to - * skip from the left. 0 <= xOffset < xSize - * - * @param yOffset - * The y offset of the raster. The y offset is the number of pixels to - * skip from the top. 0 <= yOffset < ySize - * - * @param xSize - * The x size of the raster to be read. - * - * @param ySize - * The y size of the raster to be read. - * @return - * Returns the pixels of the raster as a 1D array with offset and size - * applied. - */ - def values(xOffset: Int, yOffset: Int, xSize: Int, ySize: Int): Array[Double] - - /** - * @param xOffset - * The x offset of the raster. The x offset is the number of pixels to - * skip from the left. 0 <= xOffset < xSize - * - * @param yOffset - * The y offset of the raster. The y offset is the number of pixels to - * skip from the top. 0 <= yOffset < ySize - * - * @param xSize - * The x size of the raster to be read. - * - * @param ySize - * The y size of the raster to be read. - * @return - * Returns the mask pixels of the raster as a 1D array with offset and size - * applied. - */ - def maskValues(xOffset: Int, yOffset: Int, xSize: Int, ySize: Int): Array[Double] - - /** - * Apply f to all pixels in the raster. Overridden in subclasses to define - * the behavior. - * @param f - * the function to apply to each pixel. - * @param default - * the default value to use if the pixel is noData. - * @tparam T - * the return type of the function. - * @return - * an array of the results of applying f to each pixel. - */ - def transformValues[T](f: (Int, Int, Double) => T, default: T = null): Seq[Seq[T]] - -} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterGDAL.scala deleted file mode 100644 index cba9940b7..000000000 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterGDAL.scala +++ /dev/null @@ -1,254 +0,0 @@ -package com.databricks.labs.mosaic.core.raster - -import org.gdal.gdal.{gdal, Dataset} -import org.gdal.gdalconst.gdalconstConstants._ -import org.gdal.osr.SpatialReference -import org.locationtech.proj4j.CRSFactory - -import java.io.File -import java.nio.file.{Files, Paths} -import java.nio.file.StandardCopyOption.REPLACE_EXISTING -import java.util.Locale -import scala.collection.JavaConverters.dictionaryAsScalaMapConverter -import scala.util.Try - -/** GDAL implementation of the MosaicRaster trait. */ -//noinspection DuplicatedCode -case class MosaicRasterGDAL(raster: Dataset, path: String, memSize: Long) extends MosaicRaster(path, memSize) { - - import com.databricks.labs.mosaic.core.raster.MosaicRasterGDAL.toWorldCoord - - val crsFactory: CRSFactory = new CRSFactory - - override def metadata: Map[String, String] = { - Option(raster.GetMetadataDomainList()) - .map(_.toArray) - .map(domain => - domain - .map(domainName => - Option(raster.GetMetadata_Dict(domainName.toString)) - .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) - .getOrElse(Map.empty[String, String]) - ) - .reduceOption(_ ++ _).getOrElse(Map.empty[String, String]) - ) - .getOrElse(Map.empty[String, String]) - - } - - override def subdatasets: Map[String, String] = { - val subdatasetsMap = Option(raster.GetMetadata_Dict("SUBDATASETS")) - .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) - .getOrElse(Map.empty[String, String]) - val keys = subdatasetsMap.keySet - keys.flatMap(key => - if (key.toUpperCase(Locale.ROOT).contains("NAME")) { - val path = subdatasetsMap(key) - Seq( - key -> path.split(":").last, - path.split(":").last -> path - ) - } else Seq(key -> subdatasetsMap(key)) - ).toMap - } - - override def SRID: Int = { - Try(crsFactory.readEpsgFromParameters(proj4String)) - .filter(_ != null) - .getOrElse("EPSG:0") - .split(":") - .last - .toInt - } - - override def proj4String: String = Try(raster.GetSpatialRef.ExportToProj4).filter(_ != null).getOrElse("") - - override def getBand(bandId: Int): MosaicRasterBand = { - if (bandId > 0 && numBands >= bandId) { - MosaicRasterBandGDAL(raster.GetRasterBand(bandId), bandId) - } else { - throw new ArrayIndexOutOfBoundsException() - } - } - - override def numBands: Int = raster.GetRasterCount() - - // noinspection ZeroIndexToHead - override def extent: Seq[Double] = { - val minx = getGeoTransform(0) - val maxy = getGeoTransform(3) - val maxx = minx + getGeoTransform(1) * xSize - val miny = maxy + getGeoTransform(5) * ySize - Seq(minx, miny, maxx, maxy) - } - - override def xSize: Int = raster.GetRasterXSize - - override def ySize: Int = raster.GetRasterYSize - - def getGeoTransform: Array[Double] = raster.GetGeoTransform() - - def getGeoTransform(extent: (Int, Int, Int, Int)): Array[Double] = { - val gt = getGeoTransform - val (xmin, _, _, ymax) = extent - val (xUpperLeft, yUpperLeft) = toWorldCoord(gt, xmin, ymax) - Array(xUpperLeft, gt(1), gt(2), yUpperLeft, gt(4), gt(5)) - } - - override def getRaster: Dataset = this.raster - - def spatialRef: SpatialReference = raster.GetSpatialRef() - - override def cleanUp(): Unit = { - - /** Nothing to clean up = NOOP */ - } - - override def transformBands[T](f: MosaicRasterBand => T): Seq[T] = for (i <- 1 to numBands) yield f(getBand(i)) - - /** - * Write the raster to a file. GDAL cannot write directly to dbfs. Raster - * is written to a local file first. "../../tmp/_" is used for the - * temporary file. The file is then copied to the checkpoint directory. The - * local copy is then deleted. Temporary files are written as GeoTiffs. - * Files with subdatasets are not supported. They should be flattened - * first. - * - * @param stageId - * the UUI of the computation stage generating the raster. Used to avoid - * writing collisions. - * @param rasterId - * the UUID of the raster. Used to avoid writing collisions. - * @param extent - * the extent to clip the raster to. This is used for writing out partial - * rasters. - * @param checkpointPath - * the path to the checkpoint directory. - * @return - * A path to the written raster. - */ - override def saveCheckpoint(stageId: String, rasterId: Long, extent: (Int, Int, Int, Int), checkpointPath: String): String = { - val tmpDir = Files.createTempDirectory(s"mosaic_$stageId").toFile.getAbsolutePath - val outPath = s"$tmpDir/raster_${rasterId.toString.replace("-", "_")}.tif" - Files.createDirectories(Paths.get(outPath).getParent) - val (xmin, ymin, xmax, ymax) = extent - val xSize = xmax - xmin - val ySize = ymax - ymin - val outputDs = gdal.GetDriverByName("GTiff").Create(outPath, xSize, ySize, numBands, GDT_Float64) - for (i <- 1 to numBands) { - val band = getBand(i) - val data = band.values(xmin, ymin, xSize, ySize) - val maskData = band.maskValues(xmin, ymin, xSize, ySize) - val noDataValue = band.noDataValue - - val outBand = outputDs.GetRasterBand(i) - val maskBand = outBand.GetMaskBand() - - outBand.SetNoDataValue(noDataValue) - outBand.WriteRaster(0, 0, xSize, ySize, data) - maskBand.WriteRaster(0, 0, xSize, ySize, maskData) - outBand.FlushCache() - maskBand.FlushCache() - } - outputDs.SetGeoTransform(getGeoTransform(extent)) - outputDs.FlushCache() - - val destinationPath = Paths.get(checkpointPath.replace("dbfs:/", "/dbfs/"), s"raster_$rasterId.tif") - Files.createDirectories(destinationPath) - Files.copy(Paths.get(outPath), destinationPath, REPLACE_EXISTING) - Files.delete(Paths.get(outPath)) - destinationPath.toAbsolutePath.toString.replace("dbfs:/", "/dbfs/") - } - -} - -//noinspection ZeroIndexToHead -object MosaicRasterGDAL extends RasterReader { - - def apply(dataset: Dataset, path: String, memSize: Long): MosaicRasterGDAL = new MosaicRasterGDAL(dataset, path, memSize) - - /** - * Reads a raster from a file system path. Reads a subdataset if the path - * is to a subdataset. - * - * @example - * Raster: path = "file:///path/to/file.tif" Subdataset: path = - * "file:///path/to/file.tif:subdataset" - * @param inPath - * The path to the raster file. - * @return - * A MosaicRaster object. - */ - override def readRaster(inPath: String): MosaicRaster = { - val path = inPath.replace("dbfs:/", "/dbfs/").replace("file:/", "/") - val dataset = gdal.Open(path, GA_ReadOnly) - val size = new File(path).length() - MosaicRasterGDAL(dataset, path, size) - } - - /** - * Reads a raster band from a file system path. Reads a subdataset band if - * the path is to a subdataset. - * - * @example - * Raster: path = "file:///path/to/file.tif" Subdataset: path = - * "file:///path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * @param bandIndex - * The band index to read. - * @return - * A MosaicRaster object. - */ - override def readBand(path: String, bandIndex: Int): MosaicRasterBand = { - val raster = readRaster(path) - raster.getBand(bandIndex) - } - - /** - * Take a geo transform matrix and x and y coordinates of a pixel and - * returns the x and y coors in the projection of the raster. As per GDAL - * documentation, the origin is the top left corner of the top left pixel - * @see - * https://gdal.org/tutorials/raster_api_tut.html - * - * @param geoTransform - * The geo transform matrix of the raster. - * - * @param x - * The x coordinate of the pixel. - * @param y - * The y coordinate of the pixel. - * @return - * A tuple of doubles with the x and y coordinates in the projection of - * the raster. - */ - override def toWorldCoord(geoTransform: Seq[Double], x: Int, y: Int): (Double, Double) = { - val Xp = geoTransform(0) + x * geoTransform(1) + y * geoTransform(2) - val Yp = geoTransform(3) + x * geoTransform(4) + y * geoTransform(5) - (Xp, Yp) - } - - /** - * Take a geo transform matrix and x and y coordinates of a point and - * returns the x and y coordinates of the raster pixel. - * @see - * // Reference: - * https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal - * - * @param geoTransform - * The geo transform matrix of the raster. - * @param xGeo - * The x coordinate of the point. - * @param yGeo - * The y coordinate of the point. - * @return - * A tuple of integers with the x and y coordinates of the raster pixel. - */ - override def fromWorldCoord(geoTransform: Seq[Double], xGeo: Double, yGeo: Double): (Int, Int) = { - val x = ((xGeo - geoTransform(0)) / geoTransform(1)).toInt - val y = ((yGeo - geoTransform(3)) / geoTransform(5)).toInt - (x, y) - } - -} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/RasterReader.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/RasterReader.scala deleted file mode 100644 index 34d88fd22..000000000 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/RasterReader.scala +++ /dev/null @@ -1,78 +0,0 @@ -package com.databricks.labs.mosaic.core.raster - -import org.apache.spark.internal.Logging - -/** - * RasterReader is a trait that defines the interface for reading raster data - * from a file system path. It is used by the RasterAPI to read raster and - * raster band data. - * @note - * For subdatasets the path should be the path to the subdataset and not to - * the file. - */ -trait RasterReader extends Logging { - - /** - * Reads a raster from a file system path. Reads a subdataset if the path - * is to a subdataset. - * - * @example - * Raster: path = "file:///path/to/file.tif" Subdataset: path = - * "file:///path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * @return - * A MosaicRaster object. - */ - def readRaster(path: String): MosaicRaster - - /** - * Reads a raster band from a file system path. Reads a subdataset band if - * the path is to a subdataset. - * @example - * Raster: path = "file:///path/to/file.tif" Subdataset: path = - * "file:///path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * - * @param bandIndex - * The band index to read. - * @return - * A MosaicRaster object. - */ - def readBand(path: String, bandIndex: Int): MosaicRasterBand - - /** - * Take a geo transform matrix and x and y coordinates of a pixel and - * returns the x and y coordinates in the projection of the raster. - * - * @param geoTransform - * The geo transform matrix of the raster. - * - * @param x - * The x coordinate of the pixel. - * @param y - * The y coordinate of the pixel. - * @return - * A tuple of doubles with the x and y coordinates in the projection of - * the raster. - */ - def toWorldCoord(geoTransform: Seq[Double], x: Int, y: Int): (Double, Double) - - /** - * Take a geo transform matrix and x and y coordinates of a point and - * returns the x and y coordinates of the raster pixel. - * - * @param geoTransform - * The geo transform matrix of the raster. - * - * @param x - * The x coordinate of the point. - * @param y - * The y coordinate of the point. - * @return - * A tuple of integers with the x and y coordinates of the raster pixel. - */ - def fromWorldCoord(geoTransform: Seq[Double], x: Double, y: Double): (Int, Int) - -} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala new file mode 100644 index 000000000..cdfaa76d2 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala @@ -0,0 +1,214 @@ +package com.databricks.labs.mosaic.core.raster.api + +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, MosaicRasterGDAL} +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform +import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL +import org.apache.spark.sql.types.{BinaryType, DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String +import org.gdal.gdal.gdal +import org.gdal.gdalconst.gdalconstConstants._ + +/** + * GDAL Raster API. It uses [[MosaicRasterGDAL]] as the + * [[com.databricks.labs.mosaic.core.raster.io.RasterReader]]. + */ +object GDAL { + + /** + * Returns the no data value for the given GDAL data type. For non-numeric + * data types, it returns 0.0. For numeric data types, it returns the + * minimum value of the data type. For unsigned data types, it returns the + * maximum value of the data type. + * + * @param gdalType + * The GDAL data type. + * @return + * Returns the no data value for the given GDAL data type. + */ + def getNoDataConstant(gdalType: Int): Double = { + gdalType match { + case GDT_Unknown => 0.0 + case GDT_Byte => 0.0 + // Unsigned Int16 is Char in scala + // https://www.tutorialspoint.com/scala/scala_data_types.htm + case GDT_UInt16 => Char.MaxValue.toDouble + case GDT_Int16 => Short.MinValue.toDouble + case GDT_UInt32 => 2 * Int.MinValue.toDouble + case GDT_Int32 => Int.MinValue.toDouble + case GDT_Float32 => Float.MinValue.toDouble + case GDT_Float64 => Double.MinValue + case _ => 0.0 + } + } + + /** @return Returns the name of the raster API. */ + def name: String = "GDAL" + + /** + * Enables GDAL on the worker nodes. GDAL requires drivers to be registered + * on the worker nodes. This method registers all the drivers on the worker + * nodes. + */ + def enable(): Unit = { + configureGDAL() + gdal.UseExceptions() + gdal.AllRegister() + } + + /** + * Returns the extension of the given driver. + * @param driverShortName + * The short name of the driver. For example, GTiff. + * @return + * Returns the extension of the driver. For example, tif. + */ + def getExtension(driverShortName: String): String = { + val driver = gdal.GetDriverByName(driverShortName) + val result = driver.GetMetadataItem("DMD_EXTENSION") + driver.delete() + result + } + + /** + * Reads a raster from the given input data. If it is a byte array, it will + * read the raster from the byte array. If it is a string, it will read the + * raster from the path. If the path is a zip file, it will read the raster + * from the zip file. If the path is a subdataset, it will read the raster + * from the subdataset. + * + * @param inputRaster + * The path to the raster. This path has to be a path to a single raster. + * Rasters with subdatasets are supported. + * @return + * Returns a Raster object. + */ + def readRaster( + inputRaster: => Any, + parentPath: String, + shortDriverName: String, + inputDT: DataType + ): MosaicRasterGDAL = { + inputDT match { + case StringType => + val path = inputRaster.asInstanceOf[UTF8String].toString + MosaicRasterGDAL.readRaster(path, parentPath) + case BinaryType => + val bytes = inputRaster.asInstanceOf[Array[Byte]] + val raster = MosaicRasterGDAL.readRaster(bytes, parentPath, shortDriverName) + // If the raster is coming as a byte array, we can't check for zip condition. + // We first try to read the raster directly, if it fails, we read it as a zip. + if (raster == null) { + val zippedPath = s"/vsizip/$parentPath" + MosaicRasterGDAL.readRaster(bytes, zippedPath, shortDriverName) + } else { + raster + } + } + } + + /** + * Writes the given rasters to either a path or a byte array. + * + * @param generatedRasters + * The rasters to write. + * @param checkpointPath + * The path to write the rasters to. + * @return + * Returns the paths of the written rasters. + */ + def writeRasters(generatedRasters: => Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = { + generatedRasters.map(raster => + if (raster != null) { + rasterDT match { + case StringType => + val extension = GDAL.getExtension(raster.getDriversShortName) + val writePath = s"$checkpointPath/${raster.uuid}.$extension" + val outPath = raster.writeToPath(writePath) + RasterCleaner.dispose(raster) + UTF8String.fromString(outPath) + case BinaryType => + val bytes = raster.writeToBytes() + RasterCleaner.dispose(raster) + bytes + } + } else { + null + } + ) + } + + /** + * Reads a raster from the given path. Assume not zipped file. If zipped, + * use raster(path, vsizip = true) + * + * @param path + * The path to the raster. This path has to be a path to a single raster. + * Rasters with subdatasets are supported. + * @return + * Returns a Raster object. + */ + def raster(path: String, parentPath: String): MosaicRasterGDAL = MosaicRasterGDAL.readRaster(path, parentPath) + + /** + * Reads a raster from the given byte array. If the byte array is a zip + * file, it will read the raster from the zip file. + * + * @param content + * The byte array to read the raster from. + * @return + * Returns a Raster object. + */ + def raster(content: => Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = + MosaicRasterGDAL.readRaster(content, parentPath, driverShortName) + + /** + * Reads a raster from the given path. It extracts the specified band from + * the raster. If zip, use band(path, bandIndex, vsizip = true) + * + * @param path + * The path to the raster. This path has to be a path to a single raster. + * Rasters with subdatasets are supported. + * @param bandIndex + * The index of the band to read from the raster. + * @return + * Returns a Raster band object. + */ + def band(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = + MosaicRasterGDAL.readBand(path, bandIndex, parentPath) + + /** + * Converts raster x, y coordinates to lat, lon coordinates. + * + * @param gt + * Geo transform of the raster. + * @param x + * X coordinate of the raster. + * @param y + * Y coordinate of the raster. + * @return + * Returns a tuple of (lat, lon). + */ + def toWorldCoord(gt: Seq[Double], x: Int, y: Int): (Double, Double) = { + val (xGeo, yGeo) = RasterTransform.toWorldCoord(gt, x, y) + (xGeo, yGeo) + } + + /** + * Converts lat, lon coordinates to raster x, y coordinates. + * + * @param gt + * Geo transform of the raster. + * @param x + * Latitude of the raster. + * @param y + * Longitude of the raster. + * @return + * Returns a tuple of (xPixel, yPixel). + */ + def fromWorldCoord(gt: Seq[Double], x: Double, y: Double): (Int, Int) = { + val (xPixel, yPixel) = RasterTransform.fromWorldCoord(gt, x, y) + (xPixel, yPixel) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/RasterAPI.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/RasterAPI.scala deleted file mode 100644 index bcd50e66e..000000000 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/RasterAPI.scala +++ /dev/null @@ -1,129 +0,0 @@ -package com.databricks.labs.mosaic.core.raster.api - -import com.databricks.labs.mosaic.core.raster._ -import org.gdal.gdal.gdal - -/** - * A base trait for all Raster API's. - * @param reader - * The RasterReader to use for reading the raster. - */ -abstract class RasterAPI(reader: RasterReader) extends Serializable { - - /** - * This method should be called in every raster expression if the RasterAPI - * requires enablement on worker nodes. - */ - def enable(): Unit - - /** @return Returns the name of the raster API. */ - def name: String - - /** - * Reads a raster from the given path. - * - * @param path - * The path to the raster. This path has to be a path to a single raster. - * Rasters with subdatasets are supported. - * @return - * Returns a Raster object. - */ - def raster(path: String): MosaicRaster = reader.readRaster(path) - - /** - * Reads a raster from the given path. It extracts the specified band from - * the raster. - * - * @param path - * The path to the raster. This path has to be a path to a single raster. - * Rasters with subdatasets are supported. - * @param bandIndex - * The index of the band to read from the raster. - * @return - * Returns a Raster band object. - */ - def band(path: String, bandIndex: Int): MosaicRasterBand = reader.readBand(path, bandIndex) - - /** - * Converts raster x, y coordinates to lat, lon coordinates. - * @param gt - * Geo transform of the raster. - * @param x - * X coordinate of the raster. - * @param y - * Y coordinate of the raster. - * @return - * Returns a tuple of (lat, lon). - */ - def toWorldCoord(gt: Seq[Double], x: Int, y: Int): (Double, Double) = { - val (xGeo, yGeo) = reader.toWorldCoord(gt, x, y) - (xGeo, yGeo) - } - - /** - * Converts lat, lon coordinates to raster x, y coordinates. - * @param gt - * Geo transform of the raster. - * @param x - * Latitude of the raster. - * @param y - * Longitude of the raster. - * @return - * Returns a tuple of (xPixel, yPixel). - */ - def fromWorldCoord(gt: Seq[Double], x: Double, y: Double): (Int, Int) = { - val (xPixel, yPixel) = reader.fromWorldCoord(gt, x, y) - (xPixel, yPixel) - } - -} - -/** - * A companion object for the RasterAPI trait. It provides a factory method for - * creating a RasterAPI object. - */ -object RasterAPI extends Serializable { - - /** - * Creates a RasterAPI object. - * @param name - * The name of the API to use. Currently only GDAL is supported. - * @return - * Returns a RasterAPI object. - */ - def apply(name: String): RasterAPI = - name match { - case "GDAL" => GDAL - } - - /** - * @param name - * The name of the API to use. Currently only GDAL is supported. - * @return - * Returns a RasterReader object. - */ - def getReader(name: String): RasterReader = - name match { - case "GDAL" => MosaicRasterGDAL - } - - /** - * GDAL Raster API. It uses [[MosaicRasterGDAL]] as the [[RasterReader]]. - */ - object GDAL extends RasterAPI(MosaicRasterGDAL) { - - override def name: String = "GDAL" - - /** - * Enables GDAL on the worker nodes. GDAL requires drivers to be - * registered on the worker nodes. This method registers all the - * drivers on the worker nodes. - */ - override def enable(): Unit = { - gdal.UseExceptions() - if (gdal.GetDriverCount() == 0) gdal.AllRegister() - } - - } - -} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala similarity index 56% rename from src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterBandGDAL.scala rename to src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index 4d5bc60c6..48eef2f07 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -1,4 +1,4 @@ -package com.databricks.labs.mosaic.core.raster +package com.databricks.labs.mosaic.core.raster.gdal import org.gdal.gdal.Band import org.gdal.gdalconst.gdalconstConstants @@ -7,11 +7,33 @@ import scala.collection.JavaConverters.dictionaryAsScalaMapConverter import scala.util._ /** GDAL implementation of the MosaicRasterBand trait. */ -case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { +class MosaicRasterBandGDAL(band: => Band, id: Int) { - override def index: Int = id + def getBand: Band = band - override def description: String = coerceNull(Try(band.GetDescription)) + /** + * @return + * The band's index. + */ + def index: Int = id + + /** + * @return + * The band's description. + */ + def description: String = coerceNull(Try(band.GetDescription)) + + /** + * @return + * Returns the pixels of the raster as a 1D array. + */ + def values: Array[Double] = values(0, 0, xSize, ySize) + + /** + * @return + * Returns the pixels of the raster as a 1D array. + */ + def maskValues: Array[Double] = maskValues(0, 0, xSize, ySize) /** * Get the band's metadata as a Map. @@ -19,12 +41,16 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { * @return * A Map of the band's metadata. */ - override def metadata: Map[String, String] = + def metadata: Map[String, String] = Option(band.GetMetadata_Dict) .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) .getOrElse(Map.empty[String, String]) - override def units: String = coerceNull(Try(band.GetUnitType)) + /** + * @return + * Returns band's unity type. + */ + def units: String = coerceNull(Try(band.GetUnitType)) /** * Utility method to coerce a null value to an empty string. @@ -35,16 +61,40 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { */ def coerceNull(tryVal: Try[String]): String = tryVal.filter(_ != null).getOrElse("") - override def dataType: Int = Try(band.getDataType).getOrElse(0) + /** + * @return + * Returns the band's data type. + */ + def dataType: Int = Try(band.getDataType).getOrElse(0) - override def xSize: Int = Try(band.GetXSize).getOrElse(0) + /** + * @return + * Returns the band's x size. + */ + def xSize: Int = Try(band.GetXSize).getOrElse(0) - override def ySize: Int = Try(band.GetYSize).getOrElse(0) + /** + * @return + * Returns the band's y size. + */ + def ySize: Int = Try(band.GetYSize).getOrElse(0) - override def minPixelValue: Double = computeMinMax.head + /** + * @return + * Returns the band's min pixel value. + */ + def minPixelValue: Double = computeMinMax.head - override def maxPixelValue: Double = computeMinMax.last + /** + * @return + * Returns the band's max pixel value. + */ + def maxPixelValue: Double = computeMinMax.last + /** + * @return + * Returns the band's min and max pixel values. + */ def computeMinMax: Seq[Double] = { val minMaxVals = Array.fill[Double](2)(0) Try(band.ComputeRasterMinMax(minMaxVals, 0)) @@ -52,7 +102,11 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { .getOrElse(Seq(Double.NaN, Double.NaN)) } - override def noDataValue: Double = { + /** + * @return + * Returns the band's no data value. + */ + def noDataValue: Double = { val noDataVal = Array.fill[java.lang.Double](1)(0) band.GetNoDataValue(noDataVal) noDataVal.head @@ -72,7 +126,7 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { * @return * A 2D array of pixels from the band. */ - override def values(xOffset: Int, yOffset: Int, xSize: Int, ySize: Int): Array[Double] = { + def values(xOffset: Int, yOffset: Int, xSize: Int, ySize: Int): Array[Double] = { val flatArray = Array.ofDim[Double](xSize * ySize) (xSize, ySize) match { case (0, 0) => Array.empty[Double] @@ -83,20 +137,20 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { } /** - * Get the band's pixels as a 1D array. - * - * @param xOffset - * The x offset to start reading from. - * @param yOffset - * The y offset to start reading from. - * @param xSize - * The number of pixels to read in the x direction. - * @param ySize - * The number of pixels to read in the y direction. - * @return - * A 2D array of pixels from the band. - */ - override def maskValues(xOffset: Int, yOffset: Int, xSize: Int, ySize: Int): Array[Double] = { + * Get the band's pixels as a 1D array. + * + * @param xOffset + * The x offset to start reading from. + * @param yOffset + * The y offset to start reading from. + * @param xSize + * The number of pixels to read in the x direction. + * @param ySize + * The number of pixels to read in the y direction. + * @return + * A 2D array of pixels from the band. + */ + def maskValues(xOffset: Int, yOffset: Int, xSize: Int, ySize: Int): Array[Double] = { val flatArray = Array.ofDim[Double](xSize * ySize) val maskBand = band.GetMaskBand (xSize, ySize) match { @@ -107,16 +161,24 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { } } - override def pixelValueToUnitValue(pixelValue: Double): Double = (pixelValue * pixelValueScale) + pixelValueOffset + /** + * @return + * Returns the band's pixel value with scale and offset applied. + */ + def pixelValueToUnitValue(pixelValue: Double): Double = (pixelValue * pixelValueScale) + pixelValueOffset - override def pixelValueScale: Double = { + def pixelValueScale: Double = { val scale = Array.fill[java.lang.Double](1)(0) Try(band.GetScale(scale)) .map(_ => scale.head.doubleValue()) .getOrElse(0.0) } - override def pixelValueOffset: Double = { + /** + * @return + * Returns the band's pixel value scale. + */ + def pixelValueOffset: Double = { val offset = Array.fill[java.lang.Double](1)(0) Try(band.GetOffset(offset)) .map(_ => offset.head.doubleValue()) @@ -135,7 +197,7 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { * @return * an array of the results of applying f to each pixel. */ - override def transformValues[T](f: (Int, Int, Double) => T, default: T = null): Seq[Seq[T]] = { + def transformValues[T](f: (Int, Int, Double) => T, default: T = null): Seq[Seq[T]] = { val maskBand = band.GetMaskBand() val bandValues = Array.ofDim[Double](band.GetXSize() * band.GetYSize()) val maskValues = Array.ofDim[Byte](band.GetXSize() * band.GetYSize()) @@ -157,4 +219,16 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) extends MosaicRasterBand { } } + /** + * @return + * Returns the band's mask flags. + */ + def maskFlags: Seq[Any] = Seq(band.GetMaskFlags()) + + /** + * @return + * Returns true if the band is a no data mask. + */ + def isNoDataMask: Boolean = band.GetMaskFlags() == gdalconstConstants.GMF_NODATA + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala new file mode 100644 index 000000000..42653c6bd --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -0,0 +1,705 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import com.databricks.labs.mosaic.core.geometry.MosaicGeometry +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.io.{RasterCleaner, RasterReader, RasterWriter} +import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector +import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.POLYGON +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.orc.util.Murmur3 +import org.gdal.gdal.gdal.GDALInfo +import org.gdal.gdal.{Dataset, InfoOptions, gdal} +import org.gdal.gdalconst.gdalconstConstants._ +import org.gdal.osr +import org.gdal.osr.SpatialReference +import org.locationtech.proj4j.CRSFactory + +import java.nio.file.{Files, Paths, StandardCopyOption} +import java.util.{Locale, UUID, Vector => JVector} +import scala.collection.JavaConverters.dictionaryAsScalaMapConverter +import scala.util.Try + +/** GDAL implementation of the MosaicRaster trait. */ +//noinspection DuplicatedCode +class MosaicRasterGDAL( + _uuid: Long, + raster: => Dataset, + path: String, + isTemp: Boolean, + parentPath: String, + driverShortName: String, + memSize: Long +) extends RasterWriter + with RasterCleaner { + + def getSpatialReference: SpatialReference = { + if (raster != null) { + raster.GetSpatialRef + } else { + val tmp = refresh() + val result = tmp.spatialRef + dispose(tmp) + result + } + } + + // Factory for creating CRS objects + protected val crsFactory: CRSFactory = new CRSFactory + + // Only use this with GDAL rasters + private val wsg84 = new osr.SpatialReference() + wsg84.ImportFromEPSG(4326) + wsg84.SetAxisMappingStrategy(osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + + /** + * @return + * The raster's driver short name. + */ + def getDriversShortName: String = driverShortName + + /** + * @return + * The raster's path on disk. Usually this is a parent file for the tile. + */ + def getParentPath: String = parentPath + + /** + * @return + * The diagonal size of a raster. + */ + def diagSize: Double = math.sqrt(xSize * xSize + ySize * ySize) + + /** @return Returns pixel x size. */ + def pixelXSize: Double = getGeoTransform(1) + + /** @return Returns pixel y size. */ + def pixelYSize: Double = getGeoTransform(5) + + /** @return Returns the origin x coordinate. */ + def originX: Double = getGeoTransform(0) + + /** @return Returns the origin y coordinate. */ + def originY: Double = getGeoTransform(3) + + /** @return Returns the max x coordinate. */ + def xMax: Double = originX + xSize * pixelXSize + + /** @return Returns the max y coordinate. */ + def yMax: Double = originY + ySize * pixelYSize + + /** @return Returns the min x coordinate. */ + def xMin: Double = originX + + /** @return Returns the min y coordinate. */ + def yMin: Double = originY + + /** @return Returns the diagonal size of a pixel. */ + def pixelDiagSize: Double = math.sqrt(pixelXSize * pixelXSize + pixelYSize * pixelYSize) + + /** @return Returns file extension. */ + def getRasterFileExtension: String = getRaster.GetDriver().GetMetadataItem("DMD_EXTENSION") + + /** @return Returns the raster's bands as a Seq. */ + def getBands: Seq[MosaicRasterBandGDAL] = (1 to numBands).map(getBand) + + /** + * Flushes the cache of the raster. This is needed to ensure that the + * raster is written to disk. This is needed for operations like + * RasterProject. + * @return + * Returns the raster object. + */ + def flushCache(): MosaicRasterGDAL = { + // Note: Do not wrap GDAL objects into Option + if (getRaster != null) getRaster.FlushCache() + this.destroy() + this.refresh() + } + + /** + * Opens a raster from a file system path. + * @param path + * The path to the raster file. + * @return + * A MosaicRaster object. + */ + def openRaster(path: String): Dataset = { + MosaicRasterGDAL.openRaster(path, Some(driverShortName)) + } + + /** + * @return + * Returns the raster's metadata as a Map. + */ + def metadata: Map[String, String] = { + Option(raster.GetMetadataDomainList()) + .map(_.toArray) + .map(domain => + domain + .map(domainName => + Option(raster.GetMetadata_Dict(domainName.toString)) + .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) + .getOrElse(Map.empty[String, String]) + ) + .reduceOption(_ ++ _) + .getOrElse(Map.empty[String, String]) + ) + .getOrElse(Map.empty[String, String]) + + } + + /** + * @return + * Returns the raster's subdatasets as a Map. + */ + def subdatasets: Map[String, String] = { + val dict = raster.GetMetadata_Dict("SUBDATASETS") + val subdatasetsMap = Option(dict) + .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) + .getOrElse(Map.empty[String, String]) + val keys = subdatasetsMap.keySet + keys.flatMap(key => + if (key.toUpperCase(Locale.ROOT).contains("NAME")) { + val path = subdatasetsMap(key) + val pieces = path.split(":") + Seq( + key -> pieces.last, + s"${pieces.last}_vsimem" -> path, + pieces.last -> s"${pieces.head}:$parentPath:${pieces.last}" + ) + } else Seq(key -> subdatasetsMap(key)) + ).toMap + } + + /** + * @return + * Returns the raster's SRID. This is the EPSG code of the raster's CRS. + */ + def SRID: Int = { + Try(crsFactory.readEpsgFromParameters(proj4String)) + .filter(_ != null) + .getOrElse("EPSG:0") + .split(":") + .last + .toInt + } + + /** + * @return + * Returns the raster's proj4 string. + */ + def proj4String: String = { + + try { + raster.GetSpatialRef.ExportToProj4 + } catch { + case _: Any => "" + } + } + + /** + * @param bandId + * The band index to read. + * @return + * Returns the raster's band as a MosaicRasterBand object. + */ + def getBand(bandId: Int): MosaicRasterBandGDAL = { + if (bandId > 0 && numBands >= bandId) { + new MosaicRasterBandGDAL(raster.GetRasterBand(bandId), bandId) + } else { + throw new ArrayIndexOutOfBoundsException() + } + } + + /** + * @return + * Returns the raster's number of bands. + */ + def numBands: Int = raster.GetRasterCount() + + // noinspection ZeroIndexToHead + /** + * @return + * Returns the raster's extent as a Seq(xmin, ymin, xmax, ymax). + */ + def extent: Seq[Double] = { + val minX = getGeoTransform(0) + val maxY = getGeoTransform(3) + val maxX = minX + getGeoTransform(1) * xSize + val minY = maxY + getGeoTransform(5) * ySize + Seq(minX, minY, maxX, maxY) + } + + /** + * @return + * Returns x size of the raster. + */ + def xSize: Int = raster.GetRasterXSize + + /** + * @return + * Returns y size of the raster. + */ + def ySize: Int = raster.GetRasterYSize + + /** + * @return + * Returns the raster's geotransform as a Seq. + */ + def getGeoTransform: Array[Double] = raster.GetGeoTransform() + + /** + * @return + * Underlying GDAL raster object. + */ + def getRaster: Dataset = this.raster + + /** + * @return + * Returns the raster's spatial reference. + */ + def spatialRef: SpatialReference = raster.GetSpatialRef() + + /** + * Applies a function to each band of the raster. + * @param f + * The function to apply. + * @return + * Returns a Seq of the results of the function. + */ + def transformBands[T](f: => MosaicRasterBandGDAL => T): Seq[T] = for (i <- 1 to numBands) yield f(getBand(i)) + + /** + * @return + * Returns MosaicGeometry representing bounding box of the raster. + */ + def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = wsg84): MosaicGeometry = { + val gt = getGeoTransform + + val sourceCRS = spatialRef + val transform = new osr.CoordinateTransformation(sourceCRS, destCRS) + + val bbox = geometryAPI.geometry( + Seq( + Seq(gt(0), gt(3)), + Seq(gt(0) + gt(1) * xSize, gt(3)), + Seq(gt(0) + gt(1) * xSize, gt(3) + gt(5) * ySize), + Seq(gt(0), gt(3) + gt(5) * ySize) + ).map(geometryAPI.fromCoords), + POLYGON + ) + + val geom1 = org.gdal.ogr.ogr.CreateGeometryFromWkb(bbox.toWKB) + geom1.Transform(transform) + + geometryAPI.geometry(geom1.ExportToWkb(), "WKB") + } + + /** + * @return + * True if the raster is empty, false otherwise. May be expensive to + * compute since it requires reading the raster and computing statistics. + */ + def isEmpty: Boolean = { + import org.json4s._ + import org.json4s.jackson.JsonMethods._ + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + + val vector = new JVector[String]() + vector.add("-stats") + vector.add("-json") + val infoOptions = new InfoOptions(vector) + val gdalInfo = GDALInfo(raster, infoOptions) + val json = parse(gdalInfo).extract[Map[String, Any]] + + if (json.contains("STATISTICS_VALID_PERCENT")) { + json("STATISTICS_VALID_PERCENT").asInstanceOf[Double] == 0.0 + } else if (subdatasets.nonEmpty) { + false + } else { + getBandStats.values.map(_.getOrElse("mean", 0.0)).forall(_ == 0.0) + } + } + + /** + * @return + * Returns the raster's path. + */ + def getPath: String = path + + /** + * @return + * Returns the raster for a given cell ID. Used for tessellation. + */ + def getRasterForCell(cellID: Long, indexSystem: IndexSystem, geometryAPI: GeometryAPI): MosaicRasterGDAL = { + val cellGeom = indexSystem.indexToGeometry(cellID, geometryAPI) + val geomCRS = indexSystem.osrSpatialRef + RasterClipByVector.clip(this, cellGeom, geomCRS, geometryAPI) + } + + /** + * Cleans up the raster driver and references. + * + * Unlinks the raster file. After this operation the raster object is no + * longer usable. To be used as last step in expression after writing to + * bytes. + */ + def cleanUp(): Unit = { + val isInMem = path.contains("/vsimem/") + val isSubdataset = PathUtils.isSubdataset(path) + val filePath = if (isSubdataset) PathUtils.fromSubdatasetPath(path) else path + val pamFilePath = s"$filePath.aux.xml" + if (isInMem) { + // Delete the raster from the virtual file system + // Note that Unlink is not the same as Delete + // Unlink may leave PAM residuals + Try(gdal.GetDriverByName(driverShortName).Delete(path)) + Try(gdal.GetDriverByName(driverShortName).Delete(filePath)) + Try(gdal.Unlink(path)) + Try(gdal.Unlink(filePath)) + Try(gdal.Unlink(pamFilePath)) + } + if (isTemp) { + Try(gdal.GetDriverByName(driverShortName).Delete(path)) + Try(Files.deleteIfExists(Paths.get(path))) + Try(Files.deleteIfExists(Paths.get(filePath))) + Try(Files.deleteIfExists(Paths.get(pamFilePath))) + val tmpParent = Paths.get(path).getParent + if (tmpParent != null) Try(Files.deleteIfExists(tmpParent)) + } + } + + /** + * @note + * If memory size is -1 this will destroy the raster and you will need to + * refresh it to use it again. + * @return + * Returns the amount of memory occupied by the file in bytes. + */ + def getMemSize: Long = { + if (memSize == -1) { + if (PathUtils.isInMemory(path)) { + val tempPath = PathUtils.createTmpFilePath(this.uuid.toString, GDAL.getExtension(driverShortName)) + writeToPath(tempPath) + val size = Files.size(Paths.get(tempPath)) + Files.delete(Paths.get(tempPath)) + size + } else { + Files.size(Paths.get(path)) + } + } else { + memSize + } + + } + + /** + * Writes a raster to a file system path. This method disposes of the + * raster object. If the raster is needed again, load it from the path. + * + * @param path + * The path to the raster file. + * @return + * A boolean indicating if the write was successful. + */ + def writeToPath(path: String, dispose: Boolean = true): String = { + val isInMem = PathUtils.isInMemory(getPath) + if (isInMem) { + val driver = raster.GetDriver() + val ds = driver.CreateCopy(path, this.flushCache().getRaster) + ds.FlushCache() + ds.delete() + } else { + Files.copy(Paths.get(getPath), Paths.get(path), StandardCopyOption.REPLACE_EXISTING).toString + } + if (dispose) RasterCleaner.dispose(this) + path + } + + /** + * Writes a raster to a byte array. + * + * @return + * A byte array containing the raster data. + */ + def writeToBytes(dispose: Boolean = true): Array[Byte] = { + if (PathUtils.isInMemory(path)) { + // Create a temporary directory to store the raster + // This is needed because Files cannot read from /vsimem/ directly + val path = PathUtils.createTmpFilePath(uuid.toString, GDAL.getExtension(driverShortName)) + writeToPath(path, dispose) + val byteArray = Files.readAllBytes(Paths.get(path)) + Files.delete(Paths.get(path)) + byteArray + } else { + val byteArray = Files.readAllBytes(Paths.get(path)) + if (dispose) RasterCleaner.dispose(this) + byteArray + } + } + + /** + * Destroys the raster object. After this operation the raster object is no + * longer usable. If the raster is needed again, use the refresh method. + */ + def destroy(): Unit = { + val raster = getRaster + if (raster != null) { + raster.FlushCache() + raster.delete() + } + } + + /** + * Refreshes the raster object. This is needed after writing to a file + * system path. GDAL only properly writes to a file system path if the + * raster object is destroyed. After refresh operation the raster object is + * usable again. + */ + def refresh(): MosaicRasterGDAL = { + new MosaicRasterGDAL(uuid, openRaster(path), path, isTemp, parentPath, driverShortName, memSize) + } + + /** + * @return + * Returns the raster's UUID. + */ + def uuid: Long = _uuid + + /** + * @return + * Returns the raster's size. + */ + def getDimensions: (Int, Int) = (xSize, ySize) + + /** + * @return + * Returns the raster's band statistics. + */ + def getBandStats: Map[Int, Map[String, Double]] = { + (1 to numBands) + .map(i => { + val band = raster.GetRasterBand(i) + val min = Array.ofDim[Double](1) + val max = Array.ofDim[Double](1) + val mean = Array.ofDim[Double](1) + val stddev = Array.ofDim[Double](1) + band.GetStatistics(true, true, min, max, mean, stddev) + i -> Map( + "min" -> min(0), + "max" -> max(0), + "mean" -> mean(0), + "stddev" -> stddev(0) + ) + }) + .toMap + } + + /** + * @param subsetName + * The name of the subdataset to get. + * @return + * Returns the raster's subdataset with given name. + */ + def getSubdataset(subsetName: String): MosaicRasterGDAL = { + subdatasets + val path = Option(raster.GetMetadata_Dict("SUBDATASETS")) + .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) + .getOrElse(Map.empty[String, String]) + .values + .find(_.toUpperCase(Locale.ROOT).endsWith(subsetName.toUpperCase(Locale.ROOT))) + .getOrElse(throw new Exception(s"Subdataset $subsetName not found")) + val ds = openRaster(path) + // Avoid costly IO to compute MEM size here + // It will be available when the raster is serialized for next operation + // If value is needed then it will be computed when getMemSize is called + MosaicRasterGDAL(ds, path, isTemp = false, parentPath, driverShortName, -1) + } + +} + +//noinspection ZeroIndexToHead +/** Companion object for MosaicRasterGDAL Implements RasterReader APIs */ +object MosaicRasterGDAL extends RasterReader { + + /** + * Opens a raster from a file system path with a given driver. + * @param driverShortName + * The driver short name to use. If None, then GDAL will try to identify + * the driver from the file extension + * @param path + * The path to the raster file. + * @return + * A MosaicRaster object. + */ + def openRaster(path: String, driverShortName: Option[String]): Dataset = { + driverShortName match { + case Some(driverShortName) => + val drivers = new JVector[String]() + drivers.add(driverShortName) + gdal.OpenEx(path, GA_ReadOnly, drivers) + case None => gdal.Open(path, GA_ReadOnly) + } + } + + /** + * Identifies the driver of a raster from a file system path. + * @param parentPath + * The path to the raster file. + * @return + * A string representing the driver short name. + */ + def identifyDriver(parentPath: String): String = { + val isSubdataset = PathUtils.isSubdataset(parentPath) + val path = PathUtils.getCleanPath(parentPath, parentPath.endsWith(".zip")) + val readPath = + if (isSubdataset) PathUtils.getSubdatasetPath(path) + else PathUtils.getZipPath(path) + val driver = gdal.IdentifyDriverEx(readPath) + val driverShortName = driver.getShortName + driverShortName + } + + /** + * Creates a MosaicRaster object from a GDAL raster object. + * @param dataset + * The GDAL raster object. + * @param path + * The path to the raster file in vsimem or in temp dir. + * @param isTemp + * A boolean indicating if the raster is temporary. + * @param parentPath + * The path to the file of the raster on disk. + * @param driverShortName + * The driver short name of the raster. + * @param memSize + * The size of the raster in memory. + * @return + * A MosaicRaster object. + */ + def apply( + dataset: => Dataset, + path: String, + isTemp: Boolean, + parentPath: String, + driverShortName: String, + memSize: Long + ): MosaicRasterGDAL = { + val uuid = Murmur3.hash64(path.getBytes()) + val raster = new MosaicRasterGDAL(uuid, dataset, path, isTemp, parentPath, driverShortName, memSize) + raster + } + + /** + * Creates a MosaicRaster object from a file system path. + * @param path + * The path to the raster file. + * @param isTemp + * A boolean indicating if the raster is temporary. + * @param parentPath + * The path to the file of the raster on disk. + * @param driverShortName + * The driver short name of the raster. + * @param memSize + * The size of the raster in memory. + * @return + * A MosaicRaster object. + */ + def apply(path: String, isTemp: Boolean, parentPath: String, driverShortName: String, memSize: Long): MosaicRasterGDAL = { + val uuid = Murmur3.hash64(path.getBytes()) + val dataset = openRaster(path, Some(driverShortName)) + val raster = new MosaicRasterGDAL(uuid, dataset, path, isTemp, parentPath, driverShortName, memSize) + raster + } + + /** + * Reads a raster from a file system path. Reads a subdataset if the path + * is to a subdataset. + * + * @example + * Raster: path = "file:///path/to/file.tif" Subdataset: path = + * "file:///path/to/file.tif:subdataset" + * @param inPath + * The path to the raster file. + * @return + * A MosaicRaster object. + */ + override def readRaster(inPath: String, parentPath: String): MosaicRasterGDAL = { + val isSubdataset = PathUtils.isSubdataset(inPath) + val localCopy = PathUtils.copyToTmp(inPath) + val path = PathUtils.getCleanPath(localCopy, localCopy.endsWith(".zip")) + val uuid = Murmur3.hash64(path.getBytes()) + val readPath = + if (isSubdataset) PathUtils.getSubdatasetPath(path) + else PathUtils.getZipPath(path) + val dataset = openRaster(readPath, None) + val driverShortName = dataset.GetDriver().getShortName + + // Avoid costly IO to compute MEM size here + // It will be available when the raster is serialized for next operation + // If value is needed then it will be computed when getMemSize is called + // We cannot just use memSize value of the parent due to the fact that the raster could be a subdataset + val raster = new MosaicRasterGDAL(uuid, dataset, path, true, parentPath, driverShortName, -1) + raster + } + + /** + * Reads a raster from a byte array. + * @param contentBytes + * The byte array containing the raster data. + * @param driverShortName + * The driver short name of the raster. + * @return + * A MosaicRaster object. + */ + override def readRaster(contentBytes: => Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = { + if (Option(contentBytes).isEmpty || contentBytes.isEmpty) { + new MosaicRasterGDAL(-1L, null, "", false, parentPath, "", -1) + } else { + // This is a temp UUID for purposes of reading the raster through GDAL from memory + // The stable UUID is kept in metadata of the raster + val uuid = Murmur3.hash64(UUID.randomUUID().toString.getBytes()) + val extension = GDAL.getExtension(driverShortName) + val virtualPath = s"/vsimem/$uuid.$extension" + gdal.FileFromMemBuffer(virtualPath, contentBytes) + // Try reading as a virtual file, if that fails, read as a zipped virtual file + val dataset = Option( + openRaster(virtualPath, Some(driverShortName)) + ).getOrElse({ + // Unlink the previous virtual file + gdal.Unlink(virtualPath) + // Create a virtual zip file + val virtualZipPath = s"/vsimem/$uuid.zip" + val zippedPath = s"/vsizip/$virtualZipPath" + gdal.FileFromMemBuffer(virtualZipPath, contentBytes) + openRaster(zippedPath, Some(driverShortName)) + }) + val raster = new MosaicRasterGDAL(uuid, dataset, virtualPath, false, parentPath, driverShortName, contentBytes.length) + raster + } + } + + /** + * Reads a raster band from a file system path. Reads a subdataset band if + * the path is to a subdataset. + * + * @example + * Raster: path = "file:///path/to/file.tif" Subdataset: path = + * "file:///path/to/file.tif:subdataset" + * @param path + * The path to the raster file. + * @param bandIndex + * The band index to read. + * @return + * A MosaicRaster object. + */ + override def readBand(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = { + val raster = readRaster(path, parentPath) + // TODO: Raster and Band are coupled, this can cause a pointer leak + raster.getBand(bandIndex) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterCleaner.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterCleaner.scala new file mode 100644 index 000000000..9a8be672a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterCleaner.scala @@ -0,0 +1,65 @@ +package com.databricks.labs.mosaic.core.raster.io + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import org.gdal.gdal.Dataset + +/** Trait for cleaning up raster objects. */ +trait RasterCleaner { + + /** + * Cleans up the rasters from memory or from temp directory. Cleaning up is + * destructive and should only be done when the raster is no longer needed. + */ + def cleanUp(): Unit + + /** + * Destroys the raster object. Rasters can be recreated from file system + * path or from content bytes after destroy. + */ + def destroy(): Unit + +} + +object RasterCleaner { + + /** + * Flushes the cache and deletes the dataset. Note that this does not + * unlink virtual files. For that, use gdal.unlink(path). + * + * @param ds + * The dataset to destroy. + */ + def destroy(ds: => Dataset): Unit = { + if (ds != null) { + try { + ds.FlushCache() + // Not to be confused with physical deletion, this is just deletes jvm object + ds.delete() + } catch { + case _: Any => () + } + } + } + + /** + * Destroys and cleans up the raster object. This is a destructive operation and should + * only be done when the raster is no longer needed. + * + * @param raster + * The raster to destroy and clean up. + */ + def dispose(raster: => Any): Unit = { + raster match { + case r: MosaicRasterGDAL => + r.destroy() + r.cleanUp() + case rt: MosaicRasterTile => + rt.getRaster.destroy() + rt.getRaster.cleanUp() + // NOOP for simpler code handling in expressions, removes need for repeated if/else + case _ => () + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala new file mode 100644 index 000000000..65ef016cc --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala @@ -0,0 +1,64 @@ +package com.databricks.labs.mosaic.core.raster.io + +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, MosaicRasterGDAL} +import org.apache.spark.internal.Logging + +/** + * RasterReader is a trait that defines the interface for reading raster data + * from a file system path. It is used by the RasterAPI to read raster and + * raster band data. + * @note + * For subdatasets the path should be the path to the subdataset and not to + * the file. + */ +trait RasterReader extends Logging { + + /** + * Reads a raster from a file system path. Reads a subdataset if the path + * is to a subdataset. + * + * @example + * Raster: path = "/path/to/file.tif" Subdataset: path = + * "FORMAT:/path/to/file.tif:subdataset" + * @param path + * The path to the raster file. + * @param parentPath + * The path of the parent raster file. + * @return + * A MosaicRaster object. + */ + def readRaster(path: String, parentPath: String): MosaicRasterGDAL + + /** + * Reads a raster from an in memory buffer. Use the buffer bytes to produce + * a uuid of the raster. + * + * @param contentBytes + * The file bytes. + * @param parentPath + * The path of the parent raster file. + * @param driverShortName + * The driver short name of the raster file. + * @return + * A MosaicRaster object. + */ + def readRaster(contentBytes: => Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL + + /** + * Reads a raster band from a file system path. Reads a subdataset band if + * the path is to a subdataset. + * @example + * Raster: path = "/path/to/file.tif" Subdataset: path = + * "FORMAT:/path/to/file.tif:subdataset" + * @param path + * The path to the raster file. + * @param bandIndex + * The band index to read. + * @param parentPath + * The path of the parent raster file. + * @return + * A MosaicRaster object. + */ + def readBand(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterWriter.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterWriter.scala new file mode 100644 index 000000000..fa6848d65 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterWriter.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.core.raster.io + +/** + * RasterWriter is a trait that defines the interface for writing raster data + * to a file system path or as bytes. It is used by the RasterAPI to write + * rasters. + */ +trait RasterWriter { + + /** + * Writes a raster to a file system path. + * + * @param path + * The path to the raster file. + * @param destroy + * A boolean indicating if the raster should be destroyed after writing. + * @return + * A boolean indicating if the write was successful. + */ + def writeToPath(path: String, destroy: Boolean = true): String + + /** + * Writes a raster to a byte array. + * + * @param destroy + * A boolean indicating if the raster should be destroyed after writing. + * @return + * A byte array containing the raster data. + */ + def writeToBytes(destroy: Boolean = true): Array[Byte] + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala new file mode 100644 index 000000000..14647cebb --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala @@ -0,0 +1,40 @@ +package com.databricks.labs.mosaic.core.raster.operator + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.operator.pixel.PixelCombineRasters + +/** CombineAVG is a helper object for combining rasters using average. */ +object CombineAVG { + + /** + * Creates a new raster using average of input rasters. The average is + * computed as (sum of all rasters) / (number of rasters). It is applied to + * all bands of the input rasters. Please note the data type of the output + * raster is double. + * + * @param rasters + * The rasters to compute result for. + * + * @return + * A new raster with average of input rasters. + */ + def compute(rasters: => Seq[MosaicRasterGDAL]): MosaicRasterGDAL = { + val pythonFunc = """ + |import numpy as np + |import sys + | + |def average(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize,raster_ysize, buf_radius, gt, **kwargs): + | div = np.zeros(in_ar[0].shape) + | for i in range(len(in_ar)): + | div += (in_ar[i] != 0) + | div[div == 0] = 1 + | + | y = np.sum(in_ar, axis = 0, dtype = 'float64') + | y = y / div + | + | np.clip(y,0, sys.float_info.max, out = out_ar) + |""".stripMargin + PixelCombineRasters.combine(rasters, pythonFunc, "average") + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala new file mode 100644 index 000000000..e907d1eb7 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.core.raster.operator + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALCalc +import com.databricks.labs.mosaic.utils.PathUtils + +/** NDVI is a helper object for computing NDVI. */ +object NDVI { + + /** + * Computes NDVI from a MosaicRasterGDAL. + * + * @param raster + * MosaicRasterGDAL to compute NDVI from. + * @param redIndex + * Index of the red band. + * @param nirIndex + * Index of the near-infrared band. + * @return + * MosaicRasterGDAL with NDVI computed. + */ + def compute(raster: => MosaicRasterGDAL, redIndex: Int, nirIndex: Int): MosaicRasterGDAL = { + val ndviPath = PathUtils.createTmpFilePath(raster.uuid.toString, GDAL.getExtension(raster.getDriversShortName)) + // noinspection ScalaStyle + val gdalCalcCommand = + s"""gdal_calc -A ${raster.getPath} --A_band=$redIndex -B ${raster.getPath} --B_band=$nirIndex --outfile=$ndviPath --calc="(B-A)/(B+A)"""" + + GDALCalc.executeCalc(gdalCalcCommand, ndviPath) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala new file mode 100644 index 000000000..2c40ab81c --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala @@ -0,0 +1,59 @@ +package com.databricks.labs.mosaic.core.raster.operator.clip + +import com.databricks.labs.mosaic.core.geometry.MosaicGeometry +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp +import com.databricks.labs.mosaic.utils.PathUtils +import org.gdal.osr.SpatialReference + +/** + * RasterClipByVector is an object that defines the interface for clipping a + * raster by a vector geometry. + */ +object RasterClipByVector { + + /** + * Clips a raster by a vector geometry. The method handles all the + * abstractions over GDAL Warp. It uses CUTLINE_ALL_TOUCHED=TRUE to ensure + * that all pixels that touch the geometry are included. This will avoid + * the issue of having a pixel that is half in and half out of the + * geometry, important for tessellation. It also uses COMPRESS=DEFLATE to + * ensure that the output is compressed. The method also uses the geometry + * API to generate a shapefile that is used to clip the raster. The + * shapefile is deleted after the clip is complete. + * + * @param raster + * The raster to clip. + * @param geometry + * The geometry to clip by. + * @param geomCRS + * The geometry CRS. + * @param geometryAPI + * The geometry API. + * @return + * A clipped raster. + */ + def clip(raster: => MosaicRasterGDAL, geometry: MosaicGeometry, geomCRS: SpatialReference, geometryAPI: GeometryAPI): MosaicRasterGDAL = { + val rasterCRS = raster.getSpatialReference + val outShortName = raster.getDriversShortName + val geomSrcCRS = if (geomCRS == null ) rasterCRS else geomCRS + + val resultFileName = PathUtils.createTmpFilePath(raster.uuid.toString, GDAL.getExtension(outShortName)) + + val shapeFileName = VectorClipper.generateClipper(geometry, geomSrcCRS, rasterCRS, geometryAPI) + + val result = GDALWarp.executeWarp( + resultFileName, + isTemp = true, + Seq(raster), + command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -of $outShortName -cutline $shapeFileName -crop_to_cutline -co COMPRESS=DEFLATE -dstalpha" + ) + + VectorClipper.cleanUpClipper(shapeFileName) + + result + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/VectorClipper.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/VectorClipper.scala new file mode 100644 index 000000000..7c7ea58f2 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/VectorClipper.scala @@ -0,0 +1,96 @@ +package com.databricks.labs.mosaic.core.raster.operator.clip + +import com.databricks.labs.mosaic.core.geometry.MosaicGeometry +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import org.gdal.gdal.gdal +import org.gdal.ogr.ogrConstants.OFTInteger +import org.gdal.ogr.{DataSource, Feature, ogr} +import org.gdal.osr.SpatialReference + +import scala.util.Try + +/** + * VectorClipper is an object that defines the interface for managing a clipper + * shapefile used for clipping a raster by a vector geometry. + */ +object VectorClipper { + + /** + * Generates an in memory shapefile that is used to clip a raster. + * @return + * The shapefile name. + */ + private def getShapefileName: String = { + val uuid = java.util.UUID.randomUUID() + val shapeFileName = s"/vsimem/${uuid.toString}.shp" + shapeFileName + } + + /** + * Generates a shapefile data source that is used to clip a raster. + * @param fileName + * The shapefile data source. + * @return + * The shapefile. + */ + private def getShapefile(fileName: String): DataSource = { + val shpDriver = ogr.GetDriverByName("ESRI Shapefile") + val shpDataSource = shpDriver.CreateDataSource(fileName) + shpDataSource + } + + /** + * Generates a clipper shapefile that is used to clip a raster. The + * shapefile is flushed to disk and then the data source is deleted. The + * shapefile is accessed by gdalwarp by file name. + * @note + * The shapefile is generated in memory. + * + * @param geometry + * The geometry to clip by. + * @param srcCrs + * The geometry CRS. + * @param dstCrs + * The raster CRS. + * @param geometryAPI + * The geometry API. + * @return + * The shapefile name. + */ + def generateClipper(geometry: MosaicGeometry, srcCrs: SpatialReference, dstCrs: SpatialReference, geometryAPI: GeometryAPI): String = { + val shapeFileName = getShapefileName + var shpDataSource = getShapefile(shapeFileName) + + val projectedGeom = geometry.osrTransformCRS(srcCrs, dstCrs, geometryAPI) + + val geom = ogr.CreateGeometryFromWkb(projectedGeom.toWKB) + + val geomLayer = shpDataSource.CreateLayer("geom") + + val idField = new org.gdal.ogr.FieldDefn("id", OFTInteger) + geomLayer.CreateField(idField) + val featureDefn = geomLayer.GetLayerDefn() + val feature = new Feature(featureDefn) + feature.SetGeometry(geom) + feature.SetField("id", 1) + geomLayer.CreateFeature(feature) + + shpDataSource.FlushCache() + shpDataSource.delete() + shpDataSource = null + + shapeFileName + } + + /** + * Cleans up the clipper shapefile. + * + * @param shapeFileName + * The shapefile to clean up. + */ + def cleanUpClipper(shapeFileName: String): Unit = { + Try(ogr.GetDriverByName("ESRI Shapefile").DeleteDataSource(shapeFileName)) + Try(gdal.Unlink(shapeFileName)) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala new file mode 100644 index 000000000..c3b57d5f3 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala @@ -0,0 +1,33 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import org.gdal.gdal.{BuildVRTOptions, gdal} + +/** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */ +object GDALBuildVRT { + + /** + * Executes the GDAL BuildVRT command. + * + * @param outputPath + * The output path of the VRT file. + * @param isTemp + * Whether the output is a temp file. + * @param rasters + * The rasters to build the VRT from. + * @param command + * The GDAL BuildVRT command. + * @return + * A MosaicRaster object. + */ + def executeVRT(outputPath: String, isTemp: Boolean, rasters: => Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { + require(command.startsWith("gdalbuildvrt"), "Not a valid GDAL Build VRT command.") + val vrtOptionsVec = OperatorOptions.parseOptions(command) + val vrtOptions = new BuildVRTOptions(vrtOptionsVec) + val result = gdal.BuildVRT(outputPath, rasters.map(_.getRaster).toArray, vrtOptions) + // TODO: Figure out multiple parents, should this be an array? + // VRT files are just meta files, mem size doesnt make much sense so we keep -1 + MosaicRasterGDAL(result, outputPath, isTemp, rasters.head.getParentPath, "VRT", -1).flushCache() + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala new file mode 100644 index 000000000..cd4d92a97 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala @@ -0,0 +1,29 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL + +/** GDALCalc is a helper object for executing GDAL Calc commands. */ +object GDALCalc { + + val gdal_calc = "/usr/lib/python3/dist-packages/osgeo_utils/gdal_calc.py" + + /** + * Executes the GDAL Calc command. + * @param gdalCalcCommand + * The GDAL Calc command to execute. + * @param resultPath + * The path to the result. + * @return + * Returns the result as a [[MosaicRasterGDAL]]. + */ + def executeCalc(gdalCalcCommand: String, resultPath: String): MosaicRasterGDAL = { + require(gdalCalcCommand.startsWith("gdal_calc"), "Not a valid GDAL Calc command.") + import sys.process._ + val toRun = gdalCalcCommand.replace("gdal_calc", gdal_calc) + s"sudo python3 $toRun".!! + val result = GDAL.raster(resultPath, resultPath) + result + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala new file mode 100644 index 000000000..0b44e006c --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala @@ -0,0 +1,34 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import org.gdal.gdal.{TranslateOptions, gdal} + +import java.nio.file.{Files, Paths} + +/** GDALTranslate is a wrapper for the GDAL Translate command. */ +object GDALTranslate { + + /** + * Executes the GDAL Translate command. + * + * @param outputPath + * The output path of the translated file. + * @param isTemp + * Whether the output is a temp file. + * @param raster + * The raster to translate. + * @param command + * The GDAL Translate command. + * @return + * A MosaicRaster object. + */ + def executeTranslate(outputPath: String, isTemp: Boolean, raster: => MosaicRasterGDAL, command: String): MosaicRasterGDAL = { + require(command.startsWith("gdal_translate"), "Not a valid GDAL Translate command.") + val translateOptionsVec = OperatorOptions.parseOptions(command) + val translateOptions = new TranslateOptions(translateOptionsVec) + val result = gdal.Translate(outputPath, raster.getRaster, translateOptions) + val size = Files.size(Paths.get(outputPath)) + MosaicRasterGDAL(result, outputPath, isTemp, raster.getParentPath, raster.getDriversShortName, size).flushCache() + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala new file mode 100644 index 000000000..e8393d728 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala @@ -0,0 +1,44 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import org.gdal.gdal.{WarpOptions, gdal} + +import java.nio.file.{Files, Paths} + +/** GDALWarp is a wrapper for the GDAL Warp command. */ +object GDALWarp { + + /** + * Executes the GDAL Warp command. + * + * @param outputPath + * The output path of the warped file. + * @param isTemp + * Whether the output is a temp file. + * @param rasters + * The rasters to warp. + * @param command + * The GDAL Warp command. + * @return + * A MosaicRaster object. + */ + def executeWarp(outputPath: String, isTemp: Boolean, rasters: => Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { + require(command.startsWith("gdalwarp"), "Not a valid GDAL Warp command.") + // Test: gdal.ParseCommandLine(command) + val warpOptionsVec = OperatorOptions.parseOptions(command) + val warpOptions = new WarpOptions(warpOptionsVec) + val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions) + // TODO: Figure out multiple parents, should this be an array? + // Format will always be the same as the first raster + val size = Files.size(Paths.get(outputPath)) + MosaicRasterGDAL( + result, + outputPath, + isTemp, + rasters.head.getParentPath, + rasters.head.getDriversShortName, + size + ).flushCache() + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala new file mode 100644 index 000000000..b1529d3e7 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala @@ -0,0 +1,21 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +/** OperatorOptions is a helper object for parsing GDAL command options. */ +object OperatorOptions { + + /** + * Parses the options from a GDAL command. + * + * @param command + * The GDAL command. + * @return + * A vector of options. + */ + def parseOptions(command: String): java.util.Vector[String] = { + val args = command.split(" ") + val optionsVec = new java.util.Vector[String]() + args.drop(1).foreach(optionsVec.add) + optionsVec + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala new file mode 100644 index 000000000..2bd605445 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala @@ -0,0 +1,86 @@ +package com.databricks.labs.mosaic.core.raster.operator.merge + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.utils.PathUtils + +/** MergeBands is a helper object for merging raster bands. */ +object MergeBands { + + /** + * Merges the raster bands into a single raster. + * + * @param rasters + * The rasters to merge. + * @param resampling + * The resampling method to use. + * @return + * A MosaicRaster object. + */ + def merge(rasters: => Seq[MosaicRasterGDAL], resampling: String): MosaicRasterGDAL = { + val rasterUUID = java.util.UUID.randomUUID.toString + val outShortName = rasters.head.getRaster.GetDriver.getShortName + + val vrtPath = PathUtils.createTmpFilePath(rasterUUID, "vrt") + val rasterPath = PathUtils.createTmpFilePath(rasterUUID, "tif") + + val vrtRaster = GDALBuildVRT.executeVRT( + vrtPath, + isTemp = true, + rasters, + command = s"gdalbuildvrt -separate -resolution highest" + ) + + val result = GDALTranslate.executeTranslate( + rasterPath, + isTemp = true, + vrtRaster, + command = s"gdal_translate -r $resampling -of $outShortName -co COMPRESS=DEFLATE" + ) + + dispose(vrtRaster) + + result + } + + /** + * Merges the raster bands into a single raster. This method allows for + * custom pixel sizes. + * + * @param rasters + * The rasters to merge. + * @param pixel + * The pixel size to use. + * @param resampling + * The resampling method to use. + * @return + * A MosaicRaster object. + */ + def merge(rasters: => Seq[MosaicRasterGDAL], pixel: (Double, Double), resampling: String): MosaicRasterGDAL = { + val rasterUUID = java.util.UUID.randomUUID.toString + val outShortName = rasters.head.getRaster.GetDriver.getShortName + + val vrtPath = PathUtils.createTmpFilePath(rasterUUID, "vrt") + val rasterPath = PathUtils.createTmpFilePath(rasterUUID, "tif") + + val vrtRaster = GDALBuildVRT.executeVRT( + vrtPath, + isTemp = true, + rasters, + command = s"gdalbuildvrt -separate -resolution user -tr ${pixel._1} ${pixel._2}" + ) + + val result = GDALTranslate.executeTranslate( + rasterPath, + isTemp = true, + vrtRaster, + command = s"gdalwarp -r $resampling -of $outShortName -co COMPRESS=DEFLATE -overwrite" + ) + + dispose(vrtRaster) + + result + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala new file mode 100644 index 000000000..08adf2053 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala @@ -0,0 +1,46 @@ +package com.databricks.labs.mosaic.core.raster.operator.merge + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.utils.PathUtils + +/** MergeRasters is a helper object for merging rasters. */ +object MergeRasters { + + /** + * Merges the rasters into a single raster. + * + * @param rasters + * The rasters to merge. + * @return + * A MosaicRaster object. + */ + def merge(rasters: => Seq[MosaicRasterGDAL]): MosaicRasterGDAL = { + val rasterUUID = java.util.UUID.randomUUID.toString + val outShortName = rasters.head.getRaster.GetDriver.getShortName + + val vrtPath = PathUtils.createTmpFilePath(rasterUUID, "vrt") + val rasterPath = PathUtils.createTmpFilePath(rasterUUID, "tif") + + val vrtRaster = GDALBuildVRT.executeVRT( + vrtPath, + isTemp = true, + rasters, + command = s"gdalbuildvrt -resolution highest" + ) + + val result = GDALTranslate.executeTranslate( + rasterPath, + isTemp = true, + vrtRaster, + command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + ) + + dispose(vrtRaster) + + result + } + + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala new file mode 100644 index 000000000..4462d01a3 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala @@ -0,0 +1,93 @@ +package com.databricks.labs.mosaic.core.raster.operator.pixel + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.utils.PathUtils + +import java.io.File +import scala.xml.{Elem, UnprefixedAttribute, XML} + +/** MergeRasters is a helper object for merging rasters. */ +object PixelCombineRasters { + + /** + * Merges the rasters into a single raster. + * + * @param rasters + * The rasters to merge. + * @return + * A MosaicRaster object. + */ + def combine(rasters: => Seq[MosaicRasterGDAL], pythonFunc: String, pythonFuncName: String): MosaicRasterGDAL = { + val rasterUUID = java.util.UUID.randomUUID.toString + val outShortName = rasters.head.getRaster.GetDriver.getShortName + + val vrtPath = PathUtils.createTmpFilePath(rasterUUID, "vrt") + val rasterPath = PathUtils.createTmpFilePath(rasterUUID, "tif") + + val vrtRaster = GDALBuildVRT.executeVRT( + vrtPath, + isTemp = true, + rasters, + command = s"gdalbuildvrt -resolution highest" + ) + + addPixelFunction(vrtPath, pythonFunc, pythonFuncName) + + val result = GDALTranslate.executeTranslate( + rasterPath, + isTemp = true, + vrtRaster, + command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + ) + + dispose(vrtRaster) + + result + } + + /** + * Adds a pixel function to the VRT file. The pixel function is a Python + * function that is applied to each pixel in the VRT file. The pixel + * function is set for all bands in the VRT file. + * + * @param vrtPath + * The path to the VRT file. + * @param pixFuncCode + * The pixel function code. + * @param pixFuncName + * The pixel function name. + */ + def addPixelFunction(vrtPath: String, pixFuncCode: String, pixFuncName: String): Unit = { + val pixFuncTypeEl = {pixFuncName} + val pixFuncLangEl = Python + val pixFuncCodeEl = + {scala.xml.Unparsed(s"")} + + + val vrtContent = XML.loadFile(new File(vrtPath)) + val vrtWithPixFunc = vrtContent match { + case body @ Elem(_, _, _, _, child @ _*) => body.copy( + child = child.map { + case el @ Elem(_, "VRTRasterBand", _, _, child @ _*) => el + .asInstanceOf[Elem] + .copy( + child = Seq(pixFuncTypeEl, pixFuncLangEl, pixFuncCodeEl) ++ child, + attributes = el + .asInstanceOf[Elem] + .attributes + .append( + new UnprefixedAttribute("subClass", "VRTDerivedRasterBand", scala.xml.Null) + ) + ) + case el => el + } + ) + } + + XML.save(vrtPath, vrtWithPixFunc) + + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala new file mode 100644 index 000000000..a091c4495 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala @@ -0,0 +1,47 @@ +package com.databricks.labs.mosaic.core.raster.operator.proj + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp +import com.databricks.labs.mosaic.utils.PathUtils +import org.gdal.osr.SpatialReference + +/** + * RasterProject is an object that defines the interface for projecting a + * raster. + */ +object RasterProject { + + /** + * Projects a raster to a new CRS. The method handles all the abstractions + * over GDAL Warp. It uses cubic resampling to ensure that the output is + * smooth. It also uses COMPRESS=DEFLATE to ensure that the output is + * compressed. + * + * @param raster + * The raster to project. + * @param destCRS + * The destination CRS. + * @return + * A projected raster. + */ + def project(raster: => MosaicRasterGDAL, destCRS: SpatialReference): MosaicRasterGDAL = { + val outShortName = raster.getDriversShortName + + val resultFileName = PathUtils.createTmpFilePath(raster.uuid.toString, GDAL.getExtension(outShortName)) + + // Note that Null is the right value here + val authName = destCRS.GetAuthorityName(null) + val authCode = destCRS.GetAuthorityCode(null) + + val result = GDALWarp.executeWarp( + resultFileName, + isTemp = true, + Seq(raster), + command = s"gdalwarp -of $outShortName -t_srs $authName:$authCode -r cubic -overwrite -co COMPRESS=DEFLATE" + ) + + result + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala new file mode 100644 index 000000000..17cb39885 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala @@ -0,0 +1,88 @@ +package com.databricks.labs.mosaic.core.raster.operator.retile + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile + +/* ReTile is a helper object for retiling rasters. */ +object BalancedSubdivision { + + /** + * Gets the number of splits for a raster. The number of splits is + * determined by the size of the raster and the desired size of the split + * rasters. The number of splits is always a power of 4. This is a + * heuristic method only due to compressions and other factors. + * + * @param raster + * The raster to split. + * @param destSize + * The desired size of the split rasters in MB. + * @return + * The number of splits. + */ + def getNumSplits(raster: => MosaicRasterGDAL, destSize: Int): Int = { + val size = raster.getMemSize + val n = size.toDouble / (destSize * 1000 * 1000) + val nInt = Math.ceil(n).toInt + Math.pow(4, Math.ceil(Math.log(nInt) / Math.log(4))).toInt + } + + /** + * Gets the tile size for a raster. The tile size is determined by the + * number of splits. The tile size is always a power of 4. This is a + * heuristic method only due to compressions and other factors. + * @note + * Power of 2 is used to split the raster in each step but the number of + * splits is always a power of 4. + * + * @param x + * The x dimension of the raster. + * @param y + * The y dimension of the raster. + * @param numSplits + * The number of splits. + * @return + * The tile size. + */ + def getTileSize(x: Int, y: Int, numSplits: Int): (Int, Int) = { + def split(tile: (Int, Int)): (Int, Int) = { + val (a, b) = tile + if (a > b) (a / 2, b) else (a, b / 2) + } + var tile = (x, y) + val originRatio = x.toDouble / y.toDouble + var i = 0 + while (Math.pow(2, i) < numSplits) { + i += 1 + tile = split(tile) + } + val ratio = tile._1.toDouble / tile._2.toDouble + // if the ratio is not maintained, split one more time + // 0.1 is an arbitrary threshold to account for rounding errors + if (Math.abs(originRatio - ratio) > 0.1) tile = split(tile) + tile + } + + /** + * Splits a raster into multiple rasters. The number of splits is + * determined by the size of the raster and the desired size of the split + * rasters. The number of splits is always a power of 4. This is a + * heuristic method only due to compressions and other factors. + * + * @param tile + * The raster to split. + * @param sizeInMb + * The desired size of the split rasters in MB. + * @return + * A sequence of MosaicRaster objects. + */ + def splitRaster( + tile: => MosaicRasterTile, + sizeInMb: Int + ): Seq[MosaicRasterTile] = { + val numSplits = getNumSplits(tile.getRaster, sizeInMb) + val (x, y) = tile.getRaster.getDimensions + val (tileX, tileY) = getTileSize(x, y, numSplits) + ReTile.reTile(tile, tileX, tileY) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala new file mode 100644 index 000000000..5897347ab --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.mosaic.core.raster.operator.retile + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.utils.PathUtils + +import scala.collection.immutable + +/** OverlappingTiles is a helper object for retiling rasters. */ +object OverlappingTiles { + + /** + * Retiles a raster into overlapping tiles. + * @note + * The overlap percentage is a percentage of the tile size. + * + * @param tile + * The raster to retile. + * @param tileWidth + * The width of the tiles. + * @param tileHeight + * The height of the tiles. + * @param overlapPercentage + * The percentage of overlap between tiles. + * @return + * A sequence of MosaicRasterTile objects. + */ + def reTile( + tile: => MosaicRasterTile, + tileWidth: Int, + tileHeight: Int, + overlapPercentage: Int + ): immutable.Seq[MosaicRasterTile] = { + val raster = tile.getRaster + val (xSize, ySize) = raster.getDimensions + + val overlapWidth = Math.ceil(tileWidth * overlapPercentage / 100.0).toInt + val overlapHeight = Math.ceil(tileHeight * overlapPercentage / 100.0).toInt + + val tiles = for (i <- 0 until xSize by (tileWidth - overlapWidth)) yield { + for (j <- 0 until ySize by (tileHeight - overlapHeight)) yield { + val xOff = i + val yOff = j + val width = Math.min(tileWidth, xSize - i) + val height = Math.min(tileHeight, ySize - j) + + val uuid = java.util.UUID.randomUUID.toString + val fileExtension = GDAL.getExtension(tile.getDriver) + val rasterPath = PathUtils.createTmpFilePath(uuid, fileExtension) + val shortName = raster.getRaster.GetDriver.getShortName + + val result = GDALTranslate.executeTranslate( + rasterPath, + isTemp = true, + raster, + command = s"gdal_translate -of $shortName -srcwin $xOff $yOff $width $height" + ) + + val isEmpty = result.isEmpty + + if (isEmpty) dispose(result) + + (isEmpty, result) + } + } + + // TODO: The rasters should not be passed by objects. + + val (_, valid) = tiles.flatten.partition(_._1) + + valid.map(t => new MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala new file mode 100644 index 000000000..8c5ce4f32 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala @@ -0,0 +1,59 @@ +package com.databricks.labs.mosaic.core.raster.operator.retile + +import com.databricks.labs.mosaic.core.Mosaic +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.operator.proj.RasterProject +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile + +/** RasterTessellate is a helper object for tessellating rasters. */ +object RasterTessellate { + + /** + * Tessellates a raster into tiles. The raster is projected into the index + * system and then split into tiles. Each tile corresponds to a cell in the + * index system. + * + * @param raster + * The raster to tessellate. + * @param resolution + * The resolution of the tiles. + * @param indexSystem + * The index system to use. + * @param geometryAPI + * The geometry API to use. + * @return + * A sequence of MosaicRasterTile objects. + */ + def tessellate(raster: => MosaicRasterGDAL, resolution: Int, indexSystem: IndexSystem, geometryAPI: GeometryAPI): Seq[MosaicRasterTile] = { + val indexSR = indexSystem.osrSpatialRef + val bbox = raster.bbox(geometryAPI, indexSR) + val cells = Mosaic.mosaicFill(bbox, resolution, keepCoreGeom = false, indexSystem, geometryAPI) + val tmpRaster = RasterProject.project(raster, indexSR) + + val chips = cells + .map(cell => { + val cellID = cell.cellIdAsLong(indexSystem) + val isValidCell = indexSystem.isValid(cellID) + if (!isValidCell) { + (false, new MosaicRasterTile(cell.index, null, "", "")) + } else { + val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI) + val isValidRaster = cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty + ( + isValidRaster, + new MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) + ) + } + }) + + val (result, invalid) = chips.partition(_._1) + invalid.flatMap(t => Option(t._2.getRaster)).foreach(dispose(_)) + dispose(tmpRaster) + + result.map(_._2) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala new file mode 100644 index 000000000..e03712467 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala @@ -0,0 +1,66 @@ +package com.databricks.labs.mosaic.core.raster.operator.retile + +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.utils.PathUtils + +/** ReTile is a helper object for retiling rasters. */ +object ReTile { + + /** + * Retiles a raster into tiles. Empty tiles are discarded. The tile size is + * specified by the user via the tileWidth and tileHeight parameters. + * + * @param tile + * The raster to retile. + * @param tileWidth + * The width of the tiles. + * @param tileHeight + * The height of the tiles. + * @return + * A sequence of MosaicRasterTile objects. + */ + def reTile( + tile: => MosaicRasterTile, + tileWidth: Int, + tileHeight: Int + ): Seq[MosaicRasterTile] = { + val raster = tile.getRaster + val (xR, yR) = raster.getDimensions + val xTiles = Math.ceil(xR / tileWidth).toInt + val yTiles = Math.ceil(yR / tileHeight).toInt + + val tiles = for (x <- 0 until xTiles; y <- 0 until yTiles) yield { + val xMin = if (x == 0) x * tileWidth else x * tileWidth - 1 + val yMin = if (y == 0) y * tileHeight else y * tileHeight - 1 + val xOffset = if (xMin + tileWidth + 1 > xR) xR - xMin else tileWidth + 1 + val yOffset = if (yMin + tileHeight + 1 > yR) yR - yMin else tileHeight + 1 + + val rasterUUID = java.util.UUID.randomUUID.toString + val fileExtension = raster.getRasterFileExtension + val rasterPath = PathUtils.createTmpFilePath(rasterUUID, fileExtension) + val shortDriver = raster.getDriversShortName + + val result = GDALTranslate.executeTranslate( + rasterPath, + isTemp = true, + raster, + command = s"gdal_translate -of $shortDriver -srcwin $xMin $yMin $xOffset $yOffset -co COMPRESS=DEFLATE" + ) + + val isEmpty = result.isEmpty + + if (isEmpty) dispose(result) + + (isEmpty, result) + + } + + val (_, valid) = tiles.partition(_._1) + + valid.map(t => new MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/transform/RasterTransform.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/transform/RasterTransform.scala new file mode 100644 index 000000000..d8056a942 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/transform/RasterTransform.scala @@ -0,0 +1,50 @@ +package com.databricks.labs.mosaic.core.raster.operator.transform + +object RasterTransform { + + /** + * Take a geo transform matrix and x and y coordinates of a pixel and + * returns the x and y coors in the projection of the raster. As per GDAL + * documentation, the origin is the top left corner of the top left pixel + * + * @see + * https://gdal.org/tutorials/raster_api_tut.html + * @param geoTransform + * The geo transform matrix of the raster. + * @param x + * The x coordinate of the pixel. + * @param y + * The y coordinate of the pixel. + * @return + * A tuple of doubles with the x and y coordinates in the projection of + * the raster. + */ + def toWorldCoord(geoTransform: Seq[Double], x: Int, y: Int): (Double, Double) = { + val Xp = geoTransform.head + x * geoTransform(1) + y * geoTransform(2) + val Yp = geoTransform(3) + x * geoTransform(4) + y * geoTransform(5) + (Xp, Yp) + } + + /** + * Take a geo transform matrix and x and y coordinates of a point and + * returns the x and y coordinates of the raster pixel. + * + * @see + * // Reference: + * https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal + * @param geoTransform + * The geo transform matrix of the raster. + * @param xGeo + * The x coordinate of the point. + * @param yGeo + * The y coordinate of the point. + * @return + * A tuple of integers with the x and y coordinates of the raster pixel. + */ + def fromWorldCoord(geoTransform: Seq[Double], xGeo: Double, yGeo: Double): (Int, Int) = { + val x = ((xGeo - geoTransform.head) / geoTransform(1)).toInt + val y = ((yGeo - geoTransform(3)) / geoTransform(5)).toInt + (x, y) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala new file mode 100644 index 000000000..1cadf2c9a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala @@ -0,0 +1,38 @@ +package com.databricks.labs.mosaic.core.types + +import org.apache.spark.sql.types._ + +/** Type definition for the raster tile. */ +class RasterTileType(fields: Array[StructField]) extends StructType(fields) { + + def rasterType: DataType = fields(1).dataType + + override def simpleString: String = "RASTER_TILE" + + override def typeName: String = "struct" + +} + +object RasterTileType { + + /** + * Creates a new instance of [[RasterTileType]]. + * + * @param idType + * Type of the index ID. + * @return + * An instance of [[RasterTileType]]. + */ + def apply(idType: DataType): RasterTileType = { + require(Seq(LongType, IntegerType, StringType).contains(idType)) + new RasterTileType( + Array( + StructField("index_id", idType), + StructField("raster", BinaryType), + StructField("parentPath", StringType), + StructField("driver", StringType) + ) + ) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala index c2dfec505..40fc4da8f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicChip.scala @@ -26,12 +26,14 @@ case class MosaicChip(isCore: Boolean, index: Either[Long, String], geom: Mosaic def isEmpty: Boolean = !isCore & Option(geom).forall(_.isEmpty) /** - * Formats the index ID as the data type supplied by the index system. - * - * @param indexSystem Index system to use for formatting. - * - * @return MosaicChip with formatted index ID. - */ + * Formats the index ID as the data type supplied by the index system. + * + * @param indexSystem + * Index system to use for formatting. + * + * @return + * MosaicChip with formatted index ID. + */ def formatCellId(indexSystem: IndexSystem): MosaicChip = { (indexSystem.getCellIdDataType, index) match { case (_: LongType, Left(value)) => this @@ -42,15 +44,17 @@ case class MosaicChip(isCore: Boolean, index: Either[Long, String], geom: Mosaic } } - def cellIdAsLong(indexSystem: IndexSystem): Long = index match { - case Left(value) => value - case _ => indexSystem.parse(index.right.get) - } + def cellIdAsLong(indexSystem: IndexSystem): Long = + index match { + case Left(value) => value + case _ => indexSystem.parse(index.right.get) + } - def cellIdAsStr(indexSystem: IndexSystem): String = index match { - case Right(value) => value - case _ => indexSystem.format(index.left.get) - } + def cellIdAsStr(indexSystem: IndexSystem): String = + index match { + case Right(value) => value + case _ => indexSystem.format(index.left.get) + } /** * Serialise to spark internal representation. @@ -71,4 +75,9 @@ case class MosaicChip(isCore: Boolean, index: Either[Long, String], geom: Mosaic */ private def encodeGeom: Array[Byte] = Option(geom).map(_.toWKB).orNull + def indexAsLong(indexSystem: IndexSystem): Long = { + if (index.isLeft) index.left.get + else indexSystem.formatCellId(index.right.get, LongType).asInstanceOf[Long] + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala new file mode 100644 index 000000000..48435d7da --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala @@ -0,0 +1,175 @@ +package com.databricks.labs.mosaic.core.types.model + +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{BinaryType, DataType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A case class modeling an instance of a mosaic raster tile. + * + * @param index + * Index ID. + * @param raster + * Raster instance corresponding to the tile. + * @param parentPath + * Parent path of the raster. + * @param driver + * Driver used to read the raster. + */ +class MosaicRasterTile( + index: Either[Long, String], + raster: => MosaicRasterGDAL, + parentPath: String, + driver: String +) { + + def getIndex: Either[Long, String] = index + + def getParentPath: String = parentPath + + def getDriver: String = driver + + def getRaster: MosaicRasterGDAL = raster + + /** + * Indicates whether the raster is present. + * @return + * True if the raster is present, false otherwise. + */ + def isEmpty: Boolean = Option(raster).forall(_.isEmpty) + + /** + * Formats the index ID as the data type supplied by the index system. + * + * @param indexSystem + * Index system to use for formatting. + * + * @return + * MosaicChip with formatted index ID. + */ + def formatCellId(indexSystem: IndexSystem): MosaicRasterTile = { + if (Option(index).isEmpty) return this + (indexSystem.getCellIdDataType, index) match { + case (_: LongType, Left(_)) => this + case (_: StringType, Right(_)) => this + case (_: LongType, Right(value)) => new MosaicRasterTile( + index = Left(indexSystem.parse(value)), + raster = raster, + parentPath = parentPath, + driver = driver + ) + case (_: StringType, Left(value)) => new MosaicRasterTile( + index = Right(indexSystem.format(value)), + raster = raster, + parentPath = parentPath, + driver = driver + ) + case _ => throw new IllegalArgumentException("Invalid cell id data type") + } + } + + /** + * Formats the index ID as the long type. + * + * @param indexSystem + * Index system to use for formatting. + * + * @return + * MosaicChip with formatted index ID. + */ + def cellIdAsLong(indexSystem: IndexSystem): Long = + index match { + case Left(value) => value + case _ => indexSystem.parse(index.right.get) + } + + /** + * Formats the index ID as the string type. + * @param indexSystem + * Index system to use for formatting. + * @return + * MosaicChip with formatted index ID. + */ + def cellIdAsStr(indexSystem: IndexSystem): String = + index match { + case Right(value) => value + case _ => indexSystem.format(index.left.get) + } + + /** + * Serialise to spark internal representation. + * + * @return + * An instance of [[InternalRow]]. + */ + def serialize( + rasterDataType: DataType = BinaryType, + checkpointLocation: String = "" + ): InternalRow = { + val parentPathUTF8 = UTF8String.fromString(parentPath) + val driverUTF8 = UTF8String.fromString(driver) + val encodedRaster = encodeRaster(rasterDataType, checkpointLocation) + if (Option(index).isDefined) { + if (index.isLeft) InternalRow.fromSeq( + Seq(index.left.get, encodedRaster, parentPathUTF8, driverUTF8) + ) + else InternalRow.fromSeq( + Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) + ) + } else { + InternalRow.fromSeq(Seq(null, encodedRaster, parentPathUTF8, driverUTF8)) + } + + } + + /** + * Encodes the chip geometry as WKB. + * + * @return + * An instance of [[Array]] of [[Byte]] representing WKB. + */ + private def encodeRaster( + rasterDataType: DataType = BinaryType, + checkpointLocation: String = "" + ): Any = { + GDAL.writeRasters(Seq(raster), checkpointLocation, rasterDataType).head + } + +} + +/** Companion object. */ +object MosaicRasterTile { + + /** + * Smart constructor based on Spark internal instance. + * + * @param row + * An instance of [[InternalRow]]. + * @param idDataType + * The data type of the index ID. + * @return + * An instance of [[MosaicRasterTile]]. + */ + def deserialize(row: InternalRow, idDataType: DataType): MosaicRasterTile = { + val index = row.get(0, idDataType) + val rasterBytes = row.get(1, BinaryType) + val parentPath = row.get(2, StringType).toString + val driver = row.get(3, StringType).toString + val raster = GDAL.readRaster(rasterBytes, parentPath, driver, BinaryType) + // noinspection TypeCheckCanBeMatch + if (Option(index).isDefined) { + if (index.isInstanceOf[Long]) { + new MosaicRasterTile(Left(index.asInstanceOf[Long]), raster, parentPath, driver) + } else { + new MosaicRasterTile(Right(index.asInstanceOf[UTF8String].toString), raster, parentPath, driver) + } + } else { + new MosaicRasterTile(null, raster, parentPath, driver) + } + + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/GDALFileFormat.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/GDALFileFormat.scala deleted file mode 100644 index 2d528210c..000000000 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/GDALFileFormat.scala +++ /dev/null @@ -1,139 +0,0 @@ -package com.databricks.labs.mosaic.datasource - -import com.databricks.labs.mosaic.core.raster.MosaicRasterGDAL -import com.databricks.labs.mosaic.GDAL -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.mapreduce.Job -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types._ - -/** - * A base Spark SQL data source for reading GDAL raster data sources. It reads - * metadata of the raster and exposes the direct paths for the raster files. - */ -class GDALFileFormat extends FileFormat with DataSourceRegister with Serializable { - - import GDALFileFormat._ - - override def shortName(): String = "gdal" - - override def inferSchema( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus] - ): Option[StructType] = { - GDAL.enable() - inferSchemaImpl() - } - - override def isSplitable( - sparkSession: SparkSession, - options: Map[String, String], - path: org.apache.hadoop.fs.Path - ): Boolean = false - - override def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration - ): PartitionedFile => Iterator[InternalRow] = { - val driverName = options.getOrElse("driverName", "") - buildReaderImpl(driverName, options) - } - - override def prepareWrite( - sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType - ): OutputWriterFactory = throw new Error("Not implemented") - -} - -object GDALFileFormat extends Serializable { - - /** - * Returns the supported file extension for the driver name. - * - * @param driverName - * the GDAL driver name - * @return - * the file extension - */ - def getFileExtension(driverName: String): String = { - // Not a complete list of GDAL drivers - driverName match { - case "GTiff" => "tif" - case "HDF4" => "hdf" - case "HDF5" => "hdf" - case "JP2ECW" => "jp2" - case "JP2KAK" => "jp2" - case "JP2MrSID" => "jp2" - case "JP2OpenJPEG" => "jp2" - case "NetCDF" => "nc" - case "PDF" => "pdf" - case "PNG" => "png" - case "VRT" => "vrt" - case "XPM" => "xpm" - case "COG" => "tif" - case "GRIB" => "grib" - case "Zarr" => "zarr" - case _ => "UNSUPPORTED" - } - } - - /** GDAL readers have fixed schema. */ - def inferSchemaImpl(): Option[StructType] = { - - Some( - StructType( - Array( - StructField("path", StringType, nullable = false), - StructField("ySize", IntegerType, nullable = false), - StructField("xSize", IntegerType, nullable = false), - StructField("bandCount", IntegerType, nullable = false), - StructField("metadata", MapType(StringType, StringType), nullable = false), - StructField("subdatasets", MapType(StringType, StringType), nullable = false), - StructField("srid", IntegerType, nullable = false), - StructField("proj4Str", StringType, nullable = false) - ) - ) - ) - - } - - def buildReaderImpl( - driverName: String, - options: Map[String, String] - ): PartitionedFile => Iterator[InternalRow] = { file: PartitionedFile => - { - GDAL.enable() - val vsizip = options.getOrElse("vsizip", "false").toBoolean - val path = Utils.getCleanPath(file.filePath, vsizip) - - if (path.endsWith(getFileExtension(driverName)) || path.endsWith("zip")) { - val raster = MosaicRasterGDAL.readRaster(path) - val ySize = raster.ySize - val xSize = raster.xSize - val bandCount = raster.numBands - val metadata = raster.metadata - val subdatasets = raster.subdatasets - val srid = raster.SRID - val proj4Str = raster.proj4String - val row = Utils.createRow(Seq(path, ySize, xSize, bandCount, metadata, subdatasets, srid, proj4Str)) - Seq(row).iterator - } else { - Seq.empty[InternalRow].iterator - } - } - } - -} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/OGRFileFormat.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/OGRFileFormat.scala index b29acd9c5..b1b6b78b4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/OGRFileFormat.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/OGRFileFormat.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.datasource +import com.databricks.labs.mosaic.utils.PathUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.Job @@ -366,7 +367,7 @@ object OGRFileFormat extends Serializable { * the data source */ def getDataSource(driverName: String, path: String, useZipPath: Boolean): org.gdal.ogr.DataSource = { - val cleanPath = Utils.getCleanPath(path, useZipPath) + val cleanPath = PathUtils.getCleanPath(path, useZipPath) // 0 is for no update driver if (driverName.nonEmpty) { ogr.GetDriverByName(driverName).Open(cleanPath, 0) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/Utils.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/Utils.scala index 7ad7924a2..5b910697e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/Utils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/Utils.scala @@ -29,17 +29,4 @@ object Utils { ) } - def getCleanPath(path: String, useZipPath: Boolean): String = { - val cleanPath = path.replace("file:/", "/").replace("dbfs:/", "/dbfs/") - if (useZipPath && cleanPath.endsWith(".zip")) { - // It is really important that the resulting path is /vsizip// and not /vsizip/ - // /vsizip// is for absolute paths /viszip/ is relative to the current working directory - // /vsizip/ wont work on a cluster - // see: https://gdal.org/user/virtual_file_systems.html#vsizip-zip-archives - s"/vsizip/$cleanPath" - } else { - cleanPath - } - } - } diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/GDALFileFormat.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/GDALFileFormat.scala new file mode 100644 index 000000000..ea2ee40b5 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/GDALFileFormat.scala @@ -0,0 +1,276 @@ +package com.databricks.labs.mosaic.datasource.gdal + +import com.databricks.labs.mosaic.core.index.IndexSystemFactory +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.google.common.io.{ByteStreams, Closeables} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapreduce.Job +import org.apache.orc.util.Murmur3 +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat +import org.apache.spark.sql.execution.datasources.{OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +import java.net.URI +import java.sql.Timestamp +import java.util.Locale + +/** A file format for reading binary files using GDAL. */ +class GDALFileFormat extends BinaryFileFormat { + + import GDALFileFormat._ + + /** + * Infer schema for the raster file. + * @param sparkSession + * Spark session. + * @param options + * Reading options. + * @param files + * List of files. + * @return + * An instance of [[StructType]]. + */ + override def inferSchema(sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + GDAL.enable() + + val reader = ReadStrategy.getReader(options) + val schema = super + .inferSchema(sparkSession, options, files) + .map(reader.getSchema(options, files, _, sparkSession)) + + schema + } + + /** + * Prepare write is not supported. + * @param sparkSession + * Spark session. + * @param job + * Job. + * @param options + * Writing options. + * @param dataSchema + * Data schema. + * @return + * An instance of [[OutputWriterFactory]]. + */ + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType + ): OutputWriterFactory = { + throw new Error("Writing to GDALFileFormat is not supported.") + } + + /** + * Indicates whether the file format is splittable. + * @param sparkSession + * Spark session. + * @param options + * Reading options. + * @param path + * Path. + * @return + * True if the file format is splittable, false otherwise. Always false + * for GDAL. + */ + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: org.apache.hadoop.fs.Path + ): Boolean = false + + override def shortName(): String = GDAL_BINARY_FILE + + /** + * Build a reader for the file format. + * @param sparkSession + * Spark session. + * @param dataSchema + * Data schema. + * @param partitionSchema + * Partition schema. + * @param requiredSchema + * Required schema. + * @param filters + * Filters. + * @param options + * Reading options. + * @param hadoopConf + * Hadoop configuration. + * @return + * A function that takes a [[PartitionedFile]] and returns an iterator of + * [[org.apache.spark.sql.catalyst.InternalRow]]. + */ + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: org.apache.hadoop.conf.Configuration + ): PartitionedFile => Iterator[org.apache.spark.sql.catalyst.InternalRow] = { + GDAL.enable() + + val indexSystem = IndexSystemFactory.getIndexSystem(sparkSession) + + val supportedExtensions = options.getOrElse("extensions", "*").split(";").map(_.trim.toLowerCase(Locale.ROOT)) + + val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val filterFuncs = filters.flatMap(createFilterFunction) + val maxLength = sparkSession.conf.get("spark.sql.sources.binaryFile.maxLength", Int.MaxValue.toString).toInt + + // Identify the reader to use for the file format. + // GDAL supports multiple reading strategies. + val reader = ReadStrategy.getReader(options) + + file: PartitionedFile => { + GDAL.enable() + val path = new Path(new URI(file.filePath)) + val fs = path.getFileSystem(broadcastedHadoopConf.value.value) + val status = fs.getFileStatus(path) + + if (status.getLen > maxLength) throw CantReadBytesException(maxLength, status) + + if (supportedExtensions.contains("*") || supportedExtensions.exists(status.getPath.getName.toLowerCase(Locale.ROOT).endsWith)) { + if (filterFuncs.forall(_.apply(status)) && isAllowedExtension(status, options)) { + reader.read(status, fs, requiredSchema, options, indexSystem) + } else { + Iterator.empty + } + } else { + Iterator.empty + } + } + + } + +} + +object GDALFileFormat { + + val GDAL_BINARY_FILE = "gdal" + val PATH = "path" + val LENGTH = "length" + val MODIFICATION_TIME = "modificationTime" + val TILE = "tile" + val CONTENT = "content" + val X_SIZE = "x_size" + val Y_SIZE = "y_size" + val X_OFFSET = "x_offset" + val Y_OFFSET = "y_offset" + val BAND_COUNT = "bandCount" + val METADATA = "metadata" + val SUBDATASETS: String = "subdatasets" + val SRID = "srid" + val UUID = "uuid" + + /** + * Creates an exception for when the file is too big to read. + * @param maxLength + * Maximum length. + * @param status + * File status. + * @return + * An instance of [[SparkException]]. + */ + def CantReadBytesException(maxLength: Long, status: FileStatus): SparkException = + new SparkException( + s"Can't read binary files bigger than $maxLength bytes. " + + s"File ${status.getPath} is ${status.getLen} bytes" + ) + + /** + * Generates a UUID for the file. + * @param status + * File status. + * @return + * A UUID. + */ + def getUUID(status: FileStatus): Long = { + val uuid = Murmur3.hash64( + status.getPath.toString.getBytes("UTF-8") ++ + status.getLen.toString.getBytes("UTF-8") ++ + status.getModificationTime.toString.getBytes("UTF-8") + ) + uuid + } + + // noinspection UnstableApiUsage + /** + * Reads the content of the file. + * @param fs + * File system. + * @param status + * File status. + * @return + * An array of bytes. + */ + def readContent(fs: FileSystem, status: FileStatus): Array[Byte] = { + val stream = fs.open(status.getPath) + try { // noinspection UnstableApiUsage + ByteStreams.toByteArray(stream) + } finally { // noinspection UnstableApiUsage + Closeables.close(stream, true) + } + } + + /** + * Indicates whether the file extension is allowed. + * @param status + * File status. + * @param options + * Reading options. + * @return + * True if the file extension is allowed, false otherwise. + */ + def isAllowedExtension(status: FileStatus, options: Map[String, String]): Boolean = { + val allowedExtensions = options.getOrElse("extensions", "*").split(";").map(_.trim.toLowerCase(Locale.ROOT)) + val fileExtension = status.getPath.getName.toLowerCase(Locale.ROOT) + allowedExtensions.contains("*") || allowedExtensions.exists(fileExtension.endsWith) + } + + /** + * Creates a filter function for the file. + * @param filter + * Filter. + * @return + * An instance of [[FileStatus]] => [[Boolean]]. + */ + private def createFilterFunction(filter: Filter): Option[FileStatus => Boolean] = { + filter match { + case And(left, right) => (createFilterFunction(left), createFilterFunction(right)) match { + case (Some(leftPred), Some(rightPred)) => Some(s => leftPred(s) && rightPred(s)) + case (Some(leftPred), None) => Some(leftPred) + case (None, Some(rightPred)) => Some(rightPred) + case (None, None) => Some(_ => true) + } + case Or(left, right) => (createFilterFunction(left), createFilterFunction(right)) match { + case (Some(leftPred), Some(rightPred)) => Some(s => leftPred(s) || rightPred(s)) + case _ => Some(_ => true) + } + case Not(child) => createFilterFunction(child) match { + case Some(pred) => Some(s => !pred(s)) + case _ => Some(_ => true) + } + case LessThan(LENGTH, value: Long) => Some(_.getLen < value) + case LessThanOrEqual(LENGTH, value: Long) => Some(_.getLen <= value) + case GreaterThan(LENGTH, value: Long) => Some(_.getLen > value) + case GreaterThanOrEqual(LENGTH, value: Long) => Some(_.getLen >= value) + case EqualTo(LENGTH, value: Long) => Some(_.getLen == value) + case LessThan(MODIFICATION_TIME, value: Timestamp) => Some(_.getModificationTime < value.getTime) + case LessThanOrEqual(MODIFICATION_TIME, value: Timestamp) => Some(_.getModificationTime <= value.getTime) + case GreaterThan(MODIFICATION_TIME, value: Timestamp) => Some(_.getModificationTime > value.getTime) + case GreaterThanOrEqual(MODIFICATION_TIME, value: Timestamp) => Some(_.getModificationTime >= value.getTime) + case EqualTo(MODIFICATION_TIME, value: Timestamp) => Some(_.getModificationTime == value.getTime) + case _ => None + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala new file mode 100644 index 000000000..acd53e535 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala @@ -0,0 +1,129 @@ +package com.databricks.labs.mosaic.datasource.gdal + +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.raster.operator.retile.BalancedSubdivision +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.Utils +import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat._ +import org.apache.hadoop.fs.{FileStatus, FileSystem} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** An object defining the retiling read strategy for the GDAL file format. */ +object ReTileOnRead extends ReadStrategy { + + // noinspection DuplicatedCode + /** + * Returns the schema of the GDAL file format. + * @note + * Different read strategies can have different schemas. + * + * @param options + * Options passed to the reader. + * @param files + * List of files to read. + * @param parentSchema + * Parent schema. + * @param sparkSession + * Spark session. + * + * @return + * Schema of the GDAL file format. + */ + override def getSchema( + options: Map[String, String], + files: Seq[FileStatus], + parentSchema: StructType, + sparkSession: SparkSession + ): StructType = { + val trimmedSchema = parentSchema.filter(field => field.name != CONTENT && field.name != LENGTH) + val indexSystem = IndexSystemFactory.getIndexSystem(sparkSession) + StructType(trimmedSchema) + .add(StructField(UUID, LongType, nullable = false)) + .add(StructField(X_SIZE, IntegerType, nullable = false)) + .add(StructField(Y_SIZE, IntegerType, nullable = false)) + .add(StructField(BAND_COUNT, IntegerType, nullable = false)) + .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) + .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) + .add(StructField(SRID, IntegerType, nullable = false)) + .add(StructField(LENGTH, LongType, nullable = false)) + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + } + + /** + * Reads the content of the file. + * @param status + * File status. + * @param fs + * File system. + * @param requiredSchema + * Required schema. + * @param options + * Options passed to the reader. + * @param indexSystem + * Index system. + * + * @return + * Iterator of internal rows. + */ + override def read( + status: FileStatus, + fs: FileSystem, + requiredSchema: StructType, + options: Map[String, String], + indexSystem: IndexSystem + ): Iterator[InternalRow] = { + val inPath = status.getPath.toString + val uuid = getUUID(status) + val sizeInMB = options.getOrElse("sizeInMB", "16").toInt + + val tiles = localSubdivide(inPath, sizeInMB) + + val rows = tiles.map(tile => { + val trimmedSchema = StructType(requiredSchema.filter(field => field.name != TILE)) + val fields = trimmedSchema.fieldNames.map { + case PATH => status.getPath.toString + case MODIFICATION_TIME => status.getModificationTime + case UUID => uuid + case X_SIZE => tile.getRaster.xSize + case Y_SIZE => tile.getRaster.ySize + case BAND_COUNT => tile.getRaster.numBands + case METADATA => tile.getRaster.metadata + case SUBDATASETS => tile.getRaster.subdatasets + case SRID => tile.getRaster.SRID + case LENGTH => tile.getRaster.getMemSize + case other => throw new RuntimeException(s"Unsupported field name: $other") + } + // Writing to bytes is destructive so we delay reading content and content length until the last possible moment + val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize())) + RasterCleaner.dispose(tile) + row + }) + + rows.iterator + } + + /** + * Subdivides a raster into tiles of a given size. + * @param inPath + * Path to the raster. + * @param sizeInMB + * Size of the tiles in MB. + * + * @return + * A tuple of the raster and the tiles. + */ + def localSubdivide(inPath: String, sizeInMB: Int): Seq[MosaicRasterTile] = { + val raster = MosaicRasterGDAL.readRaster(inPath, inPath) + val inTile = new MosaicRasterTile(null, raster, inPath, raster.getDriversShortName) + val tiles = BalancedSubdivision.splitRaster(inTile, sizeInMB) + RasterCleaner.dispose(raster) + RasterCleaner.dispose(inTile) + tiles + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala new file mode 100644 index 000000000..381804ff9 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala @@ -0,0 +1,105 @@ +package com.databricks.labs.mosaic.datasource.gdal + +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.datasource.Utils +import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat._ +import org.apache.hadoop.fs.{FileStatus, FileSystem} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** An object defining the in memory read strategy for the GDAL file format. */ +object ReadInMemory extends ReadStrategy { + + // noinspection DuplicatedCode + /** + * Returns the schema of the GDAL file format. + * @note + * Different read strategies can have different schemas. + * + * @param options + * Options passed to the reader. + * @param files + * List of files to read. + * @param parentSchema + * Parent schema. + * @param sparkSession + * Spark session. + * + * @return + * Schema of the GDAL file format. + */ + override def getSchema( + options: Map[String, String], + files: Seq[FileStatus], + parentSchema: StructType, + sparkSession: SparkSession + ): StructType = { + val indexSystem = IndexSystemFactory.getIndexSystem(sparkSession) + StructType(parentSchema.filter(_.name != CONTENT)) + .add(StructField(UUID, LongType, nullable = false)) + .add(StructField(X_SIZE, IntegerType, nullable = false)) + .add(StructField(Y_SIZE, IntegerType, nullable = false)) + .add(StructField(BAND_COUNT, IntegerType, nullable = false)) + .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) + .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) + .add(StructField(SRID, IntegerType, nullable = false)) + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + } + + /** + * Reads the content of the file. + * @param status + * File status. + * @param fs + * File system. + * @param requiredSchema + * Required schema. + * @param options + * Options passed to the reader. + * @param indexSystem + * Index system. + * @return + * Iterator of internal rows. + */ + override def read( + status: FileStatus, + fs: FileSystem, + requiredSchema: StructType, + options: Map[String, String], + indexSystem: IndexSystem + ): Iterator[InternalRow] = { + val inPath = status.getPath.toString + val driverShortName = MosaicRasterGDAL.identifyDriver(inPath) + val contentBytes: Array[Byte] = readContent(fs, status) + val raster = MosaicRasterGDAL.readRaster(contentBytes, inPath, driverShortName) + val uuid = getUUID(status) + + val fields = requiredSchema.fieldNames.filter(_ != TILE).map { + case PATH => status.getPath.toString + case LENGTH => status.getLen + case MODIFICATION_TIME => status.getModificationTime + case UUID => uuid + case X_SIZE => raster.xSize + case Y_SIZE => raster.ySize + case BAND_COUNT => raster.numBands + case METADATA => raster.metadata + case SUBDATASETS => raster.subdatasets + case SRID => raster.SRID + case other => throw new RuntimeException(s"Unsupported field name: $other") + } + val rasterTileSer = InternalRow.fromSeq( + Seq(null, contentBytes, UTF8String.fromString(inPath), UTF8String.fromString(driverShortName)) + ) + val row = Utils.createRow( + fields ++ Seq(rasterTileSer) + ) + RasterCleaner.dispose(raster) + Seq(row).iterator + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala new file mode 100644 index 000000000..cacc1c133 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala @@ -0,0 +1,80 @@ +package com.databricks.labs.mosaic.datasource.gdal + +import com.databricks.labs.mosaic._ +import com.databricks.labs.mosaic.core.index.IndexSystem +import org.apache.hadoop.fs.{FileStatus, FileSystem} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +/** A trait defining the read strategy for the GDAL file format. */ +trait ReadStrategy extends Serializable { + + /** + * Returns the schema of the GDAL file format. + * @note + * Different read strategies can have different schemas. + * + * @param options + * Options passed to the reader. + * @param files + * List of files to read. + * @param parentSchema + * Parent schema. + * @param sparkSession + * Spark session. + * + * @return + * Schema of the GDAL file format. + */ + def getSchema(options: Map[String, String], files: Seq[FileStatus], parentSchema: StructType, sparkSession: SparkSession): StructType + + /** + * Reads the content of the file. + * @param status + * File status. + * @param fs + * File system. + * @param requiredSchema + * Required schema. + * @param options + * Options passed to the reader. + * @param indexSystem + * Index system. + * + * @return + * Iterator of internal rows. + */ + def read( + status: FileStatus, + fs: FileSystem, + requiredSchema: StructType, + options: Map[String, String], + indexSystem: IndexSystem + ): Iterator[InternalRow] + +} + +/** A trait defining the read strategy for the GDAL file format. */ +object ReadStrategy { + + /** + * Returns the read strategy. + * @param options + * Options passed to the reader. + * + * @return + * Read strategy. + */ + def getReader(options: Map[String, String]): ReadStrategy = { + val readStrategy = options.getOrElse(MOSAIC_RASTER_READ_STRATEGY, MOSAIC_RASTER_READ_IN_MEMORY) + + readStrategy match { + case MOSAIC_RASTER_READ_IN_MEMORY => ReadInMemory + case MOSAIC_RASTER_RE_TILE_ON_READ => ReTileOnRead + case _ => ReadInMemory + } + + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala index f2a8a3fdc..4cb39066a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala @@ -1,6 +1,5 @@ package com.databricks.labs.mosaic.datasource.multiread -import com.databricks.labs.mosaic.datasource.GDALFileFormat import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -18,14 +17,14 @@ import org.apache.spark.sql.functions._ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameReader(sparkSession) { private val mc = MosaicContext.context() - mc.getRasterAPI.enable() import mc.functions._ - val vsizipPathColF: Column => Column = (path: Column) => - when( - path.endsWith(".zip"), - concat(lit("/vsizip/"), path) - ).otherwise(path) + val vsizipPathColF: Column => Column = + (path: Column) => + when( + path.endsWith(".zip"), + concat(lit("/vsizip/"), path) + ).otherwise(path) override def load(path: String): DataFrame = load(Seq(path): _*) @@ -35,13 +34,10 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val resolution = config("resolution").toInt val pathsDf = sparkSession.read - .option("pathGlobFilter", config("fileExtension")) - .format("binaryFile") + .format("gdal") + .option("extensions", config("extensions")) + .option("raster_storage", "in-memory") .load(paths: _*) - .select("path") - .select( - vsizipPathColF(col("path")).alias("path") - ) val rasterToGridCombiner = getRasterToGridFunc(config("combiner")) @@ -52,18 +48,16 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val loadedDf = retiledDf .withColumn( "grid_measures", - rasterToGridCombiner(col("raster"), lit(resolution)) + rasterToGridCombiner(col("tile"), lit(resolution)) ) .select( "grid_measures", - "raster" + "tile" ) .select( - posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures")), - col("raster") + posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures")) ) .select( - col("raster"), col("band_id"), explode(col("grid_measures")).alias("grid_measures") ) @@ -96,8 +90,8 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead if (retile) { rasterDf.withColumn( - "raster", - rst_retile(col("raster"), lit(tileSize), lit(tileSize)) + "tile", + rst_retile(col("tile"), lit(tileSize), lit(tileSize)) ) } else { rasterDf @@ -120,30 +114,14 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead */ private def resolveRaster(pathsDf: DataFrame, config: Map[String, String]) = { val readSubdataset = config("readSubdataset").toBoolean - val subdatasetNumber = config("subdatasetNumber").toInt val subdatasetName = config("subdatasetName") if (readSubdataset) { pathsDf - .withColumn( - "subdatasets", - rst_subdatasets(col("path")) - ) - .withColumn( - "subdataset", - if (subdatasetName.isEmpty) { - element_at(map_keys(col("subdatasets")), subdatasetNumber) - } else { - element_at(col("subdatasets"), subdatasetName) - } - ) - .select( - vsizipPathColF(col("subdataset")).alias("raster") - ) + .withColumn("subdatasets", rst_subdatasets(col("tile"))) + .withColumn("tile", rst_getsubdataset(col("tile"), lit(subdatasetName))) } else { - pathsDf.select( - col("path").alias("raster") - ) + pathsDf.select(col("tile")) } } @@ -207,7 +185,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead */ private def getConfig: Map[String, String] = { Map( - "fileExtension" -> this.extraOptions.getOrElse("fileExtension", "*"), + "extensions" -> this.extraOptions.getOrElse("extensions", "*"), "readSubdataset" -> this.extraOptions.getOrElse("readSubdataset", "false"), "vsizip" -> this.extraOptions.getOrElse("vsizip", "false"), "subdatasetNumber" -> this.extraOptions.getOrElse("subdatasetNumber", "0"), diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferCapStyle.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferCapStyle.scala new file mode 100644 index 000000000..c814108b8 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferCapStyle.scala @@ -0,0 +1,65 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.core.geometry.MosaicGeometry +import com.databricks.labs.mosaic.expressions.base.WithExpressionInfo +import com.databricks.labs.mosaic.expressions.geometry.base.UnaryVector2ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.adapters.Column +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.DataType +import org.apache.spark.unsafe.types.UTF8String + +/** + * SQL expression that returns the input geometry buffered by the radius. + * @param inputGeom + * Expression containing the geometry. + * @param radiusExpr + * The radius of the buffer. + * @param expressionConfig + * Mosaic execution context, e.g. geometryAPI, indexSystem, etc. Additional + * arguments for the expression (expressionConfigs). + */ +case class ST_BufferCapStyle( + inputGeom: Expression, + radiusExpr: Expression, + capStyleExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends UnaryVector2ArgExpression[ST_BufferCapStyle](inputGeom, radiusExpr, capStyleExpr, returnsGeometry = true, expressionConfig) { + + override def dataType: DataType = inputGeom.dataType + + override def geometryTransform(geometry: MosaicGeometry, arg1: Any, arg2: Any): Any = { + val radius = arg1.asInstanceOf[Double] + val capStyle = arg2.asInstanceOf[UTF8String].toString + geometry.bufferCapStyle(radius, capStyle) + } + + override def geometryCodeGen(geometryRef: String, argRef1: String, argRef2: String, ctx: CodegenContext): (String, String) = { + val resultRef = ctx.freshName("result") + val code = s"""$mosaicGeomClass $resultRef = $geometryRef.bufferCapStyle($argRef1, $argRef2.toString());""" + (code, resultRef) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object ST_BufferCapStyle extends WithExpressionInfo { + + override def name: String = "st_buffer_cap_style" + + override def usage: String = "_FUNC_(expr1, expr2) - Returns buffered geometry." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(a, b); + | POLYGON (...) + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + ST_BufferCapStyle(children.head, Column(children(1)).cast("double").expr, children(2), expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionAggregate.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionAggregate.scala index cf2b3b531..29fecdf24 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionAggregate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionAggregate.scala @@ -1,11 +1,11 @@ package com.databricks.labs.mosaic.expressions.geometry import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI -import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.expressions.index.IndexGeometry import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.types._ diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala index c5dc6d40f..cf9bd60ba 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala @@ -1,25 +1,34 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.{MosaicRaster, MosaicRasterBand} +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterBandGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterBandExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** * The expression for extracting metadata from a raster band. - * @param path - * The path to the raster. + * @param raster + * The expression for the raster. If the raster is stored on disk, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. * @param band * The band index. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). */ -case class RST_BandMetaData(path: Expression, band: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterBandExpression[RST_BandMetaData](path, band, MapType(StringType, StringType), expressionConfig) +case class RST_BandMetaData(raster: Expression, band: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterBandExpression[RST_BandMetaData]( + raster, + band, + MapType(StringType, StringType), + returnsRaster = false, + expressionConfig = expressionConfig + ) with NullIntolerant with CodegenFallback { @@ -31,9 +40,8 @@ case class RST_BandMetaData(path: Expression, band: Expression, expressionConfig * @return * The band metadata of the band as a map type result. */ - override def bandTransform(raster: MosaicRaster, band: MosaicRasterBand): Any = { - val metaData = band.metadata - buildMapString(metaData) + override def bandTransform(raster: => MosaicRasterTile, band: MosaicRasterBandGDAL): Any = { + buildMapString(band.metadata) } } @@ -47,7 +55,7 @@ object RST_BandMetaData extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a, 1); + | > SELECT _FUNC_(raster_tile, 1); | {"NC_GLOBAL#acknowledgement":"NOAA Coral Reef Watch Program","NC_GLOBAL#cdm_data_type":"Grid"} | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala new file mode 100644 index 000000000..397d3ee8e --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala @@ -0,0 +1,73 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.types.model.{GeometryTypeEnum, MosaicRasterTile} +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types._ + +/** The expression for extracting the bounding box of a raster. */ +case class RST_BoundingBox( + raster: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterExpression[RST_BoundingBox](raster, BinaryType, returnsRaster = false, expressionConfig = expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** + * Computes the bounding box of the raster. The bbox is returned as a WKB + * polygon. + * + * @param tile + * The raster tile to be used. + * @return + * The bounding box of the raster as a WKB polygon. + */ + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val raster = tile.getRaster + val gt = raster.getRaster.GetGeoTransform() + val (originX, originY) = GDAL.toWorldCoord(gt, 0, 0) + val (endX, endY) = GDAL.toWorldCoord(gt, raster.xSize, raster.ySize) + val geometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + val bboxPolygon = geometryAPI.geometry( + Seq( + Seq(originX, originY), + Seq(originX, endY), + Seq(endX, endY), + Seq(endX, originY), + Seq(originX, originY) + ).map(geometryAPI.fromCoords), + GeometryTypeEnum.POLYGON + ) + bboxPolygon.toWKB + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_BoundingBox extends WithExpressionInfo { + + override def name: String = "rst_boundingbox" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns the bounding box of the raster. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | POLYGON ((-180 -90, -180 90, 180 90, 180 -90, -180 -90)) + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_BoundingBox](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala new file mode 100644 index 000000000..7bad9b5d1 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala @@ -0,0 +1,78 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** The expression for clipping a raster by a vector. */ +case class RST_Clip( + rastersExpr: Expression, + geometryExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_Clip]( + rastersExpr, + geometryExpr, + RasterTileType(expressionConfig.getCellIdType), + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any): Any = { + val geometry = geometryAPI.geometry(arg1, geometryExpr.dataType) + val geomCRS = geometry.getSpatialReferenceOSR + val clipped = RasterClipByVector.clip(tile.getRaster, geometry, geomCRS, geometryAPI) + new MosaicRasterTile( + tile.getIndex, + clipped, + tile.getParentPath, + tile.getDriver + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Clip extends WithExpressionInfo { + + override def name: String = "rst_clip" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster clipped by provided vector. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, vector); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Clip](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala new file mode 100644 index 000000000..adb4974ee --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala @@ -0,0 +1,62 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.CombineAVG +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterArrayExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** Expression for combining rasters using average of pixels. */ +case class RST_CombineAvg( + rastersExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterArrayExpression[RST_CombineAvg]( + rastersExpr, + RasterTileType(expressionConfig.getCellIdType), + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** Combines the rasters using average of pixels. */ + override def rasterTransform(tiles: => Seq[MosaicRasterTile]): Any = { + val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null + new MosaicRasterTile( + index, + CombineAVG.compute(tiles.map(_.getRaster)), + tiles.head.getParentPath, + tiles.head.getDriver + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_CombineAvg extends WithExpressionInfo { + + override def name: String = "rst_combine_avg" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster that is a result of combining an array of rasters using average of pixels. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(array(raster_tile_1, raster_tile_2, raster_tile_3)); + | {index_id, raster, parent_path, driver} + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_CombineAvg](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala new file mode 100644 index 000000000..767275953 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala @@ -0,0 +1,133 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.index.IndexSystemFactory +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.raster.operator.CombineAVG +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpressionSerialization +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType} + +import scala.collection.mutable.ArrayBuffer + +/** + * Returns a new raster that is a result of combining an array of rasters using + * average of pixels. + */ +//noinspection DuplicatedCode +case class RST_CombineAvgAgg( + rasterExpr: Expression, + expressionConfig: MosaicExpressionConfig, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0 +) extends TypedImperativeAggregate[ArrayBuffer[Any]] + with UnaryLike[Expression] + with RasterExpressionSerialization { + + GDAL.enable() + + override lazy val deterministic: Boolean = true + override val child: Expression = rasterExpr + override val nullable: Boolean = false + override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def prettyName: String = "rst_combine_avg_agg" + + private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) + private lazy val row = new UnsafeRow(1) + + def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { + val value = child.eval(input) + buffer += InternalRow.copyValue(value) + buffer + } + + def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { + buffer ++= input + } + + override def createAggregationBuffer(): ArrayBuffer[Any] = ArrayBuffer.empty + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def eval(buffer: ArrayBuffer[Any]): Any = { + GDAL.enable() + + if (buffer.isEmpty) { + null + } else if (buffer.size == 1) { + buffer.head + } else { + + // Do do move the expression + val tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + + // If merging multiple index rasters, the index value is dropped + val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null + val combined = CombineAVG.compute(tiles.map(_.getRaster)).flushCache() + // TODO: should parent path be an array? + val parentPath = tiles.head.getParentPath + val driver = tiles.head.getDriver + + val result = new MosaicRasterTile(idx, combined, parentPath, driver) + .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) + .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + + tiles.foreach(RasterCleaner.dispose(_)) + RasterCleaner.dispose(result) + + result + } + } + + override def serialize(obj: ArrayBuffer[Any]): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes + } + + override def deserialize(bytes: Array[Byte]): ArrayBuffer[Any] = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(dataType, (_, x: Any) => buffer += x) + buffer + } + + override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(rasterExpr = newChild) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_CombineAvgAgg { + + def registryExpressionInfo(db: Option[String]): ExpressionInfo = + new ExpressionInfo( + classOf[RST_CombineAvgAgg].getCanonicalName, + db.orNull, + "rst_combine_avg_agg", + """ + | _FUNC_(tiles)) - Combines rasters into a single raster using average. + """.stripMargin, + "", + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | {index_id, raster, parent_path, driver} + | """.stripMargin, + "", + "agg_funcs", + "1.0", + "", + "built-in" + ) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala new file mode 100644 index 000000000..47271a35b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala @@ -0,0 +1,60 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.merge.MergeBands +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterArrayExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.BinaryType + +/** The expression for stacking and resampling input bands. */ +case class RST_FromBands( + bandsExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterArrayExpression[RST_FromBands]( + bandsExpr, + BinaryType, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** + * Stacks and resamples input bands. + * @param rasters + * The rasters to be used. + * @return + * The stacked and resampled raster. + */ + override def rasterTransform(rasters: => Seq[MosaicRasterTile]): Any = MergeBands.merge(rasters.map(_.getRaster), "bilinear") + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_FromBands extends WithExpressionInfo { + + override def name: String = "rst_frombands" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster that is a result of stacking and resampling input bands. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(array(band1, band2, band3)); + | {index_id, raster, parent_path, driver} + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_FromBands](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala new file mode 100644 index 000000000..fa69cfcfa --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala @@ -0,0 +1,110 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.gdal.ReTileOnRead +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} +import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * The raster for construction of a raster tile. This should be the first + * expression in the expression tree for a raster tile. + */ +case class RST_FromFile( + rasterPathExpr: Expression, + sizeInMB: Expression, + expressionConfig: MosaicExpressionConfig +) extends CollectionGenerator + with Serializable + with NullIntolerant + with CodegenFallback { + + GDAL.enable() + + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + + protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + + protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val cellIdDataType: DataType = indexSystem.getCellIdDataType + + override def position: Boolean = false + + override def inline: Boolean = false + + override def children: Seq[Expression] = Seq(rasterPathExpr, sizeInMB) + + override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) + + /** + * Loads a raster from a file and subdivides it into tiles of the specified + * size (in MB). + * @param input + * The input file path. + * @return + * The tiles. + */ + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + GDAL.enable() + val path = rasterPathExpr.eval(input).asInstanceOf[UTF8String].toString + val targetSize = sizeInMB.eval(input).asInstanceOf[Int] + if (targetSize <= 0) { + val raster = MosaicRasterGDAL.readRaster(path, path) + val tile = new MosaicRasterTile(null, raster, path, raster.getDriversShortName) + val row = tile.formatCellId(indexSystem).serialize() + RasterCleaner.dispose(raster) + RasterCleaner.dispose(tile) + Seq(InternalRow.fromSeq(Seq(row))) + } else { + val tiles = ReTileOnRead.localSubdivide(path, targetSize) + val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + tiles.foreach(RasterCleaner.dispose(_)) + rows.map(row => InternalRow.fromSeq(Seq(row))) + } + } + + override def makeCopy(newArgs: Array[AnyRef]): Expression = + GenericExpressionFactory.makeCopyImpl[RST_FromFile](this, newArgs, children.length, expressionConfig) + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = makeCopy(newChildren.toArray) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_FromFile extends WithExpressionInfo { + + override def name: String = "rst_fromfile" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_path); + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + (children: Seq[Expression]) => { + val sizeExpr = if (children.length == 1) new Literal(-1, IntegerType) else children(1) + RST_FromFile(children.head, sizeExpr, expressionConfig) + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala index 24ab19081..72a33e41c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala @@ -1,32 +1,34 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the georeference of the raster. */ -case class RST_GeoReference(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_GeoReference](path, MapType(StringType, DoubleType), expressionConfig) +case class RST_GeoReference(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_GeoReference](raster, MapType(StringType, DoubleType), returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the georeference of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val raster = tile.getRaster val geoTransform = raster.getRaster.GetGeoTransform() - val geoReference = Map( - "upperLeftX" -> geoTransform(0), - "upperLeftY" -> geoTransform(3), - "scaleX" -> geoTransform(1), - "scaleY" -> geoTransform(5), - "skewX" -> geoTransform(2), - "skewY" -> geoTransform(4) + buildMapDouble( + Map( + "upperLeftX" -> geoTransform(0), + "upperLeftY" -> geoTransform(3), + "scaleX" -> geoTransform(1), + "scaleY" -> geoTransform(5), + "skewX" -> geoTransform(2), + "skewY" -> geoTransform(4) + ) ) - buildMapDouble(geoReference) } } @@ -35,12 +37,12 @@ object RST_GeoReference extends WithExpressionInfo { override def name: String = "rst_georeference" - override def usage: String = "_FUNC_(expr1, expr2) - Extracts geo reference from a raster." + override def usage: String = "_FUNC_(expr1) - Extracts geo reference from a raster." override def example: String = """ | Examples: - | > SELECT _FUNC_(a, 1); + | > SELECT _FUNC_(raster_tile); | {"upper_left_x": 1.0, "upper_left_y": 1.0, "scale_x": 1.0, "scale_y": 1.0, "skew_x": 1.0, "skew_y": 1.0} | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala new file mode 100644 index 000000000..7a45a3eaa --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala @@ -0,0 +1,62 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.{ArrayType, DoubleType} + +/** The expression for extracting the no data value of a raster. */ +case class RST_GetNoData( + rastersExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterExpression[RST_GetNoData]( + rastersExpr, + ArrayType(DoubleType), + returnsRaster = false, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** + * Extracts the no data value of a raster. + * + * @param tile + * The raster to be used. + * @return + * The no data value of the raster. + */ + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getBands.map(_.noDataValue) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_GetNoData extends WithExpressionInfo { + + override def name: String = "rst_get_no_data" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster clipped by provided vector. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_GetNoData](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala new file mode 100644 index 000000000..1449bf6f3 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.unsafe.types.UTF8String + +/** Returns the subdatasets of the raster. */ +case class RST_GetSubdataset(raster: Expression, subsetName: Expression, expressionConfig: MosaicExpressionConfig) + extends Raster1ArgExpression[RST_GetSubdataset]( + raster, + subsetName, + RasterTileType(expressionConfig.getCellIdType), + returnsRaster = true, + expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** Returns the subdatasets of the raster. */ + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any): Any = { + val subsetName = arg1.asInstanceOf[UTF8String].toString + val subdataset = tile.getRaster.getSubdataset(subsetName) + new MosaicRasterTile(tile.getIndex, subdataset, tile.getParentPath, tile.getDriver) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_GetSubdataset extends WithExpressionInfo { + + override def name: String = "rst_getsubdataset" + + override def usage: String = "_FUNC_(expr1, expr2) - Extracts subdataset raster." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile, 'SUBDATASET_1_NAME'); + | {index_id, raster, parent_path, driver} + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_GetSubdataset](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala index 92a8c493f..02a6da249 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala @@ -1,26 +1,25 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the width of the raster. */ -case class RST_Height(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Height](path, IntegerType, expressionConfig) +case class RST_Height(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Height](raster, IntegerType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the width of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = raster.ySize + override def rasterTransform(tile: => MosaicRasterTile): Any = tile.getRaster.ySize } - /** Expression info required for the expression registration for spark SQL. */ object RST_Height extends WithExpressionInfo { @@ -31,7 +30,7 @@ object RST_Height extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 512 | """.stripMargin @@ -40,4 +39,3 @@ object RST_Height extends WithExpressionInfo { } } - diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala new file mode 100644 index 000000000..604bba92a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala @@ -0,0 +1,82 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** The expression that initializes no data values of a raster. */ +case class RST_InitNoData( + rastersExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterExpression[RST_InitNoData]( + rastersExpr, + RasterTileType(expressionConfig.getCellIdType), + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** + * Initializes no data values of a raster. + * + * @param tile + * The raster to be used. + * @return + * The raster with initialized no data values. + */ + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val noDataValues = tile.getRaster.getBands.map(_.noDataValue).mkString(" ") + val dstNoDataValues = tile.getRaster.getBands + .map(_.getBand.getDataType) + .map(GDAL.getNoDataConstant) + .mkString(" ") + val resultPath = PathUtils.createTmpFilePath(tile.getRaster.uuid.toString, GDAL.getExtension(tile.getDriver)) + val result = GDALWarp.executeWarp( + resultPath, + isTemp = true, + Seq(tile.getRaster), + command = s"""gdalwarp -of ${tile.getDriver} -dstnodata "$dstNoDataValues" -srcnodata "$noDataValues"""" + ) + new MosaicRasterTile( + tile.getIndex, + result, + tile.getParentPath, + tile.getDriver + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_InitNoData extends WithExpressionInfo { + + override def name: String = "rst_init_no_data" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster clipped by provided vector. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_InitNoData](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala index 79d2d7101..b54b63b55 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala @@ -1,22 +1,25 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ -case class RST_IsEmpty(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_IsEmpty](path, BooleanType, expressionConfig) +case class RST_IsEmpty(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_IsEmpty](raster, BooleanType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns true if the raster is empty. */ - override def rasterTransform(raster: MosaicRaster): Any = raster.ySize == 0 && raster.xSize == 0 + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val raster = tile.getRaster + (raster.ySize == 0 && raster.xSize == 0) || raster.isEmpty + } } @@ -30,12 +33,12 @@ object RST_IsEmpty extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | false | """.stripMargin override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { - GenericExpressionFactory.getBaseBuilder[RST_IsEmpty](2, expressionConfig) + GenericExpressionFactory.getBaseBuilder[RST_IsEmpty](1, expressionConfig) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala index a6d822403..eeffa8814 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala @@ -1,22 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the memory size of the raster in bytes. */ -case class RST_MemSize(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MemSize](path, LongType, expressionConfig) +case class RST_MemSize(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_MemSize](raster, LongType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the memory size of the raster in bytes. */ - override def rasterTransform(raster: MosaicRaster): Any = raster.getMemSize + override def rasterTransform(tile: => MosaicRasterTile): Any = tile.getRaster.getMemSize } @@ -30,7 +30,7 @@ object RST_MemSize extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 228743 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala new file mode 100644 index 000000000..54870aa65 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala @@ -0,0 +1,69 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.merge.MergeRasters +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterArrayExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** Returns a raster that is a result of merging an array of rasters. */ +case class RST_Merge( + rastersExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterArrayExpression[RST_Merge]( + rastersExpr, + RasterTileType(expressionConfig.getCellIdType), + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** + * Merges an array of rasters. + * @param tiles + * The rasters to be used. + * @return + * The merged raster. + */ + override def rasterTransform(tiles: => Seq[MosaicRasterTile]): Any = { + val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null + val raster = MergeRasters.merge(tiles.map(_.getRaster)) + new MosaicRasterTile( + index, + raster, + tiles.head.getParentPath, + tiles.head.getDriver + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Merge extends WithExpressionInfo { + + override def name: String = "rst_merge" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster that is a result of merging an array of rasters. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(array(raster_tile_1, raster_tile_2, raster_tile_3)); + | {index_id, raster, parent_path, driver} + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Merge](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala new file mode 100644 index 000000000..552feb0b5 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala @@ -0,0 +1,133 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.index.IndexSystemFactory +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.raster.operator.merge.MergeRasters +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpressionSerialization +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType} + +import scala.collection.mutable.ArrayBuffer + +/** Merges rasters into a single raster. */ +//noinspection DuplicatedCode +case class RST_MergeAgg( + rasterExpr: Expression, + expressionConfig: MosaicExpressionConfig, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0 +) extends TypedImperativeAggregate[ArrayBuffer[Any]] + with UnaryLike[Expression] + with RasterExpressionSerialization { + + GDAL.enable() + + override lazy val deterministic: Boolean = true + override val child: Expression = rasterExpr + override val nullable: Boolean = false + override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def prettyName: String = "rst_merge_agg" + + private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) + private lazy val row = new UnsafeRow(1) + + def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { + val value = child.eval(input) + buffer += InternalRow.copyValue(value) + buffer + } + + def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { + buffer ++= input + } + + override def createAggregationBuffer(): ArrayBuffer[Any] = ArrayBuffer.empty + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def eval(buffer: ArrayBuffer[Any]): Any = { + GDAL.enable() + + if (buffer.isEmpty) { + null + } else if (buffer.size == 1) { + buffer.head + } else { + + // This is a trick to get the rasters sorted by their parent path to ensure more consistent results + // when merging rasters with large overlaps + val tiles = buffer + .map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + .sortBy(_.getParentPath) + + // If merging multiple index rasters, the index value is dropped + val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null + val merged = MergeRasters.merge(tiles.map(_.getRaster)).flushCache() + // TODO: should parent path be an array? + val parentPath = tiles.head.getParentPath + val driver = tiles.head.getDriver + + val result = new MosaicRasterTile(idx, merged, parentPath, driver) + .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) + .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + + tiles.foreach(RasterCleaner.dispose(_)) + RasterCleaner.dispose(merged) + + result + } + } + + override def serialize(obj: ArrayBuffer[Any]): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes + } + + override def deserialize(bytes: Array[Byte]): ArrayBuffer[Any] = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(dataType, (_, x: Any) => buffer += x) + buffer + } + + override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(rasterExpr = newChild) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_MergeAgg { + + def registryExpressionInfo(db: Option[String]): ExpressionInfo = + new ExpressionInfo( + classOf[RST_MergeAgg].getCanonicalName, + db.orNull, + "rst_merge_agg", + """ + | _FUNC_(tiles)) - Merges rasters into a single raster. + """.stripMargin, + "", + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | {index_id, raster, parent_path, driver} + | """.stripMargin, + "", + "agg_funcs", + "1.0", + "", + "built-in" + ) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala index c7d5ee10b..1e62808f7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala @@ -1,25 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the metadata of the raster. */ -case class RST_MetaData(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MetaData](path, MapType(StringType, StringType), expressionConfig) +case class RST_MetaData(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_MetaData](raster, MapType(StringType, StringType), returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the metadata of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val metaData = raster.metadata - buildMapString(metaData) - } + override def rasterTransform(tile: => MosaicRasterTile): Any = buildMapString(tile.getRaster.metadata) + } /** Expression info required for the expression registration for spark SQL. */ @@ -32,7 +30,7 @@ object RST_MetaData extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | {"NC_GLOBAL#acknowledgement":"NOAA Coral Reef Watch Program","NC_GLOBAL#cdm_data_type":"Grid"} | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala new file mode 100644 index 000000000..85621055e --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala @@ -0,0 +1,71 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.operator.NDVI +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.BinaryType + +/** The expression for computing NDVI index. */ +case class RST_NDVI( + rastersExpr: Expression, + redIndex: Expression, + nirIndex: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster2ArgExpression[RST_NDVI]( + rastersExpr, + redIndex, + nirIndex, + BinaryType, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** + * Computes NDVI index. + * @param raster + * The raster to be used. + * @param arg1 + * The red band index. + * @param arg2 + * The nir band index. + * @return + * The raster contains NDVI index. + */ + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + val redInd = arg1.asInstanceOf[Int] + val nirInd = arg2.asInstanceOf[Int] + NDVI.compute(raster, redInd, nirInd) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_NDVI extends WithExpressionInfo { + + override def name: String = "rst_ndvi" + + override def usage: String = + """ + |_FUNC_(expr1, expr2, expr3) - Returns a raster contains NDVI index computed by bands provided by red_index and nir_index. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile, 1, 2); + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_NDVI](3, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala index df4944967..b4694821d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala @@ -1,22 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the number of bands in the raster. */ -case class RST_NumBands(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_NumBands](path, IntegerType, expressionConfig) +case class RST_NumBands(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_NumBands](raster, IntegerType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the number of bands in the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = raster.numBands + override def rasterTransform(tile: => MosaicRasterTile): Any = tile.getRaster.numBands } @@ -30,7 +30,7 @@ object RST_NumBands extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 4 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala index 6cbac4455..1aad1085c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala @@ -1,27 +1,32 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the pixel height of the raster. */ -case class RST_PixelHeight(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelHeight](path, DoubleType, expressionConfig) +case class RST_PixelHeight(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_PixelHeight](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the pixel height of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val raster = tile.getRaster val scaleY = raster.getRaster.GetGeoTransform()(5) val skewX = raster.getRaster.GetGeoTransform()(2) // when there is no skew the height is scaleY, but we cant assume 0-only skew // skew is not to be confused with rotation - math.sqrt(scaleY * scaleY + skewX * skewX) + // TODO - check if this is correct + val result = math.sqrt(scaleY * scaleY + skewX * skewX) + RasterCleaner.dispose(raster) + result } } @@ -40,7 +45,7 @@ object RST_PixelHeight extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala index d87c92bfd..b623c303e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala @@ -1,27 +1,32 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the pixel width of the raster. */ -case class RST_PixelWidth(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelWidth](path, DoubleType, expressionConfig) +case class RST_PixelWidth(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_PixelWidth](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the pixel width of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val raster = tile.getRaster val scaleX = raster.getRaster.GetGeoTransform()(1) val skewY = raster.getRaster.GetGeoTransform()(4) // when there is no skew width is scaleX, but we cant assume 0-only skew // skew is not to be confused with rotation - math.sqrt(scaleX * scaleX + skewY * skewY) + // TODO check if this is correct + val result = math.sqrt(scaleX * scaleX + skewY * skewY) + RasterCleaner.dispose(raster) + result } } @@ -40,7 +45,7 @@ object RST_PixelWidth extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvg.scala index ef0e3dc79..5c0f2ba4a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvg.scala @@ -10,11 +10,11 @@ import org.apache.spark.sql.types.DoubleType /** Returns the average value of the raster within the grid cell. */ case class RST_RasterToGridAvg( - path: Expression, + raster: Expression, resolution: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterToGridExpression[RST_RasterToGridAvg, Double]( - path, + raster, resolution, DoubleType, expressionConfig @@ -34,15 +34,15 @@ object RST_RasterToGridAvg extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns a collection of grid index cells with the average pixel value for each band of the raster. - | The output type is array>>. - | Raster mask is taken into account and only valid pixels are used for the calculation. + |_FUNC_(expr1, expr2) - Returns a collection of grid index cells with the average pixel value for each band of the raster. + | The output type is array>>. + | Raster mask is taken into account and only valid pixels are used for the calculation. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 3); | [[(11223344, 123.4), (11223345, 125.4), ...], [(11223344, 123.1), (11223344, 123.6) ...], ...] | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCount.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCount.scala index 9de6a4483..2fd36a986 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCount.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCount.scala @@ -10,11 +10,11 @@ import org.apache.spark.sql.types.IntegerType /** Returns the number of cells in the raster. */ case class RST_RasterToGridCount( - path: Expression, + raster: Expression, resolution: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterToGridExpression[RST_RasterToGridCount, Int]( - path, + raster, resolution, IntegerType, expressionConfig @@ -34,15 +34,15 @@ object RST_RasterToGridCount extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns a collection of grid index cells with the number of pixels per cell for each band of the raster. - | The output type is array>>. - | Raster mask is taken into account and only valid pixels are used for the calculation. + |_FUNC_(expr1, expr2) - Returns a collection of grid index cells with the number of pixels per cell for each band of the raster. + | The output type is array>>. + | Raster mask is taken into account and only valid pixels are used for the calculation. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 3); | [[(11223344, 123.4), (11223345, 125.4), ...], [(11223344, 123.1), (11223344, 123.6) ...], ...] | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMax.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMax.scala index e2283e33c..55cc88b2b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMax.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMax.scala @@ -10,11 +10,11 @@ import org.apache.spark.sql.types.DoubleType /** Returns the maximum value of the raster in the grid cell. */ case class RST_RasterToGridMax( - path: Expression, + raster: Expression, resolution: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterToGridExpression[RST_RasterToGridMax, Double]( - path, + raster, resolution, DoubleType, expressionConfig @@ -34,15 +34,15 @@ object RST_RasterToGridMax extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns a collection of grid index cells with the max pixel value per cell for each band of the raster. - | The output type is array>>. - | Raster mask is taken into account and only valid pixels are used for the calculation. + |_FUNC_(expr1, expr2) - Returns a collection of grid index cells with the max pixel value per cell for each band of the raster. + | The output type is array>>. + | Raster mask is taken into account and only valid pixels are used for the calculation. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 3); | [[(11223344, 123.4), (11223345, 125.4), ...], [(11223344, 123.1), (11223344, 123.6) ...], ...] | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedian.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedian.scala index 8dfe179a3..4f799d273 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedian.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedian.scala @@ -10,11 +10,11 @@ import org.apache.spark.sql.types.DoubleType /** Returns the median value of the raster. */ case class RST_RasterToGridMedian( - path: Expression, + raster: Expression, resolution: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterToGridExpression[RST_RasterToGridMedian, Double]( - path, + raster, resolution, DoubleType, expressionConfig @@ -36,15 +36,15 @@ object RST_RasterToGridMedian extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns a collection of grid index cells with the median pixel value per cell for each band of the raster. - | The output type is array>>. - | Raster mask is taken into account and only valid pixels are used for the calculation. + |_FUNC_(expr1, expr2) - Returns a collection of grid index cells with the median pixel value per cell for each band of the raster. + | The output type is array>>. + | Raster mask is taken into account and only valid pixels are used for the calculation. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 3); | [[(11223344, 123.4), (11223345, 125.4), ...], [(11223344, 123.1), (11223344, 123.6) ...], ...] | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMin.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMin.scala index 955b49004..541cbbdab 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMin.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMin.scala @@ -10,11 +10,11 @@ import org.apache.spark.sql.types.DoubleType /** Returns the minimum value of the raster in the grid cell. */ case class RST_RasterToGridMin( - path: Expression, + raster: Expression, resolution: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterToGridExpression[RST_RasterToGridMin, Double]( - path, + raster, resolution, DoubleType, expressionConfig @@ -34,15 +34,15 @@ object RST_RasterToGridMin extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns a collection of grid index cells with the min pixel value per cell for each band of the raster. - | The output type is array>>. - | Raster mask is taken into account and only valid pixels are used for the calculation. + |_FUNC_(expr1, expr2) - Returns a collection of grid index cells with the min pixel value per cell for each band of the raster. + | The output type is array>>. + | Raster mask is taken into account and only valid pixels are used for the calculation. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 3); | [[(11223344, 123.4), (11223345, 125.4), ...], [(11223344, 123.1), (11223344, 123.6) ...], ...] | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala index 2f69a5f96..e349d032c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala @@ -1,22 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the world coordinates of the raster (x,y) pixel. */ case class RST_RasterToWorldCoord( - path: Expression, + raster: Expression, x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoord](path, x, y, StringType, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, StringType, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { @@ -25,12 +26,12 @@ case class RST_RasterToWorldCoord( * GeoTransform. This ensures the projection of the raster is respected. * The output is a WKT point. */ - override def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any = { + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { val x = arg1.asInstanceOf[Int] val y = arg2.asInstanceOf[Int] val gt = raster.getRaster.GetGeoTransform() - val (xGeo, yGeo) = rasterAPI.toWorldCoord(gt, x, y) + val (xGeo, yGeo) = GDAL.toWorldCoord(gt, x, y) val geometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) val point = geometryAPI.fromCoords(Seq(xGeo, yGeo)) @@ -46,13 +47,13 @@ object RST_RasterToWorldCoord extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns the (x, y) pixel in world coordinates using geo transform of the raster. + |_FUNC_(expr1, expr2, expr3) - Returns the (x, y) pixel in world coordinates using geo transform of the raster. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a, b, c); + | > SELECT _FUNC_(raster_tile, x, y); | (11.2, 12.3) | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala index cfd74109a..1b6787088 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala @@ -1,21 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the world coordinates of the raster (x,y) pixel. */ case class RST_RasterToWorldCoordX( - path: Expression, + raster: Expression, x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordX](path, x, y, DoubleType, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { @@ -23,12 +24,12 @@ case class RST_RasterToWorldCoordX( * Returns the world coordinates of the raster x pixel by applying * GeoTransform. This ensures the projection of the raster is respected. */ - override def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any = { + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { val x = arg1.asInstanceOf[Int] val y = arg2.asInstanceOf[Int] val gt = raster.getRaster.GetGeoTransform() - val (xGeo, _) = rasterAPI.toWorldCoord(gt, x, y) + val (xGeo, _) = GDAL.toWorldCoord(gt, x, y) xGeo } @@ -41,13 +42,13 @@ object RST_RasterToWorldCoordX extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns the x coordinate of the pixel in world coordinates using geo transform of the raster. + |_FUNC_(expr1, expr2, expr3) - Returns the x coordinate of the pixel in world coordinates using geo transform of the raster. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a, b); + | > SELECT _FUNC_(raster_tile, x, y); | 11.2 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala index 88067c513..65ac5decc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala @@ -1,21 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the world coordinates of the raster (x,y) pixel. */ case class RST_RasterToWorldCoordY( - path: Expression, + raster: Expression, x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordY](path, x, y, DoubleType, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { @@ -23,12 +24,12 @@ case class RST_RasterToWorldCoordY( * Returns the world coordinates of the raster y pixel by applying * GeoTransform. This ensures the projection of the raster is respected. */ - override def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any = { + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { val x = arg1.asInstanceOf[Int] val y = arg2.asInstanceOf[Int] val gt = raster.getRaster.GetGeoTransform() - val (_, yGeo) = rasterAPI.toWorldCoord(gt, x, y) + val (_, yGeo) = GDAL.toWorldCoord(gt, x, y) yGeo } @@ -41,13 +42,13 @@ object RST_RasterToWorldCoordY extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns the y coordinate of the pixel in world coordinates using geo transform of the raster. + |_FUNC_(expr1, expr2, expr3) - Returns the y coordinate of the pixel in world coordinates using geo transform of the raster. |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a, b); + | > SELECT _FUNC_(raster_tile, x, y); | 11.2 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala index 346a5a700..1d5fdedec 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala @@ -1,24 +1,24 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.operator.retile.ReTile +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterGeneratorExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig -import org.apache.hive.common.util.Murmur3 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} /** * Returns a set of new rasters with the specified tile size (tileWidth x * tileHeight). */ case class RST_ReTile( - pathExpr: Expression, + rasterExpr: Expression, tileWidthExpr: Expression, tileHeightExpr: Expression, expressionConfig: MosaicExpressionConfig -) extends RasterGeneratorExpression[RST_ReTile](pathExpr, expressionConfig) +) extends RasterGeneratorExpression[RST_ReTile](rasterExpr, expressionConfig) with NullIntolerant with CodegenFallback { @@ -26,29 +26,13 @@ case class RST_ReTile( * Returns a set of new rasters with the specified tile size (tileWidth x * tileHeight). */ - override def rasterGenerator(raster: MosaicRaster): Seq[(Long, (Int, Int, Int, Int))] = { + override def rasterGenerator(tile: => MosaicRasterTile): Seq[MosaicRasterTile] = { val tileWidthValue = tileWidthExpr.eval().asInstanceOf[Int] val tileHeightValue = tileHeightExpr.eval().asInstanceOf[Int] - - val xSize = raster.xSize - val ySize = raster.ySize - - val xTiles = Math.ceil(xSize / tileWidthValue).toInt - val yTiles = Math.ceil(ySize / tileHeightValue).toInt - - val tiles = for (x <- 0 until xTiles; y <- 0 until yTiles) yield { - val xMin = x * tileWidthValue - val yMin = y * tileHeightValue - val xMax = Math.min(xMin + tileWidthValue, xSize) - val yMax = Math.min(yMin + tileHeightValue, ySize) - val id = Murmur3.hash64(s"${raster.toString}$xMin$yMin$xMax$yMax".getBytes) - (id, (xMin, yMin, xMax, yMax)) - } - - tiles + ReTile.reTile(tile, tileWidthValue, tileHeightValue) } - override def children: Seq[Expression] = Seq(pathExpr, tileWidthExpr, tileHeightExpr) + override def children: Seq[Expression] = Seq(rasterExpr, tileWidthExpr, tileHeightExpr) } @@ -59,16 +43,16 @@ object RST_ReTile extends WithExpressionInfo { override def usage: String = """ - |_FUNC_(expr1) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). + |_FUNC_(expr1, expr2, expr3) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). |""".stripMargin override def example: String = """ | Examples: - | > SELECT _FUNC_(a, b); - | /path/to/raster_tile_1.tif - | /path/to/raster_tile_2.tif - | /path/to/raster_tile_3.tif + | > SELECT _FUNC_(raster_tile, 256, 256); + | {index_id, raster_tile, tile_width, tile_height} + | {index_id, raster_tile, tile_width, tile_height} + | {index_id, raster_tile, tile_width, tile_height} | ... | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala index 92f3058c7..b54506882 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala @@ -1,23 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the rotation angle of the raster. */ -case class RST_Rotation(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Rotation](path, DoubleType, expressionConfig) +case class RST_Rotation(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Rotation](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the rotation angle of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val gt = raster.getRaster.GetGeoTransform() + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val gt = tile.getRaster.getRaster.GetGeoTransform() // arctan of y_skew and x_scale math.atan(gt(4) / gt(1)) } @@ -37,7 +37,7 @@ object RST_Rotation extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 11.2 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala index fe60d905c..293227d37 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala @@ -1,27 +1,27 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ import org.gdal.osr.SpatialReference import scala.util.Try /** Returns the SRID of the raster. */ -case class RST_SRID(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SRID](path, IntegerType, expressionConfig) +case class RST_SRID(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_SRID](raster, IntegerType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the SRID of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { + override def rasterTransform(tile: => MosaicRasterTile): Any = { // Reference: https://gis.stackexchange.com/questions/267321/extracting-epsg-from-a-raster-using-gdal-bindings-in-python - val proj = new SpatialReference(raster.getRaster.GetProjection()) + val proj = new SpatialReference(tile.getRaster.getRaster.GetProjection()) Try(proj.AutoIdentifyEPSG()) Try(proj.GetAttrValue("AUTHORITY", 1).toInt).getOrElse(0) } @@ -41,7 +41,7 @@ object RST_SRID extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 4326 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala index 05882c14a..4deaca6fd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala @@ -1,24 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expressions, FunctionBuilder} -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the scale x of the raster. */ -case class RST_ScaleX(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleX](path, DoubleType, expressionConfig) +case class RST_ScaleX(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_ScaleX](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the scale x of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val scaleX = raster.getRaster.GetGeoTransform()(1) - scaleX + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getRaster.GetGeoTransform()(1) } } @@ -36,7 +35,7 @@ object RST_ScaleX extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala index 8761eca66..5875bbf7a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala @@ -1,24 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the scale y of the raster. */ -case class RST_ScaleY(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleY](path, DoubleType, expressionConfig) +case class RST_ScaleY(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_ScaleY](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the scale y of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val scaleY = raster.getRaster.GetGeoTransform()(5) - scaleY + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getRaster.GetGeoTransform()(5) } } @@ -36,7 +35,7 @@ object RST_ScaleY extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala new file mode 100644 index 000000000..268091a10 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala @@ -0,0 +1,90 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** Returns a raster with the specified no data values. */ +case class RST_SetNoData( + rastersExpr: Expression, + noDataExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_SetNoData]( + rastersExpr, + noDataExpr, + RasterTileType(expressionConfig.getCellIdType), + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + /** + * Returns a raster with the specified no data values. + * @param tile + * The input raster tile. + * @param arg1 + * The no data values. + * @return + * The raster with the specified no data values. + */ + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any): Any = { + val noDataValues = tile.getRaster.getBands.map(_.noDataValue).mkString(" ") + val dstNoDataValues = (arg1 match { + case doubles: Array[Double] => doubles + case d: Double => Array.fill[Double](tile.getRaster.numBands)(d) + case _ => throw new IllegalArgumentException("No data values must be an array of doubles or a double") + }).mkString(" ") + val resultPath = PathUtils.createTmpFilePath(tile.getRaster.uuid.toString, GDAL.getExtension(tile.getDriver)) + val result = GDALWarp.executeWarp( + resultPath, + isTemp = true, + Seq(tile.getRaster), + command = s"""gdalwarp -of ${tile.getDriver} -dstnodata "$dstNoDataValues" -srcnodata "$noDataValues"""" + ) + new MosaicRasterTile( + tile.getIndex, + result, + tile.getParentPath, + tile.getDriver + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_SetNoData extends WithExpressionInfo { + + override def name: String = "rst_set_no_data" + + override def usage: String = + """ + |_FUNC_(expr1, expr2) - Returns a raster clipped by provided vector. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile, 0.0); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | > SELECT _FUNC_(raster_tile, array(0.0, 0.0, 0.0)); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_SetNoData](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala index 297118c25..697b758da 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala @@ -1,24 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the skew x of the raster. */ -case class RST_SkewX(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewX](path, DoubleType, expressionConfig) +case class RST_SkewX(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_SkewX](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the skew x of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val skewX = raster.getRaster.GetGeoTransform()(2) - skewX + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getRaster.GetGeoTransform()(2) } } @@ -36,7 +35,7 @@ object RST_SkewX extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala index 382dd0a72..1fe4893c5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala @@ -1,24 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the skew y of the raster. */ -case class RST_SkewY(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewY](path, DoubleType, expressionConfig) +case class RST_SkewY(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_SkewY](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the skew y of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val skewY = raster.getRaster.GetGeoTransform()(4) - skewY + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getRaster.GetGeoTransform()(4) } } @@ -36,7 +35,7 @@ object RST_SkewY extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala index eb4170529..c90a967c3 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala @@ -1,25 +1,27 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the subdatasets of the raster. */ -case class RST_Subdatasets(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Subdatasets](path, MapType(StringType, StringType), expressionConfig) +case class RST_Subdatasets(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Subdatasets]( + raster, + MapType(StringType, StringType), + returnsRaster = false, + expressionConfig + ) with NullIntolerant with CodegenFallback { /** Returns the subdatasets of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val subdatasets = raster.subdatasets - buildMapString(subdatasets) - } + override def rasterTransform(tile: => MosaicRasterTile): Any = buildMapString(tile.getRaster.subdatasets) } @@ -33,7 +35,7 @@ object RST_Subdatasets extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | {"NETCDF:"ct5km_baa-max-7d_v3.1_20220101.nc":bleaching_alert_area":"[1x3600x7200] N/A (8-bit unsigned integer)", | "NETCDF:"ct5km_baa-max-7d_v3.1_20220101.nc":mask":"[1x3600x7200] mask (8-bit unsigned integer)"} | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdivide.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdivide.scala new file mode 100644 index 000000000..f0756d6fc --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdivide.scala @@ -0,0 +1,54 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.retile.BalancedSubdivision +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterGeneratorExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** Returns a set of new rasters with the specified tile size (In MB). */ +case class RST_Subdivide( + rasterExpr: Expression, + sizeInMB: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterGeneratorExpression[RST_Subdivide](rasterExpr, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns a set of new rasters with the specified tile size (In MB). */ + override def rasterGenerator(tile: => MosaicRasterTile): Seq[MosaicRasterTile] = { + val targetSize = sizeInMB.eval().asInstanceOf[Int] + BalancedSubdivision.splitRaster(tile, targetSize) + } + + override def children: Seq[Expression] = Seq(rasterExpr, sizeInMB) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Subdivide extends WithExpressionInfo { + + override def name: String = "rst_subdivide" + + override def usage: String = + """ + |_FUNC_(expr1, expr2) - Returns a set of new rasters with same aspect ratio that are not larger than the threshold memory footprint. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile, 256); + | {index_id, raster_tile, tile_width, tile_height} + | {index_id, raster_tile, tile_width, tile_height} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Subdivide](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala index 6385b1d61..bcc296afa 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala @@ -1,27 +1,28 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.gdal.gdal.gdal.GDALInfo import org.gdal.gdal.InfoOptions +import org.gdal.gdal.gdal.GDALInfo import java.util.{Vector => JVector} /** Returns the summary info the raster. */ -case class RST_Summary(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Summary](path, StringType, expressionConfig: MosaicExpressionConfig) +case class RST_Summary(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Summary](raster, StringType, returnsRaster = false, expressionConfig: MosaicExpressionConfig) with NullIntolerant with CodegenFallback { /** Returns the summary info the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { + override def rasterTransform(tile: => MosaicRasterTile): Any = { + val raster = tile.getRaster val vector = new JVector[String]() // For other flags check the way gdalinfo.py script is called, InfoOptions expects a collection of same flags. // https://gdal.org/programs/gdalinfo.html @@ -43,7 +44,7 @@ object RST_Summary extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | { | "description":"byte.tif", | "driverShortName":"GTiff", diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Tessellate.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Tessellate.scala new file mode 100644 index 000000000..bb22cdc5b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Tessellate.scala @@ -0,0 +1,64 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.retile.RasterTessellate +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterTessellateGeneratorExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** + * Returns a set of new rasters which are the result of the tessellation of the + * input raster. + */ +case class RST_Tessellate( + rasterExpr: Expression, + resolutionExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterTessellateGeneratorExpression[RST_Tessellate](rasterExpr, resolutionExpr, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** + * Returns a set of new rasters which are the result of the tessellation of + * the input raster. + */ + override def rasterGenerator(tile: => MosaicRasterTile, resolution: Int): Seq[MosaicRasterTile] = { + RasterTessellate.tessellate( + tile.getRaster, + resolution, + indexSystem, + geometryAPI + ) + } + + override def children: Seq[Expression] = Seq(rasterExpr, resolutionExpr) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Tessellate extends WithExpressionInfo { + + override def name: String = "rst_tessellate" + + override def usage: String = + """ + |_FUNC_(expr1, expr2) - Returns a set of new rasters with the specified resolution within configured grid. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile, 3); + | {index_id, raster_tile, tile_width, tile_height} + | {index_id, raster_tile, tile_width, tile_height} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Tessellate](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ToOverlappingTiles.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ToOverlappingTiles.scala new file mode 100644 index 000000000..287d2389d --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ToOverlappingTiles.scala @@ -0,0 +1,65 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.retile.OverlappingTiles +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterGeneratorExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** + * Returns a set of new rasters which are the result of a rolling window over + * the input raster. + */ +case class RST_ToOverlappingTiles( + rasterExpr: Expression, + tileWidthExpr: Expression, + tileHeightExpr: Expression, + overlapExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends RasterGeneratorExpression[RST_ToOverlappingTiles](rasterExpr, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** + * Returns a set of new rasters which are the result of a rolling window + * over the input raster. + */ + override def rasterGenerator(tile: => MosaicRasterTile): Seq[MosaicRasterTile] = { + val tileWidthValue = tileWidthExpr.eval().asInstanceOf[Int] + val tileHeightValue = tileHeightExpr.eval().asInstanceOf[Int] + val overlapValue = overlapExpr.eval().asInstanceOf[Int] + OverlappingTiles.reTile(tile, tileWidthValue, tileHeightValue, overlapValue) + } + + override def children: Seq[Expression] = Seq(rasterExpr, tileWidthExpr, tileHeightExpr, overlapExpr) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_ToOverlappingTiles extends WithExpressionInfo { + + override def name: String = "rst_to_overlapping_tiles" + + override def usage: String = + """ + |_FUNC_(expr1, expr2, expr3, expr4) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). + | The tiles will overlap by the specified amount. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile, 256, 256, 10); + | {index_id, raster_tile, tile_width, tile_height} + | {index_id, raster_tile, tile_width, tile_height} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_ToOverlappingTiles](4, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala new file mode 100644 index 000000000..00d4ac7e0 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala @@ -0,0 +1,43 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types._ + +/** Returns true if the raster is empty. */ +case class RST_TryOpen(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_TryOpen](raster, BooleanType, returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns true if the raster can be opened. */ + override def rasterTransform(tile: => MosaicRasterTile): Any = { + Option(tile.getRaster.getRaster).isDefined + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_TryOpen extends WithExpressionInfo { + + override def name: String = "rst_tryopen" + + override def usage: String = "_FUNC_(expr1) - Returns true if the raster can be opened." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | false + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_TryOpen](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala index ceb34c6fe..7a53e488a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala @@ -1,24 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ -case class RST_UpperLeftX(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftX](path, DoubleType, expressionConfig) +case class RST_UpperLeftX(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_UpperLeftX](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the upper left x of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val upperLeftX = raster.getRaster.GetGeoTransform()(0) - upperLeftX + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getRaster.GetGeoTransform()(0) } } @@ -33,7 +32,7 @@ object RST_UpperLeftX extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala index 984a5aed2..8e6525bab 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala @@ -1,24 +1,23 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the upper left y of the raster. */ -case class RST_UpperLeftY(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftY](path, DoubleType, expressionConfig) +case class RST_UpperLeftY(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_UpperLeftY](raster, DoubleType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the upper left y of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = { - val upperLeftY = raster.getRaster.GetGeoTransform()(3) - upperLeftY + override def rasterTransform(tile: => MosaicRasterTile): Any = { + tile.getRaster.getRaster.GetGeoTransform()(3) } } @@ -33,7 +32,7 @@ object RST_UpperLeftY extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 1.123 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala index 2f017d86b..a8a9a280d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala @@ -1,22 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types._ /** Returns the width of the raster. */ -case class RST_Width(path: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Width](path, IntegerType, expressionConfig) +case class RST_Width(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Width](raster, IntegerType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** Returns the width of the raster. */ - override def rasterTransform(raster: MosaicRaster): Any = raster.xSize + override def rasterTransform(tile: => MosaicRasterTile): Any = tile.getRaster.xSize } @@ -30,7 +30,7 @@ object RST_Width extends WithExpressionInfo { override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile); | 512 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala index 0ff1ca136..adf8b7c19 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala @@ -1,21 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} /** Returns the world coordinate of the raster. */ case class RST_WorldToRasterCoord( - path: Expression, + raster: Expression, x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoord](path, x, y, PixelCoordsType, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, PixelCoordsType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { @@ -23,13 +24,12 @@ case class RST_WorldToRasterCoord( * Returns the x and y of the raster by applying GeoTransform as a tuple of * Integers. This will ensure projection of the raster is respected. */ - override def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any = { + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { val xGeo = arg1.asInstanceOf[Double] val yGeo = arg2.asInstanceOf[Double] val gt = raster.getRaster.GetGeoTransform() - val (x, y) = rasterAPI.fromWorldCoord(gt, xGeo, yGeo) - + val (x, y) = GDAL.fromWorldCoord(gt, xGeo, yGeo) InternalRow.fromSeq(Seq(x, y)) } @@ -40,12 +40,12 @@ object RST_WorldToRasterCoord extends WithExpressionInfo { override def name: String = "rst_worldtorastercoord" - override def usage: String = "_FUNC_(expr1) - Returns x and y coordinates (pixel, line) of the pixel." + override def usage: String = "_FUNC_(expr1, expr2, expr3) - Returns x and y coordinates (pixel, line) of the pixel." override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 1.123, 1.123); | (11, 12) | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala index ddf15879e..00de3af69 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala @@ -1,21 +1,22 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types.IntegerType /** Returns the x coordinate of the raster. */ case class RST_WorldToRasterCoordX( - path: Expression, + raster: Expression, x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordX](path, x, y, IntegerType, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { @@ -23,10 +24,10 @@ case class RST_WorldToRasterCoordX( * Returns the x coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. */ - override def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any = { + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { val xGeo = arg1.asInstanceOf[Double] val gt = raster.getRaster.GetGeoTransform() - rasterAPI.fromWorldCoord(gt, xGeo, 0)._1 + GDAL.fromWorldCoord(gt, xGeo, 0)._1 } } @@ -36,12 +37,12 @@ object RST_WorldToRasterCoordX extends WithExpressionInfo { override def name: String = "rst_worldtorastercoordx" - override def usage: String = "_FUNC_(expr1) - Returns x coordinate (pixel, line) of the pixel." + override def usage: String = "_FUNC_(expr1, expr2, expr3) - Returns x coordinate (pixel, line) of the pixel." override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 1.123, 1.123); | 11 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala index 0249ad1bb..5b3bf09a7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala @@ -1,31 +1,33 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.MosaicRaster +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.types.IntegerType /** Returns the Y coordinate of the raster. */ case class RST_WorldToRasterCoordY( - path: Expression, + raster: Expression, x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordY](path, x, y, IntegerType, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { /** - * Returns the y coordinate of the raster by applying GeoTransform. This - * will ensure projection of the raster is respected. - */ override def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any = { + * Returns the y coordinate of the raster by applying GeoTransform. This + * will ensure projection of the raster is respected. + */ + override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { val xGeo = arg1.asInstanceOf[Double] val gt = raster.getRaster.GetGeoTransform() - rasterAPI.fromWorldCoord(gt, xGeo, 0)._2 + GDAL.fromWorldCoord(gt, xGeo, 0)._2 } } @@ -35,12 +37,12 @@ object RST_WorldToRasterCoordY extends WithExpressionInfo { override def name: String = "rst_worldtorastercoordy" - override def usage: String = "_FUNC_(expr1) - Returns y coordinate (pixel, line) of the pixel." + override def usage: String = "_FUNC_(expr1, expr2, expr3) - Returns y coordinate (pixel, line) of the pixel." override def example: String = """ | Examples: - | > SELECT _FUNC_(a); + | > SELECT _FUNC_(raster_tile, 1.123, 1.123); | 12 | """.stripMargin diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala index 3a1ffa79b..df8bd761e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala @@ -1,12 +1,13 @@ package com.databricks.labs.mosaic.expressions.raster.base -import com.databricks.labs.mosaic.core.raster.MosaicRaster -import com.databricks.labs.mosaic.core.raster.api.RasterAPI +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} import org.apache.spark.sql.types.DataType -import org.apache.spark.unsafe.types.UTF8String import scala.reflect.ClassTag @@ -15,8 +16,9 @@ import scala.reflect.ClassTag * the boilerplate code needed to create a function builder for a given * expression. It minimises amount of code needed to create a new expression. * - * @param pathExpr - * The expression for the raster path. + * @param rasterExpr + * The raster expression. It can be a path to a raster file or a byte array + * containing the raster file content. * @param arg1Expr * The expression for the first argument. * @param outputType @@ -27,22 +29,19 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( - pathExpr: Expression, + rasterExpr: Expression, arg1Expr: Expression, outputType: DataType, + returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression with NullIntolerant - with Serializable { + with Serializable + with RasterExpressionSerialization { - /** - * The raster API to be used. Enable the raster so that subclasses dont - * need to worry about this. - */ - protected val rasterAPI: RasterAPI = RasterAPI(expressionConfig.getRasterAPI) - rasterAPI.enable() + GDAL.enable() - override def left: Expression = pathExpr + override def left: Expression = rasterExpr override def right: Expression = arg1Expr @@ -50,7 +49,7 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( override def dataType: DataType = outputType /** - * The function to be overriden by the extending class. It is called when + * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to * the expression. It abstracts spark serialization from the caller. * @param raster @@ -60,29 +59,30 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( * @return * A result of the expression. */ - def rasterTransform(raster: MosaicRaster, arg1: Any): Any + def rasterTransform(raster: => MosaicRasterTile, arg1: Any): Any /** * Evaluation of the expression. It evaluates the raster path and the loads * the raster from the path. It handles the clean up of the raster before * returning the results. - * @param inputPath - * The path to the raster. It is a UTF8String. * + * @param input + * The input to the expression. It can be a path to a raster file or a + * byte array containing the raster file content. * @param arg1 * The first argument. - * * @return * The result of the expression. */ - override def nullSafeEval(inputPath: Any, arg1: Any): Any = { - val path = inputPath.asInstanceOf[UTF8String].toString - - val raster = rasterAPI.raster(path) - val result = rasterTransform(raster, arg1) - - raster.cleanUp() - result + //noinspection DuplicatedCode + override def nullSafeEval(input: Any, arg1: Any): Any = { + GDAL.enable() + val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val result = rasterTransform(tile, arg1) + val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + // passed by name makes things re-evaluated + RasterCleaner.dispose(tile) + serialized } override def makeCopy(newArgs: Array[AnyRef]): Expression = GenericExpressionFactory.makeCopyImpl[T](this, newArgs, 2, expressionConfig) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala index 338dab4ec..0146b9380 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala @@ -1,12 +1,14 @@ package com.databricks.labs.mosaic.expressions.raster.base -import com.databricks.labs.mosaic.core.raster.MosaicRaster -import com.databricks.labs.mosaic.core.raster.api.RasterAPI +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, TernaryExpression} import org.apache.spark.sql.types.DataType -import org.apache.spark.unsafe.types.UTF8String import scala.reflect.ClassTag @@ -14,8 +16,9 @@ import scala.reflect.ClassTag * Base class for all raster expressions that take two arguments. It provides * the boilerplate code needed to create a function builder for a given * expression. It minimises amount of code needed to create a new expression. - * @param pathExpr - * The expression for the raster path. + * @param rasterExpr + * The raster expression. It can be a path to a raster file or a byte array + * containing the raster file content. * @param arg1Expr * The expression for the first argument. * @param arg2Expr @@ -28,23 +31,20 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( - pathExpr: Expression, + rasterExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, outputType: DataType, + returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression with NullIntolerant - with Serializable { + with Serializable + with RasterExpressionSerialization { - /** - * The raster API to be used. Enable the raster so that subclasses dont - * need to worry about this. - */ - protected val rasterAPI: RasterAPI = RasterAPI(expressionConfig.getRasterAPI) - rasterAPI.enable() + GDAL.enable() - override def first: Expression = pathExpr + override def first: Expression = rasterExpr override def second: Expression = arg1Expr @@ -54,7 +54,7 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( override def dataType: DataType = outputType /** - * The function to be overriden by the extending class. It is called when + * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to * the expression. It abstracts spark serialization from the caller. * @param raster @@ -66,31 +66,32 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( * @return * A result of the expression. */ - def rasterTransform(raster: MosaicRaster, arg1: Any, arg2: Any): Any + def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any /** * Evaluation of the expression. It evaluates the raster path and the loads * the raster from the path. It handles the clean up of the raster before * returning the results. - * @param inputPath - * The path to the raster. It is a UTF8String. * + * @param input + * The input raster. It can be a path to a raster file or a byte array + * containing the raster file content. * @param arg1 * The first argument. * @param arg2 * The second argument. - * * @return * The result of the expression. */ - override def nullSafeEval(inputPath: Any, arg1: Any, arg2: Any): Any = { - val path = inputPath.asInstanceOf[UTF8String].toString - - val raster = rasterAPI.raster(path) + //noinspection DuplicatedCode + override def nullSafeEval(input: Any, arg1: Any, arg2: Any): Any = { + GDAL.enable() + val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val raster = tile.getRaster val result = rasterTransform(raster, arg1, arg2) - - raster.cleanUp() - result + val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + RasterCleaner.dispose(tile) + serialized } override def makeCopy(newArgs: Array[AnyRef]): Expression = GenericExpressionFactory.makeCopyImpl[T](this, newArgs, 3, expressionConfig) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala new file mode 100644 index 000000000..928b994b6 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala @@ -0,0 +1,92 @@ +package com.databricks.labs.mosaic.expressions.raster.base + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{ArrayType, DataType} + +import scala.reflect.ClassTag + +/** + * Base class for all raster expressions that take two arguments. It provides + * the boilerplate code needed to create a function builder for a given + * expression. It minimises amount of code needed to create a new expression. + * + * @param rastersExpr + * The rasters expression. It is an array column containing rasters as either + * paths or as content byte arrays. + * @param outputType + * The output type of the result. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + * @tparam T + * The type of the extending class. + */ +abstract class RasterArrayExpression[T <: Expression: ClassTag]( + rastersExpr: Expression, + outputType: DataType, + returnsRaster: Boolean, + expressionConfig: MosaicExpressionConfig +) extends UnaryExpression + with NullIntolerant + with Serializable + with RasterExpressionSerialization { + + GDAL.enable() + + override def child: Expression = rastersExpr + + /** Output Data Type */ + override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType + + /** + * The function to be overridden by the extending class. It is called when + * the expression is evaluated. It provides the rasters to the expression. + * It abstracts spark serialization from the caller. + * @param rasters + * The sequence of rasters to be used. + * @return + * A result of the expression. + */ + def rasterTransform(rasters: => Seq[MosaicRasterTile]): Any + + /** + * Evaluation of the expression. It evaluates the raster path and the loads + * the raster from the path. It handles the clean up of the raster before + * returning the results. + * @param input + * The input to the expression. It is an array containing paths to raster + * files or byte arrays containing the raster files contents. + * + * @return + * The result of the expression. + */ + override def nullSafeEval(input: Any): Any = { + GDAL.enable() + val rasterDT = rastersExpr.dataType.asInstanceOf[ArrayType].elementType + val arrayData = input.asInstanceOf[ArrayData] + val n = arrayData.numElements() + val tiles = (0 until n) + .map(i => + MosaicRasterTile + .deserialize(arrayData.get(i, rasterDT).asInstanceOf[InternalRow], expressionConfig.getCellIdType) + ) + + val result = rasterTransform(tiles) + val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + tiles.foreach(t => RasterCleaner.dispose(t)) + serialized + } + + override def makeCopy(newArgs: Array[AnyRef]): Expression = GenericExpressionFactory.makeCopyImpl[T](this, newArgs, 1, expressionConfig) + + override def withNewChildInternal( + newFirst: Expression + ): Expression = makeCopy(Array(newFirst)) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala index 04312de68..1efcfb553 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala @@ -1,12 +1,14 @@ package com.databricks.labs.mosaic.expressions.raster.base -import com.databricks.labs.mosaic.core.raster.{MosaicRaster, MosaicRasterBand} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterBandGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} import org.apache.spark.sql.types.DataType -import org.apache.spark.unsafe.types.UTF8String import scala.reflect.ClassTag @@ -15,8 +17,10 @@ import scala.reflect.ClassTag * provides the boilerplate code needed to create a function builder for a * given expression. It minimises amount of code needed to create a new * expression. - * @param pathExpr - * The expression for the raster path. + * @param rasterExpr + * The path to the raster if MOSAIC_RASTER_STORAGE is set to + * MOSAIC_RASTER_STORAGE_DISK. The bytes of the raster if + * MOSAIC_RASTER_STORAGE is set to MOSAIC_RASTER_STORAGE_BYTE. * @param bandExpr * The expression for the band index. * @param outputType @@ -27,22 +31,19 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterBandExpression[T <: Expression: ClassTag]( - pathExpr: Expression, + rasterExpr: Expression, bandExpr: Expression, outputType: DataType, + returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression with NullIntolerant - with Serializable { + with Serializable + with RasterExpressionSerialization { - /** - * The raster API to be used. Enable the raster so that subclasses dont - * need to worry about this. - */ - protected val rasterAPI: RasterAPI = RasterAPI(expressionConfig.getRasterAPI) - rasterAPI.enable() + GDAL.enable() - override def left: Expression = pathExpr + override def left: Expression = rasterExpr override def right: Expression = bandExpr @@ -50,7 +51,7 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( override def dataType: DataType = outputType /** - * The function to be overriden by the extending class. It is called when + * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster band to the * expression. It abstracts spark serialization from the caller. * @param raster @@ -60,7 +61,7 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( * @return * The result of the expression. */ - def bandTransform(raster: MosaicRaster, band: MosaicRasterBand): Any + def bandTransform(raster: => MosaicRasterTile, band: MosaicRasterBandGDAL): Any /** * Evaluation of the expression. It evaluates the raster path and the loads @@ -68,26 +69,27 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( * specified band. It handles the clean up of the raster before returning * the results. * - * @param inputPath - * The path to the raster. It is a UTF8String. - * + * @param inputRaster + * The path to the raster if MOSAIC_RASTER_STORAGE is set to + * MOSAIC_RASTER_STORAGE_DISK. The bytes of the raster if + * MOSAIC_RASTER_STORAGE is set to MOSAIC_RASTER_STORAGE_BYTE. * @param inputBand * The band index to be used. It is an Int. - * * @return * The result of the expression. */ - override def nullSafeEval(inputPath: Any, inputBand: Any): Any = { - val path = inputPath.asInstanceOf[UTF8String].toString + // noinspection DuplicatedCode + override def nullSafeEval(inputRaster: Any, inputBand: Any): Any = { + GDAL.enable() + val tile = MosaicRasterTile.deserialize(inputRaster.asInstanceOf[InternalRow], expressionConfig.getCellIdType) val bandIndex = inputBand.asInstanceOf[Int] - val raster = rasterAPI.raster(path) - val band = raster.getBand(bandIndex) - val result = bandTransform(raster, band) - - raster.cleanUp() + val band = tile.getRaster.getBand(bandIndex) + val result = bandTransform(tile, band) - result + val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + RasterCleaner.dispose(tile) + serialized } override def makeCopy(newArgs: Array[AnyRef]): Expression = GenericExpressionFactory.makeCopyImpl[T](this, newArgs, 2, expressionConfig) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala index fe44baf12..2207424a5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala @@ -1,12 +1,14 @@ package com.databricks.labs.mosaic.expressions.raster.base -import com.databricks.labs.mosaic.core.raster.MosaicRaster -import com.databricks.labs.mosaic.core.raster.api.RasterAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} import org.apache.spark.sql.types.DataType -import org.apache.spark.unsafe.types.UTF8String import scala.reflect.ClassTag @@ -14,8 +16,10 @@ import scala.reflect.ClassTag * Base class for all raster expressions that take no arguments. It provides * the boilerplate code needed to create a function builder for a given * expression. It minimises amount of code needed to create a new expression. - * @param pathExpr - * The expression for the raster path. + * @param rasterExpr + * The expression for the raster. If the raster is stored on disc, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. * @param outputType * The output type of the result. * @param expressionConfig @@ -24,27 +28,32 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterExpression[T <: Expression: ClassTag]( - pathExpr: Expression, + rasterExpr: Expression, outputType: DataType, + returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression with NullIntolerant - with Serializable { + with Serializable + with RasterExpressionSerialization { /** * The raster API to be used. Enable the raster so that subclasses dont * need to worry about this. */ - protected val rasterAPI: RasterAPI = RasterAPI(expressionConfig.getRasterAPI) - rasterAPI.enable() + GDAL.enable() - override def child: Expression = pathExpr + protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val cellIdDataType: DataType = indexSystem.getCellIdDataType + + override def child: Expression = rasterExpr /** Output Data Type */ override def dataType: DataType = outputType /** - * The function to be overriden by the extending class. It is called when + * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster to the expression. * It abstracts spark serialization from the caller. * @param raster @@ -52,26 +61,25 @@ abstract class RasterExpression[T <: Expression: ClassTag]( * @return * The result of the expression. */ - def rasterTransform(raster: MosaicRaster): Any + def rasterTransform(raster: => MosaicRasterTile): Any /** * Evaluation of the expression. It evaluates the raster path and the loads * the raster from the path. It handles the clean up of the raster before * returning the results. - * @param inputPath - * The path to the raster. It is a UTF8String. + * @param input + * The input raster as either a path or bytes. * * @return * The result of the expression. */ - override def nullSafeEval(inputPath: Any): Any = { - val path = inputPath.asInstanceOf[UTF8String].toString - - val raster = rasterAPI.raster(path) - val result = rasterTransform(raster) - - raster.cleanUp() - result + override def nullSafeEval(input: Any): Any = { + GDAL.enable() + val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], cellIdDataType) + val result = rasterTransform(tile) + val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + RasterCleaner.dispose(tile) + serialized } override def makeCopy(newArgs: Array[AnyRef]): Expression = GenericExpressionFactory.makeCopyImpl[T](this, newArgs, 1, expressionConfig) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala new file mode 100644 index 000000000..0087314f9 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala @@ -0,0 +1,50 @@ +package com.databricks.labs.mosaic.expressions.raster.base + +import com.databricks.labs.mosaic.core.index.IndexSystemFactory +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Base trait for raster serialization. It is used to serialize the result of + * the expression. + */ +trait RasterExpressionSerialization { + + /** + * Serializes the result of the expression. If the expression returns a + * raster, the raster is serialized. If the expression returns a scalar, + * the scalar is returned. + * @param data + * The result of the expression. + * @param returnsRaster + * Whether the expression returns a raster. + * @param outputDataType + * The output data type of the expression. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + * @return + * The serialized result of the expression. + */ + def serialize( + data: => Any, + returnsRaster: Boolean, + outputDataType: DataType, + expressionConfig: MosaicExpressionConfig + ): Any = { + if (returnsRaster) { + val tile = data.asInstanceOf[MosaicRasterTile] + val checkpoint = expressionConfig.getRasterCheckpoint + val rasterType = outputDataType.asInstanceOf[StructType].fields(1).dataType + val result = tile + .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) + .serialize(rasterType, checkpoint) + RasterCleaner.dispose(tile) + result + } else { + data + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala index c52dc6d56..8cb68c49d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala @@ -1,13 +1,16 @@ package com.databricks.labs.mosaic.expressions.raster.base -import com.databricks.labs.mosaic.core.raster.MosaicRaster -import com.databricks.labs.mosaic.core.raster.api.RasterAPI +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig -import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, NullIntolerant} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, NullIntolerant} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import scala.reflect.ClassTag @@ -19,28 +22,33 @@ import scala.reflect.ClassTag * rasters based on the input raster. The new rasters are written in the * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. - * @param inPathExpr - * The expression for the raster path. + * @param rasterExpr + * The expression for the raster. If the raster is stored on disc, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T * The type of the extending class. */ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( - inPathExpr: Expression, + rasterExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator with NullIntolerant with Serializable { + GDAL.enable() + + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + val uuid: String = java.util.UUID.randomUUID().toString.replace("-", "_") - /** - * The raster API to be used. Enable the raster so that subclasses dont - * need to worry about this. - */ - protected val rasterAPI: RasterAPI = RasterAPI(expressionConfig.getRasterAPI) - rasterAPI.enable() + protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + + protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val cellIdDataType: DataType = indexSystem.getCellIdDataType override def position: Boolean = false @@ -51,32 +59,30 @@ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( * needs to be wrapped in a StructType. The actually type is that of the * structs element. */ - override def elementSchema: StructType = StructType(Array(StructField("path", StringType))) + override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) /** - * The function to be overriden by the extending class. It is called when + * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster band to the * expression. It abstracts spark serialization from the caller. * @param raster * The raster to be used. * @return - * Sequence of subrasters = (id, reference to the input raster, extent of - * the output raster, unified mask for all bands). + * Sequence of generated new rasters to be written. */ - def rasterGenerator(raster: MosaicRaster): Seq[(Long, (Int, Int, Int, Int))] + def rasterGenerator(raster: => MosaicRasterTile): Seq[MosaicRasterTile] override def eval(input: InternalRow): TraversableOnce[InternalRow] = { - val inPath = inPathExpr.eval(input).asInstanceOf[UTF8String].toString - val checkpointPath = expressionConfig.getRasterCheckpoint - - val raster = rasterAPI.raster(inPath) - val result = rasterGenerator(raster) + GDAL.enable() + val tile = MosaicRasterTile.deserialize(rasterExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType) + val generatedRasters = rasterGenerator(tile) - for ((id, extent) <- result) yield { - val outPath = raster.saveCheckpoint(uuid, id, extent, checkpointPath) - InternalRow.fromSeq(Seq(UTF8String.fromString(outPath))) - } + // Writing rasters disposes of the written raster + val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize()) + generatedRasters.foreach(gr => RasterCleaner.dispose(gr)) + RasterCleaner.dispose(tile) + rows.map(row => InternalRow.fromSeq(Seq(row))) } override def makeCopy(newArgs: Array[AnyRef]): Expression = diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGridExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGridExpression.scala new file mode 100644 index 000000000..b1e83de1b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGridExpression.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.mosaic.expressions.raster.base + +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, MosaicRasterGDAL} + +/** + * Base trait for raster grid expressions. It provides the boilerplate code + * needed to create a function builder for a given expression. It minimises + * amount of code needed to create a new expression. + */ +trait RasterGridExpression { + + /** + * Transforms a pixel to a cell ID and a value. + * @param gt + * The geotransform of the raster. + * @param indexSystem + * The index system to be used. + * @param resolution + * The resolution of the index system. + * @param x + * X coordinate of the pixel. + * @param y + * Y coordinate of the pixel. + * @param value + * The value of the pixel. + * @return + * A tuple containing the cell ID and the value. + */ + def pixelTransformer( + gt: Seq[Double], + indexSystem: IndexSystem, + resolution: Int + )(x: Int, y: Int, value: Double): (Long, Double) = { + val offset = 0.5 // This centers the point to the pixel centroid + val xOffset = offset + x + val yOffset = offset + y + val xGeo = gt.head + xOffset * gt(1) + yOffset * gt(2) + val yGeo = gt(3) + xOffset * gt(4) + yOffset * gt(5) + val cellID = indexSystem.pointToIndex(xGeo, yGeo, resolution) + (cellID, value) + } + + /** + * Transforms a raster to a sequence of maps. Each map contains cell IDs + * and values for a given band. + * @param raster + * The raster to be transformed. + * @param indexSystem + * The index system to be used. + * @param resolution + * The resolution of the index system. + * @return + * A sequence of maps. Each map contains cell IDs and values for a given + * band. + */ + def griddedPixels( + raster: => MosaicRasterGDAL, + indexSystem: IndexSystem, + resolution: Int + ): Seq[Map[Long, Seq[Double]]] = { + val gt = raster.getRaster.GetGeoTransform() + val bandTransform = (band: MosaicRasterBandGDAL) => { + val results = band.transformValues[(Long, Double)](pixelTransformer(gt, indexSystem, resolution), (0L, -1.0)) + results + // Filter out default cells. We don't want to return them since they are masked in original raster. + // We use 0L as a dummy cell ID for default cells. + .map(row => row.filter(_._1 != 0L)) + .filterNot(_.isEmpty) + .flatten + .groupBy(_._1) // Group by cell ID. + } + val transformed = raster.transformBands(bandTransform) + transformed.map(band => band.mapValues(values => values.map(_._2))) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala new file mode 100644 index 000000000..8e02543c6 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala @@ -0,0 +1,99 @@ +package com.databricks.labs.mosaic.expressions.raster.base + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, NullIntolerant} +import org.apache.spark.sql.types._ + +import scala.reflect.ClassTag + +/** + * Base class for all raster generator expressions that take no arguments. It + * provides the boilerplate code needed to create a function builder for a + * given expression. It minimises amount of code needed to create a new + * expression. These expressions are used to generate a collection of new + * rasters based on the input raster. The new rasters are written in the + * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not + * supported, please flatten beforehand. + * + * @param rasterExpr + * The expression for the raster. If the raster is stored on disc, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + * @tparam T + * The type of the extending class. + */ +abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( + rasterExpr: Expression, + resolutionExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends CollectionGenerator + with NullIntolerant + with Serializable { + + GDAL.enable() + + val uuid: String = java.util.UUID.randomUUID().toString.replace("-", "_") + + val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + + override def position: Boolean = false + + override def inline: Boolean = false + + /** + * Generators expressions require an abstraction for element type. Always + * needs to be wrapped in a StructType. The actually type is that of the + * structs element. + */ + override def elementSchema: StructType = StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType)))) + + /** + * The function to be overridden by the extending class. It is called when + * the expression is evaluated. It provides the raster band to the + * expression. It abstracts spark serialization from the caller. + * @param raster + * The raster to be used. + * @return + * Sequence of generated new rasters to be written. + */ + def rasterGenerator(raster: => MosaicRasterTile, resolution: Int): Seq[MosaicRasterTile] + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + GDAL.enable() + val tile = MosaicRasterTile + .deserialize( + rasterExpr.eval(input).asInstanceOf[InternalRow], + indexSystem.getCellIdDataType + ) + val inResolution: Int = indexSystem.getResolution(resolutionExpr.eval(input)) + val generatedChips = rasterGenerator(tile, inResolution) + .map(chip => chip.formatCellId(indexSystem)) + + val rows = generatedChips + .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize()))) + + RasterCleaner.dispose(tile) + generatedChips.foreach(chip => RasterCleaner.dispose(chip)) + generatedChips.foreach(chip => RasterCleaner.dispose(chip.getRaster)) + + rows.iterator + } + + override def makeCopy(newArgs: Array[AnyRef]): Expression = + GenericExpressionFactory.makeCopyImpl[T](this, newArgs, children.length, expressionConfig) + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = makeCopy(newChildren.toArray) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala index ab9d8d2e9..4d39c35cb 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala @@ -2,12 +2,14 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} -import com.databricks.labs.mosaic.core.raster.{MosaicRaster, MosaicRasterBand} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.raster.RasterToGridType import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -20,8 +22,9 @@ import scala.reflect.ClassTag * Mosaic. All cells are projected to spatial coordinates and then to grid * index system. The pixels are grouped by cell ids and then combined to form a * grid -> value/measure collection per band of the raster. - * @param pathExpr - * The expression for the raster path. + * @param rasterExpr + * The raster expression. It can be a path to a raster file or a byte array + * containing the raster file content. * @param measureType * The output type of the result. * @param expressionConfig @@ -30,11 +33,12 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterToGridExpression[T <: Expression: ClassTag, P]( - pathExpr: Expression, + rasterExpr: Expression, resolution: Expression, measureType: DataType, expressionConfig: MosaicExpressionConfig -) extends Raster1ArgExpression[T](pathExpr, resolution, RasterToGridType(expressionConfig.getCellIdType, measureType), expressionConfig) +) extends Raster1ArgExpression[T](rasterExpr, resolution, RasterToGridType(expressionConfig.getCellIdType, measureType), returnsRaster = false, expressionConfig) + with RasterGridExpression with NullIntolerant with Serializable { @@ -47,32 +51,22 @@ abstract class RasterToGridExpression[T <: Expression: ClassTag, P]( * result is a Sequence of (cellId, measure) of each band of the raster. It * applies the values combiner on the measures of each cell. For no * combine, use the identity function. - * @param raster + * @param tile * The raster to be used. * @return * Sequence of (cellId, measure) of each band of the raster. */ - override def rasterTransform(raster: MosaicRaster, arg1: Any): Any = { - val gt = raster.getRaster.GetGeoTransform() + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any): Any = { + GDAL.enable() val resolution = arg1.asInstanceOf[Int] - val bandTransform = (band: MosaicRasterBand) => { - val results = band.transformValues[(Long, Double)](pixelTransformer(gt, resolution), (0L, -1.0)) - results - // Filter out default cells. We don't want to return them since they are masked in original raster. - // We use 0L as a dummy cell ID for default cells. - .map(row => row.filter(_._1 != 0L)) - .filterNot(_.isEmpty) - .flatten - .groupBy(_._1) // Group by cell ID. - .mapValues(values => valuesCombiner(values.map(_._2))) // Apply combiner that is overridden in subclasses. - } - val transformed = raster.transformBands(bandTransform) - - serialize(transformed) + val transformed = griddedPixels(tile.getRaster, indexSystem, resolution) + val results = transformed.map(_.mapValues(valuesCombiner)) + RasterCleaner.dispose(tile) + serialize(results) } /** - * The method to be overriden to specify how the pixel values are combined + * The method to be overridden to specify how the pixel values are combined * within a cell. * @param values * The values to be combined. @@ -81,16 +75,6 @@ abstract class RasterToGridExpression[T <: Expression: ClassTag, P]( */ def valuesCombiner(values: Seq[Double]): P - private def pixelTransformer(gt: Seq[Double], resolution: Int)(x: Int, y: Int, value: Double): (Long, Double) = { - val offset = 0.5 // This centers the point to the pixel centroid - val xOffset = offset + x - val yOffset = offset + y - val xGeo = gt(0) + xOffset * gt(1) + yOffset * gt(2) - val yGeo = gt(3) + xOffset * gt(4) + yOffset * gt(5) - val cellID = indexSystem.pointToIndex(xGeo, yGeo, resolution) - (cellID, value) - } - /** * Serializes the result of the raster transform to the desired output * type. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala index fe03cb672..a229aae89 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala @@ -21,8 +21,8 @@ package object raster { * The measure type of the resulting pixel value. * * @return - * The datatype to be used for serialization of the result of - * [[RasterToGridExpression]]. + * The datatype to be used for serialization of the result of + * [[com.databricks.labs.mosaic.expressions.raster.base.RasterToGridExpression]]. */ def RasterToGridType(cellIDType: DataType, measureType: DataType): DataType = { ArrayType( diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 8f9bf92d7..88c53d6fb 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -4,7 +4,6 @@ import com.databricks.labs.mosaic._ import com.databricks.labs.mosaic.core.crs.CRSBoundsProvider import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem -import com.databricks.labs.mosaic.core.raster.api.RasterAPI import com.databricks.labs.mosaic.core.types.ChipType import com.databricks.labs.mosaic.datasource.multiread.MosaicDataFrameReader import com.databricks.labs.mosaic.expressions.constructors._ @@ -25,7 +24,7 @@ import org.apache.spark.sql.types.{LongType, StringType} import scala.reflect.runtime.universe //noinspection DuplicatedCode -class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAPI: RasterAPI) extends Serializable with Logging { +class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends Serializable with Logging { // Make spark aware of the mosaic setup // Check the DBR type and raise appropriate warnings @@ -36,13 +35,12 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP spark.conf.set(MOSAIC_INDEX_SYSTEM, indexSystem.name) spark.conf.set(MOSAIC_GEOMETRY_API, geometryAPI.name) - spark.conf.set(MOSAIC_RASTER_API, rasterAPI.name) import org.apache.spark.sql.adapters.{Column => ColumnAdapter} + // noinspection ScalaWeakerAccess val mirror: universe.Mirror = universe.runtimeMirror(getClass.getClassLoader) val expressionConfig: MosaicExpressionConfig = MosaicExpressionConfig(spark) - def setCellIdDataType(dataType: String): Unit = if (dataType == "string") { indexSystem.setCellIdDataType(StringType) @@ -52,6 +50,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP throw new Error(s"Unsupported data type: $dataType") } + // noinspection ScalaWeakerAccess def registerProductH3(registry: FunctionRegistry, dbName: Option[String]): Unit = { aliasFunction(registry, "grid_longlatascellid", dbName, "h3_longlatash3", None) aliasFunction(registry, "grid_polyfill", dbName, "h3_polyfillash3", None) @@ -59,6 +58,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP aliasFunction(registry, "grid_distance", dbName, "h3_distance", None) } + // noinspection ScalaWeakerAccess def aliasFunction( registry: FunctionRegistry, alias: String, @@ -144,12 +144,12 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP mosaicRegistry.registerExpression[ST_Area](expressionConfig) mosaicRegistry.registerExpression[ST_Buffer](expressionConfig) mosaicRegistry.registerExpression[ST_BufferLoop](expressionConfig) + mosaicRegistry.registerExpression[ST_BufferCapStyle](expressionConfig) mosaicRegistry.registerExpression[ST_Centroid](expressionConfig) mosaicRegistry.registerExpression[ST_Contains](expressionConfig) mosaicRegistry.registerExpression[ST_ConvexHull](expressionConfig) mosaicRegistry.registerExpression[ST_Distance](expressionConfig) mosaicRegistry.registerExpression[ST_Difference](expressionConfig) - mosaicRegistry.registerExpression[ST_Buffer](expressionConfig) mosaicRegistry.registerExpression[ST_Envelope](expressionConfig) mosaicRegistry.registerExpression[ST_GeometryType](expressionConfig) mosaicRegistry.registerExpression[ST_HasValidCoordinates](expressionConfig) @@ -179,6 +179,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP mosaicRegistry.registerExpression[ST_Y](expressionConfig) mosaicRegistry.registerExpression[ST_Haversine](expressionConfig) + // noinspection ScalaDeprecation registry.registerFunction( FunctionIdentifier("st_centroid2D", database), ST_Centroid.legacyInfo(database, "st_centroid2D"), @@ -253,10 +254,20 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP /** RasterAPI dependent functions */ mosaicRegistry.registerExpression[RST_BandMetaData](expressionConfig) + mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig) + mosaicRegistry.registerExpression[RST_Clip](expressionConfig) + mosaicRegistry.registerExpression[RST_CombineAvg](expressionConfig) mosaicRegistry.registerExpression[RST_GeoReference](expressionConfig) + mosaicRegistry.registerExpression[RST_GetNoData](expressionConfig) + mosaicRegistry.registerExpression[RST_GetSubdataset](expressionConfig) + mosaicRegistry.registerExpression[RST_Height](expressionConfig) + mosaicRegistry.registerExpression[RST_InitNoData](expressionConfig) mosaicRegistry.registerExpression[RST_IsEmpty](expressionConfig) mosaicRegistry.registerExpression[RST_MemSize](expressionConfig) + mosaicRegistry.registerExpression[RST_Merge](expressionConfig) + mosaicRegistry.registerExpression[RST_FromBands](expressionConfig) mosaicRegistry.registerExpression[RST_MetaData](expressionConfig) + mosaicRegistry.registerExpression[RST_NDVI](expressionConfig) mosaicRegistry.registerExpression[RST_NumBands](expressionConfig) mosaicRegistry.registerExpression[RST_PixelWidth](expressionConfig) mosaicRegistry.registerExpression[RST_PixelHeight](expressionConfig) @@ -272,15 +283,20 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP mosaicRegistry.registerExpression[RST_Rotation](expressionConfig) mosaicRegistry.registerExpression[RST_ScaleX](expressionConfig) mosaicRegistry.registerExpression[RST_ScaleY](expressionConfig) + mosaicRegistry.registerExpression[RST_SetNoData](expressionConfig) mosaicRegistry.registerExpression[RST_SkewX](expressionConfig) mosaicRegistry.registerExpression[RST_SkewY](expressionConfig) mosaicRegistry.registerExpression[RST_SRID](expressionConfig) mosaicRegistry.registerExpression[RST_Subdatasets](expressionConfig) mosaicRegistry.registerExpression[RST_Summary](expressionConfig) + mosaicRegistry.registerExpression[RST_Tessellate](expressionConfig) + mosaicRegistry.registerExpression[RST_FromFile](expressionConfig) + mosaicRegistry.registerExpression[RST_ToOverlappingTiles](expressionConfig) + mosaicRegistry.registerExpression[RST_TryOpen](expressionConfig) + mosaicRegistry.registerExpression[RST_Subdivide](expressionConfig) mosaicRegistry.registerExpression[RST_UpperLeftX](expressionConfig) mosaicRegistry.registerExpression[RST_UpperLeftY](expressionConfig) mosaicRegistry.registerExpression[RST_Width](expressionConfig) - mosaicRegistry.registerExpression[RST_Height](expressionConfig) mosaicRegistry.registerExpression[RST_WorldToRasterCoord](expressionConfig) mosaicRegistry.registerExpression[RST_WorldToRasterCoordX](expressionConfig) mosaicRegistry.registerExpression[RST_WorldToRasterCoordY](expressionConfig) @@ -291,16 +307,36 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP ST_IntersectionAggregate.registryExpressionInfo(database), (exprs: Seq[Expression]) => ST_IntersectionAggregate(exprs(0), exprs(1), geometryAPI.name, indexSystem, 0, 0) ) + registry.registerFunction( + FunctionIdentifier("st_intersection_agg", database), + ST_IntersectionAggregate.registryExpressionInfo(database), + (exprs: Seq[Expression]) => ST_IntersectionAggregate(exprs(0), exprs(1), geometryAPI.name, indexSystem, 0, 0) + ) registry.registerFunction( FunctionIdentifier("st_intersects_aggregate", database), ST_IntersectsAggregate.registryExpressionInfo(database), (exprs: Seq[Expression]) => ST_IntersectsAggregate(exprs(0), exprs(1), geometryAPI.name) ) + registry.registerFunction( + FunctionIdentifier("st_intersects_agg", database), + ST_IntersectsAggregate.registryExpressionInfo(database), + (exprs: Seq[Expression]) => ST_IntersectsAggregate(exprs(0), exprs(1), geometryAPI.name) + ) registry.registerFunction( FunctionIdentifier("st_union_agg", database), ST_UnionAgg.registryExpressionInfo(database), (exprs: Seq[Expression]) => ST_UnionAgg(exprs(0), geometryAPI.name) ) + registry.registerFunction( + FunctionIdentifier("rst_merge_agg", database), + RST_MergeAgg.registryExpressionInfo(database), + (exprs: Seq[Expression]) => RST_MergeAgg(exprs(0), expressionConfig) + ) + registry.registerFunction( + FunctionIdentifier("rst_combineavg_agg", database), + RST_CombineAvgAgg.registryExpressionInfo(database), + (exprs: Seq[Expression]) => RST_CombineAvgAgg(exprs(0), expressionConfig) + ) /** IndexSystem and GeometryAPI Specific methods */ registry.registerFunction( @@ -464,8 +500,6 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP def getGeometryAPI: GeometryAPI = this.geometryAPI - def getRasterAPI: RasterAPI = this.rasterAPI - def getIndexSystem: IndexSystem = this.indexSystem def getProductMethod(methodName: String): universe.MethodMirror = { @@ -509,6 +543,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP ColumnAdapter(ST_BufferLoop(geom.expr, r1.cast("double").expr, r2.cast("double").expr, expressionConfig)) def st_bufferloop(geom: Column, r1: Double, r2: Double): Column = ColumnAdapter(ST_BufferLoop(geom.expr, lit(r1).cast("double").expr, lit(r2).cast("double").expr, expressionConfig)) + def st_buffer_cap_style(geom: Column, radius: Column, capStyle: Column): Column = + ColumnAdapter(ST_BufferCapStyle(geom.expr, radius.cast("double").expr, capStyle.expr, expressionConfig)) + def st_buffer_cap_style(geom: Column, radius: Double, capStyle: String): Column = + ColumnAdapter(ST_BufferCapStyle(geom.expr, lit(radius).cast("double").expr, lit(capStyle).expr, expressionConfig)) def st_centroid(geom: Column): Column = ColumnAdapter(ST_Centroid(geom.expr, expressionConfig)) def st_convexhull(geom: Column): Column = ColumnAdapter(ST_ConvexHull(geom.expr, expressionConfig)) def st_difference(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Difference(geom1.expr, geom2.expr, expressionConfig)) @@ -587,121 +625,119 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP ColumnAdapter(RST_BandMetaData(raster.expr, band.expr, expressionConfig)) def rst_bandmetadata(raster: Column, band: Int): Column = ColumnAdapter(RST_BandMetaData(raster.expr, lit(band).expr, expressionConfig)) - def rst_bandmetadata(raster: String, band: Int): Column = - ColumnAdapter(RST_BandMetaData(lit(raster).expr, lit(band).expr, expressionConfig)) + def rst_boundingbox(raster: Column): Column = ColumnAdapter(RST_BoundingBox(raster.expr, expressionConfig)) + def rst_clip(raster: Column, geometry: Column): Column = ColumnAdapter(RST_Clip(raster.expr, geometry.expr, expressionConfig)) + def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig)) - def rst_georeference(raster: String): Column = ColumnAdapter(RST_GeoReference(lit(raster).expr, expressionConfig)) + def rst_getnodata(raster: Column): Column = ColumnAdapter(RST_GetNoData(raster.expr, expressionConfig)) + def rst_getsubdataset(raster: Column, subdatasetName: Column): Column = + ColumnAdapter(RST_GetSubdataset(raster.expr, subdatasetName.expr, expressionConfig)) + def rst_getsubdataset(raster: Column, subdatasetName: String): Column = + ColumnAdapter(RST_GetSubdataset(raster.expr, lit(subdatasetName).expr, expressionConfig)) def rst_height(raster: Column): Column = ColumnAdapter(RST_Height(raster.expr, expressionConfig)) - def rst_height(raster: String): Column = ColumnAdapter(RST_Height(lit(raster).expr, expressionConfig)) + def rst_initnodata(raster: Column): Column = ColumnAdapter(RST_InitNoData(raster.expr, expressionConfig)) def rst_isempty(raster: Column): Column = ColumnAdapter(RST_IsEmpty(raster.expr, expressionConfig)) - def rst_isempty(raster: String): Column = ColumnAdapter(RST_IsEmpty(lit(raster).expr, expressionConfig)) def rst_memsize(raster: Column): Column = ColumnAdapter(RST_MemSize(raster.expr, expressionConfig)) - def rst_memsize(raster: String): Column = ColumnAdapter(RST_MemSize(lit(raster).expr, expressionConfig)) + def rst_frombands(bandsArray: Column): Column = ColumnAdapter(RST_FromBands(bandsArray.expr, expressionConfig)) + def rst_merge(rasterArray: Column): Column = ColumnAdapter(RST_Merge(rasterArray.expr, expressionConfig)) def rst_metadata(raster: Column): Column = ColumnAdapter(RST_MetaData(raster.expr, expressionConfig)) - def rst_metadata(raster: String): Column = ColumnAdapter(RST_MetaData(lit(raster).expr, expressionConfig)) + def rst_ndvi(raster: Column, band1: Column, band2: Column): Column = + ColumnAdapter(RST_NDVI(raster.expr, band1.expr, band2.expr, expressionConfig)) + def rst_ndvi(raster: Column, band1: Int, band2: Int): Column = + ColumnAdapter(RST_NDVI(raster.expr, lit(band1).expr, lit(band2).expr, expressionConfig)) def rst_numbands(raster: Column): Column = ColumnAdapter(RST_NumBands(raster.expr, expressionConfig)) - def rst_numbands(raster: String): Column = ColumnAdapter(RST_NumBands(lit(raster).expr, expressionConfig)) def rst_pixelheight(raster: Column): Column = ColumnAdapter(RST_PixelHeight(raster.expr, expressionConfig)) - def rst_pixelheight(raster: String): Column = ColumnAdapter(RST_PixelHeight(lit(raster).expr, expressionConfig)) def rst_pixelwidth(raster: Column): Column = ColumnAdapter(RST_PixelWidth(raster.expr, expressionConfig)) - def rst_pixelwidth(raster: String): Column = ColumnAdapter(RST_PixelWidth(lit(raster).expr, expressionConfig)) def rst_rastertogridavg(raster: Column, resolution: Column): Column = ColumnAdapter(RST_RasterToGridAvg(raster.expr, resolution.expr, expressionConfig)) - def rst_rastertogridavg(raster: String, resolution: Column): Column = - ColumnAdapter(RST_RasterToGridAvg(lit(raster).expr, resolution.expr, expressionConfig)) def rst_rastertogridcount(raster: Column, resolution: Column): Column = ColumnAdapter(RST_RasterToGridCount(raster.expr, resolution.expr, expressionConfig)) - def rst_rastertogridcount(raster: String, resolution: Column): Column = - ColumnAdapter(RST_RasterToGridCount(lit(raster).expr, resolution.expr, expressionConfig)) def rst_rastertogridmax(raster: Column, resolution: Column): Column = ColumnAdapter(RST_RasterToGridMax(raster.expr, resolution.expr, expressionConfig)) - def rst_rastertogridmax(raster: String, resolution: Column): Column = - ColumnAdapter(RST_RasterToGridMax(lit(raster).expr, resolution.expr, expressionConfig)) def rst_rastertogridmedian(raster: Column, resolution: Column): Column = ColumnAdapter(RST_RasterToGridMedian(raster.expr, resolution.expr, expressionConfig)) - def rst_rastertogridmedian(raster: String, resolution: Column): Column = - ColumnAdapter(RST_RasterToGridMedian(lit(raster).expr, resolution.expr, expressionConfig)) def rst_rastertogridmin(raster: Column, resolution: Column): Column = ColumnAdapter(RST_RasterToGridMin(raster.expr, resolution.expr, expressionConfig)) - def rst_rastertogridmin(raster: String, resolution: Column): Column = - ColumnAdapter(RST_RasterToGridMin(lit(raster).expr, resolution.expr, expressionConfig)) def rst_rastertoworldcoord(raster: Column, x: Column, y: Column): Column = ColumnAdapter(RST_RasterToWorldCoord(raster.expr, x.expr, y.expr, expressionConfig)) - def rst_rastertoworldcoord(raster: String, x: Column, y: Column): Column = - ColumnAdapter(RST_RasterToWorldCoord(lit(raster).expr, x.expr, y.expr, expressionConfig)) def rst_rastertoworldcoord(raster: Column, x: Int, y: Int): Column = ColumnAdapter(RST_RasterToWorldCoord(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) def rst_rastertoworldcoordx(raster: Column, x: Column, y: Column): Column = ColumnAdapter(RST_RasterToWorldCoordX(raster.expr, x.expr, y.expr, expressionConfig)) - def rst_rastertoworldcoordx(raster: String, x: Column, y: Column): Column = - ColumnAdapter(RST_RasterToWorldCoordX(lit(raster).expr, x.expr, y.expr, expressionConfig)) def rst_rastertoworldcoordx(raster: Column, x: Int, y: Int): Column = ColumnAdapter(RST_RasterToWorldCoordX(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) def rst_rastertoworldcoordy(raster: Column, x: Column, y: Column): Column = ColumnAdapter(RST_RasterToWorldCoordY(raster.expr, x.expr, y.expr, expressionConfig)) - def rst_rastertoworldcoordy(raster: String, x: Column, y: Column): Column = - ColumnAdapter(RST_RasterToWorldCoordY(lit(raster).expr, x.expr, y.expr, expressionConfig)) def rst_rastertoworldcoordy(raster: Column, x: Int, y: Int): Column = ColumnAdapter(RST_RasterToWorldCoordY(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) def rst_retile(raster: Column, tileWidth: Column, tileHeight: Column): Column = ColumnAdapter(RST_ReTile(raster.expr, tileWidth.expr, tileHeight.expr, expressionConfig)) def rst_retile(raster: Column, tileWidth: Int, tileHeight: Int): Column = ColumnAdapter(RST_ReTile(raster.expr, lit(tileWidth).expr, lit(tileHeight).expr, expressionConfig)) - def rst_retile(raster: String, tileWidth: Int, tileHeight: Int): Column = - ColumnAdapter(RST_ReTile(lit(raster).expr, lit(tileWidth).expr, lit(tileHeight).expr, expressionConfig)) def rst_rotation(raster: Column): Column = ColumnAdapter(RST_Rotation(raster.expr, expressionConfig)) - def rst_rotation(raster: String): Column = ColumnAdapter(RST_Rotation(lit(raster).expr, expressionConfig)) def rst_scalex(raster: Column): Column = ColumnAdapter(RST_ScaleX(raster.expr, expressionConfig)) - def rst_scalex(raster: String): Column = ColumnAdapter(RST_ScaleX(lit(raster).expr, expressionConfig)) def rst_scaley(raster: Column): Column = ColumnAdapter(RST_ScaleY(raster.expr, expressionConfig)) - def rst_scaley(raster: String): Column = ColumnAdapter(RST_ScaleY(lit(raster).expr, expressionConfig)) + def rst_setnodata(raster: Column, nodata: Column): Column = ColumnAdapter(RST_SetNoData(raster.expr, nodata.expr, expressionConfig)) + def rst_setnodata(raster: Column, nodata: Double): Column = + ColumnAdapter(RST_SetNoData(raster.expr, lit(nodata).expr, expressionConfig)) def rst_skewx(raster: Column): Column = ColumnAdapter(RST_SkewX(raster.expr, expressionConfig)) - def rst_skewx(raster: String): Column = ColumnAdapter(RST_SkewX(lit(raster).expr, expressionConfig)) def rst_skewy(raster: Column): Column = ColumnAdapter(RST_SkewY(raster.expr, expressionConfig)) - def rst_skewy(raster: String): Column = ColumnAdapter(RST_SkewY(lit(raster).expr, expressionConfig)) def rst_srid(raster: Column): Column = ColumnAdapter(RST_SRID(raster.expr, expressionConfig)) - def rst_srid(raster: String): Column = ColumnAdapter(RST_SRID(lit(raster).expr, expressionConfig)) def rst_subdatasets(raster: Column): Column = ColumnAdapter(RST_Subdatasets(raster.expr, expressionConfig)) - def rst_subdatasets(raster: String): Column = ColumnAdapter(RST_Subdatasets(lit(raster).expr, expressionConfig)) def rst_summary(raster: Column): Column = ColumnAdapter(RST_Summary(raster.expr, expressionConfig)) - def rst_summary(raster: String): Column = ColumnAdapter(RST_Summary(lit(raster).expr, expressionConfig)) + def rst_tessellate(raster: Column, resolution: Column): Column = + ColumnAdapter(RST_Tessellate(raster.expr, resolution.expr, expressionConfig)) + def rst_tessellate(raster: Column, resolution: Int): Column = + ColumnAdapter(RST_Tessellate(raster.expr, lit(resolution).expr, expressionConfig)) + def rst_fromfile(raster: Column): Column = ColumnAdapter(RST_FromFile(raster.expr, lit(-1).expr, expressionConfig)) + def rst_fromfile(raster: Column, sizeInMB: Column): Column = + ColumnAdapter(RST_FromFile(raster.expr, sizeInMB.expr, expressionConfig)) + def rst_fromfile(raster: Column, sizeInMB: Int): Column = + ColumnAdapter(RST_FromFile(raster.expr, lit(sizeInMB).expr, expressionConfig)) + def rst_to_overlapping_tiles(raster: Column, width: Int, height: Int, overlap: Int): Column = + ColumnAdapter(RST_ToOverlappingTiles(raster.expr, lit(width).expr, lit(height).expr, lit(overlap).expr, expressionConfig)) + def rst_to_overlapping_tiles(raster: Column, width: Column, height: Column, overlap: Column): Column = + ColumnAdapter(RST_ToOverlappingTiles(raster.expr, width.expr, height.expr, overlap.expr, expressionConfig)) + def rst_tryopen(raster: Column): Column = ColumnAdapter(RST_TryOpen(raster.expr, expressionConfig)) + def rst_subdivide(raster: Column, sizeInMB: Column): Column = + ColumnAdapter(RST_Subdivide(raster.expr, sizeInMB.expr, expressionConfig)) + def rst_subdivide(raster: Column, sizeInMB: Int): Column = + ColumnAdapter(RST_Subdivide(raster.expr, lit(sizeInMB).expr, expressionConfig)) def rst_upperleftx(raster: Column): Column = ColumnAdapter(RST_UpperLeftX(raster.expr, expressionConfig)) - def rst_upperleftx(raster: String): Column = ColumnAdapter(RST_UpperLeftX(lit(raster).expr, expressionConfig)) def rst_upperlefty(raster: Column): Column = ColumnAdapter(RST_UpperLeftY(raster.expr, expressionConfig)) - def rst_upperlefty(raster: String): Column = ColumnAdapter(RST_UpperLeftY(lit(raster).expr, expressionConfig)) def rst_width(raster: Column): Column = ColumnAdapter(RST_Width(raster.expr, expressionConfig)) - def rst_width(raster: String): Column = ColumnAdapter(RST_Width(lit(raster).expr, expressionConfig)) def rst_worldtorastercoord(raster: Column, x: Column, y: Column): Column = ColumnAdapter(RST_WorldToRasterCoord(raster.expr, x.expr, y.expr, expressionConfig)) def rst_worldtorastercoord(raster: Column, x: Double, y: Double): Column = ColumnAdapter(RST_WorldToRasterCoord(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) - def rst_worldtorastercoord(raster: String, x: Double, y: Double): Column = - ColumnAdapter(RST_WorldToRasterCoord(lit(raster).expr, lit(x).expr, lit(y).expr, expressionConfig)) def rst_worldtorastercoordx(raster: Column, x: Column, y: Column): Column = ColumnAdapter(RST_WorldToRasterCoordX(raster.expr, x.expr, y.expr, expressionConfig)) def rst_worldtorastercoordx(raster: Column, x: Double, y: Double): Column = ColumnAdapter(RST_WorldToRasterCoordX(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) - def rst_worldtorastercoordx(raster: String, x: Double, y: Double): Column = - ColumnAdapter(RST_WorldToRasterCoordX(lit(raster).expr, lit(x).expr, lit(y).expr, expressionConfig)) def rst_worldtorastercoordy(raster: Column, x: Column, y: Column): Column = ColumnAdapter(RST_WorldToRasterCoordY(raster.expr, x.expr, y.expr, expressionConfig)) def rst_worldtorastercoordy(raster: Column, x: Double, y: Double): Column = ColumnAdapter(RST_WorldToRasterCoordY(raster.expr, lit(x).expr, lit(y).expr, expressionConfig)) - def rst_worldtorastercoordy(raster: String, x: Double, y: Double): Column = - ColumnAdapter(RST_WorldToRasterCoordY(lit(raster).expr, lit(x).expr, lit(y).expr, expressionConfig)) /** Aggregators */ def st_intersects_aggregate(leftIndex: Column, rightIndex: Column): Column = ColumnAdapter( ST_IntersectsAggregate(leftIndex.expr, rightIndex.expr, geometryAPI.name).toAggregateExpression(isDistinct = false) ) + + def st_intersects_agg(leftIndex: Column, rightIndex: Column): Column = st_intersects_aggregate(leftIndex, rightIndex) def st_intersection_aggregate(leftIndex: Column, rightIndex: Column): Column = ColumnAdapter( ST_IntersectionAggregate(leftIndex.expr, rightIndex.expr, geometryAPI.name, indexSystem, 0, 0) .toAggregateExpression(isDistinct = false) ) + def st_intersection_agg(leftIndex: Column, rightIndex: Column): Column = st_intersection_aggregate(leftIndex, rightIndex) def st_union_agg(geom: Column): Column = ColumnAdapter(ST_UnionAgg(geom.expr, geometryAPI.name).toAggregateExpression(isDistinct = false)) + def rst_merge_agg(raster: Column): Column = + ColumnAdapter(RST_MergeAgg(raster.expr, expressionConfig).toAggregateExpression(isDistinct = false)) + def rst_combineavg_agg(raster: Column): Column = + ColumnAdapter(RST_CombineAvgAgg(raster.expr, expressionConfig).toAggregateExpression(isDistinct = false)) /** IndexSystem Specific */ @@ -898,8 +934,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAP @deprecated("Please use 'st_centroid' expressions instead.") def st_centroid2D(geom: Column): Column = { struct( - ColumnAdapter(ST_X(ST_Centroid(geom.expr, expressionConfig), expressionConfig)), - ColumnAdapter(ST_Y(ST_Centroid(geom.expr, expressionConfig), expressionConfig)) + ColumnAdapter(ST_X(ST_Centroid(geom.expr, expressionConfig), expressionConfig)), + ColumnAdapter(ST_Y(ST_Centroid(geom.expr, expressionConfig), expressionConfig)) ) } @@ -913,8 +949,8 @@ object MosaicContext extends Logging { private var instance: Option[MosaicContext] = None - def build(indexSystem: IndexSystem, geometryAPI: GeometryAPI, rasterAPI: RasterAPI = GDAL): MosaicContext = { - instance = Some(new MosaicContext(indexSystem, geometryAPI, rasterAPI)) + def build(indexSystem: IndexSystem, geometryAPI: GeometryAPI): MosaicContext = { + instance = Some(new MosaicContext(indexSystem, geometryAPI)) instance.get.setCellIdDataType(indexSystem.getCellIdDataType.typeName) context() } @@ -923,8 +959,6 @@ object MosaicContext extends Logging { def geometryAPI: GeometryAPI = context().getGeometryAPI - def rasterAPI: RasterAPI = context().getRasterAPI - def indexSystem: IndexSystem = context().getIndexSystem def context(): MosaicContext = @@ -935,16 +969,17 @@ object MosaicContext extends Logging { def reset(): Unit = instance = None - // noinspection ScalaStyle + // noinspection ScalaStyle,ScalaWeakerAccess def checkDBR(spark: SparkSession): Boolean = { val sparkVersion = spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "") val isML = sparkVersion.contains("-ml-") val isPhoton = spark.conf.getOption("spark.databricks.photon.enabled").getOrElse("false").toBoolean + val isTest = spark.conf.getOption("spark.databricks.clusterUsageTags.clusterType").isEmpty - if (!isML && !isPhoton) { + if (!isML && !isPhoton && !isTest) { // Print out the warnings both to the log and to the console logWarning("DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime") - logWarning("DEPRECATION WARNING: Mosaic will stop working on this cluster after v0.3.x.") + logWarning("DEPRECATION WARNING: Mosaic will stop working on this cluster after v0.3.x.") logWarning("Please use a Databricks Photon-enabled Runtime (for performance benefits) or Runtime ML (for spatial AI benefits).") println("DEPRECATION WARNING: Mosaic is not supported on the selected Databricks Runtime") println("DEPRECATION WARNING: Mosaic will stop working on this cluster after v0.3.x.") diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala index aae083258..4eab02c4d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala @@ -26,8 +26,6 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def getIndexSystem: String = configs.getOrElse(MOSAIC_INDEX_SYSTEM, H3.name) - def getRasterAPI: String = configs.getOrElse(MOSAIC_RASTER_API, GDAL.name) - def getRasterCheckpoint: String = configs.getOrElse(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT) def getCellIdType: DataType = IndexSystemFactory.getIndexSystem(getIndexSystem).cellIdType @@ -65,8 +63,8 @@ object MosaicExpressionConfig { expressionConfig .setGeometryAPI(spark.conf.get(MOSAIC_GEOMETRY_API, JTS.name)) .setIndexSystem(spark.conf.get(MOSAIC_INDEX_SYSTEM, H3.name)) - .setRasterAPI(spark.conf.get(MOSAIC_RASTER_API, GDAL.name)) .setRasterCheckpoint(spark.conf.get(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT)) + } } diff --git a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala index 99d80704b..4b26bc472 100644 --- a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala @@ -7,40 +7,58 @@ import org.gdal.gdal.gdal import java.io.{BufferedInputStream, File, PrintWriter} import java.nio.file.{Files, Paths} import scala.language.postfixOps -import scala.sys.process._ import scala.util.Try //noinspection DuplicatedCode +/** GDAL environment preparation and configuration. */ object MosaicGDAL extends Logging { + private val usrlibsoPath = "/usr/lib/libgdal.so" + private val usrlibso30Path = "/usr/lib/libgdal.so.30" + private val usrlibso3003Path = "/usr/lib/libgdal.so.30.0.3" + private val libjnisoPath = "/lib/jni/libgdalalljni.so" + private val libjniso30Path = "/lib/jni/libgdalalljni.so.30" + private val libogdisoPath = "/lib/ogdi/libgdal.so" + // noinspection ScalaWeakerAccess val GDAL_ENABLED = "spark.mosaic.gdal.native.enabled" - private val mosaicGDALPath = Files.createTempDirectory("mosaic-gdal") - private val mosaicGDALAbsolutePath = mosaicGDALPath.toAbsolutePath.toString var isEnabled = false - def wasEnabled(spark: SparkSession): Boolean = spark.conf.get(GDAL_ENABLED, "false").toBoolean + /** Returns true if GDAL is enabled. */ + def wasEnabled(spark: SparkSession): Boolean = + spark.conf.get(GDAL_ENABLED, "false").toBoolean || sys.env.getOrElse("GDAL_ENABLED", "false").toBoolean - def prepareEnvironment(spark: SparkSession, initScriptPath: String, sharedObjectsPath: String): Unit = { + /** Prepares the GDAL environment. */ + def prepareEnvironment(spark: SparkSession, initScriptPath: String): Unit = { if (!wasEnabled(spark) && !isEnabled) { Try { copyInitScript(initScriptPath) - copySharedObjects(sharedObjectsPath) } match { - case scala.util.Success(_) => - logInfo("GDAL environment prepared successfully.") - case scala.util.Failure(exception) => - logError("GDAL environment preparation failed.", exception) - throw exception + case scala.util.Success(_) => logInfo("GDAL environment prepared successfully.") + case scala.util.Failure(exception) => logWarning("GDAL environment preparation failed.", exception) } } } + /** Configures the GDAL environment. */ + def configureGDAL(): Unit = { + val CPL_TMPDIR = Files.createTempDirectory("mosaic-gdal-tmp").toAbsolutePath.toString + val GDAL_PAM_PROXY_DIR = Files.createTempDirectory("mosaic-gdal-tmp").toAbsolutePath.toString + gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") + gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "EMPTY_DIR") + gdal.SetConfigOption("CPL_TMPDIR", CPL_TMPDIR) + gdal.SetConfigOption("GDAL_PAM_PROXY_DIR", GDAL_PAM_PROXY_DIR) + gdal.SetConfigOption("GDAL_PAM_ENABLED", "NO") + gdal.SetConfigOption("CPL_VSIL_USE_TEMP_FILE_FOR_RANDOM_WRITE", "NO") + } + + /** Enables the GDAL environment. */ def enableGDAL(spark: SparkSession): Unit = { if (!wasEnabled(spark) && !isEnabled) { Try { isEnabled = true loadSharedObjects() + configureGDAL() gdal.AllRegister() spark.conf.set(GDAL_ENABLED, "true") } match { @@ -56,54 +74,51 @@ object MosaicGDAL extends Logging { } } - private def copySharedObjects(path: String): Unit = { - val so = readResourceBytes("/gdal/ubuntu/libgdalalljni.so") - val so30 = readResourceBytes("/gdal/ubuntu/libgdalalljni.so.30") - - val usrGDALPath = Paths.get("/usr/lib/jni/") - if (!Files.exists(mosaicGDALPath)) Files.createDirectories(mosaicGDALPath) - if (!Files.exists(usrGDALPath)) Files.createDirectories(usrGDALPath) - Files.write(Paths.get(s"$mosaicGDALAbsolutePath/libgdalalljni.so"), so) - Files.write(Paths.get(s"$mosaicGDALAbsolutePath/libgdalalljni.so.30"), so30) - - s"sudo cp $mosaicGDALAbsolutePath/libgdalalljni.so $path/libgdalalljni.so".!! - s"sudo cp $mosaicGDALAbsolutePath/libgdalalljni.so.30 $path/libgdalalljni.so.30".!! - } - - //noinspection ScalaStyle + // noinspection ScalaStyle private def copyInitScript(path: String): Unit = { val destPath = Paths.get(path) - if (!Files.exists(mosaicGDALPath)) Files.createDirectories(mosaicGDALPath) if (!Files.exists(destPath)) Files.createDirectories(destPath) - val w = new PrintWriter(new File(s"$mosaicGDALAbsolutePath/mosaic-gdal-init.sh")) + val w = new PrintWriter(new File(s"$path/mosaic-gdal-init.sh")) val scriptLines = readResourceLines("/scripts/install-gdal-databricks.sh") scriptLines - .map { x => if (x.contains("__DEFAULT_JNI_PATH__")) x.replace("__DEFAULT_JNI_PATH__", path) else x } - .foreach(x => w.println(x)) + .map { x => if (x.contains("__DEFAULT_JNI_PATH__")) x.replace("__DEFAULT_JNI_PATH__", path) else x } + .foreach(x => w.println(x)) w.close() - - s"sudo cp $mosaicGDALAbsolutePath/mosaic-gdal-init.sh $path/mosaic-gdal-init.sh".!! } + /** Loads the shared objects required for GDAL. */ private def loadSharedObjects(): Unit = { - System.load("/usr/lib/libgdal.so.30") - if (!Files.exists(Paths.get("/usr/lib/libgdal.so"))) { - "sudo cp /usr/lib/libgdal.so.30 /usr/lib/libgdal.so".!! + loadOrNOOP(usrlibsoPath) + loadOrNOOP(usrlibso30Path) + loadOrNOOP(usrlibso3003Path) + loadOrNOOP(libjnisoPath) + loadOrNOOP(libjniso30Path) + loadOrNOOP(libogdisoPath) + } + + /** Loads the shared object if it exists. */ + // noinspection ScalaStyle + private def loadOrNOOP(path: String): Unit = { + try { + if (Files.exists(Paths.get(path))) System.load(path) + } catch { + case t: Throwable => + println(t.toString) + println(s"Failed to load $path") + logWarning(s"Failed to load $path", t) } - System.load("/usr/lib/libgdal.so") - System.load("/usr/lib/libgdal.so.30.0.3") - System.load("/usr/lib/jni/libgdalalljni.so.30") - System.load("/usr/lib/ogdi/libgdal.so") } + /** Reads the resource bytes. */ private def readResourceBytes(name: String): Array[Byte] = { val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) - try Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray + try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } finally bis.close() } - //noinspection SameParameterValue + /** Reads the resource lines. */ + // noinspection SameParameterValue private def readResourceLines(name: String): Array[String] = { val bytes = readResourceBytes(name) val lines = new String(bytes).split("\n") diff --git a/src/main/scala/com/databricks/labs/mosaic/package.scala b/src/main/scala/com/databricks/labs/mosaic/package.scala index 83aba0d93..29278a6d5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/package.scala @@ -2,15 +2,14 @@ package com.databricks.labs import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem -import com.databricks.labs.mosaic.core.raster.api.RasterAPI import com.databricks.labs.mosaic.datasource.multiread.MosaicDataFrameReader import org.apache.spark.sql.SparkSession +//noinspection ScalaWeakerAccess package object mosaic { val JTS: GeometryAPI = mosaic.core.geometry.api.JTS val ESRI: GeometryAPI = mosaic.core.geometry.api.ESRI - val GDAL: RasterAPI = mosaic.core.raster.api.RasterAPI.GDAL val H3: IndexSystem = mosaic.core.index.H3IndexSystem val BNG: IndexSystem = mosaic.core.index.BNGIndexSystem @@ -24,6 +23,11 @@ package object mosaic { val MOSAIC_RASTER_CHECKPOINT = "spark.databricks.labs.mosaic.raster.checkpoint" val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "dbfs:/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_READ_STRATEGY = "raster.read.strategy" + val MOSAIC_RASTER_READ_IN_MEMORY = "in_memory" + val MOSAIC_RASTER_READ_AS_PATH = "as_path" + val MOSAIC_RASTER_RE_TILE_ON_READ = "retile_on_read" + def read: MosaicDataFrameReader = new MosaicDataFrameReader(SparkSession.builder().getOrCreate()) } diff --git a/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQL.scala b/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQL.scala index 73e2ba802..d153272f8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQL.scala @@ -3,7 +3,7 @@ package com.databricks.labs.mosaic.sql.extensions import com.databricks.labs.mosaic._ import com.databricks.labs.mosaic.core.geometry.api.{ESRI, JTS} import com.databricks.labs.mosaic.core.index.{BNGIndexSystem, H3IndexSystem} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL +import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSessionExtensions @@ -32,15 +32,14 @@ class MosaicSQL extends (SparkSessionExtensions => Unit) with Logging { // spark.conf.get will throw an Exception if the key is not found. // Since GDAL is optional, we need to handle the case where the key is not found. // Fixes issue #297. - val rasterAPI = spark.conf.get(MOSAIC_RASTER_API, "GDAL") - val mosaicContext = (indexSystem, geometryAPI, rasterAPI) match { - case ("H3", "JTS", "GDAL") => MosaicContext.build(H3IndexSystem, JTS, GDAL) - case ("H3", "ESRI", "GDAL") => MosaicContext.build(H3IndexSystem, ESRI, GDAL) - case ("BNG", "JTS", "GDAL") => MosaicContext.build(BNGIndexSystem, JTS, GDAL) - case ("BNG", "ESRI", "GDAL") => MosaicContext.build(BNGIndexSystem, ESRI, GDAL) - case (is, gapi, rapi) => throw new Error(s"Index system, geometry API and rasterAPI: ($is, $gapi, $rapi) not supported.") + val mosaicContext = (indexSystem, geometryAPI) match { + case ("H3", "JTS") => MosaicContext.build(H3IndexSystem, JTS) + case ("H3", "ESRI") => MosaicContext.build(H3IndexSystem, ESRI) + case ("BNG", "JTS") => MosaicContext.build(BNGIndexSystem, JTS) + case ("BNG", "ESRI") => MosaicContext.build(BNGIndexSystem, ESRI) + case (is, gapi) => throw new Error(s"Index system, geometry API and rasterAPI: ($is, $gapi) not supported.") } - logInfo(s"Registering Mosaic SQL Extensions ($indexSystem, $geometryAPI, $rasterAPI).") + logInfo(s"Registering Mosaic SQL Extensions ($indexSystem, $geometryAPI).") mosaicContext.register(spark) // NOP rule. This rule is specified only to respect syntax. _ => () diff --git a/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQLDefault.scala b/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQLDefault.scala index d888a9d22..55bd6e762 100644 --- a/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQLDefault.scala +++ b/src/main/scala/com/databricks/labs/mosaic/sql/extensions/MosaicSQLDefault.scala @@ -2,7 +2,7 @@ package com.databricks.labs.mosaic.sql.extensions import com.databricks.labs.mosaic.core.geometry.api.JTS import com.databricks.labs.mosaic.core.index.H3IndexSystem -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL +import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSessionExtensions @@ -24,7 +24,7 @@ class MosaicSQLDefault extends (SparkSessionExtensions => Unit) with Logging { */ override def apply(ext: SparkSessionExtensions): Unit = { ext.injectCheckRule(spark => { - val mosaicContext = MosaicContext.build(H3IndexSystem, JTS, GDAL) + val mosaicContext = MosaicContext.build(H3IndexSystem, JTS) logInfo(s"Registering Mosaic SQL Extensions (H3, JTS, GDAL).") mosaicContext.register(spark) // NOP rule. This rule is specified only to respect syntax. diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala new file mode 100644 index 000000000..718f892f2 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala @@ -0,0 +1,104 @@ +package com.databricks.labs.mosaic.utils + +import java.nio.file.{Files, Paths} +import java.util.UUID + +object PathUtils { + + def getFormatExtension(rawPath: String): String = { + val path: String = resolvePath(rawPath) + val fileName = path.split("/").last + val extension = fileName.split("\\.").last + extension + } + + private def resolvePath(rawPath: String): String = { + val path = + if (isSubdataset(rawPath)) { + val _ :: filePath :: _ :: Nil = rawPath.split(":").toList + filePath + } else { + rawPath + } + path + } + + def getCleanPath(path: String, useZipPath: Boolean): String = { + val cleanPath = path.replace("file:/", "/").replace("dbfs:/", "/dbfs/") + if (useZipPath && cleanPath.endsWith(".zip")) { + getZipPath(cleanPath) + } else { + cleanPath + } + } + + def isSubdataset(path: String): Boolean = { + path.split(":").length == 3 + } + + def isInMemory(path: String): Boolean = { + path.startsWith("/vsimem/") || path.contains("/vsimem/") + } + + def getSubdatasetPath(path: String): String = { + // Subdatasets are paths with a colon in them. + // We need to check for this condition and handle it. + // Subdatasets paths are formatted as: "FORMAT:/path/to/file.tif:subdataset" + val format :: filePath :: subdataset :: Nil = path.split(":").toList + val isZip = filePath.endsWith(".zip") + val vsiPrefix = if (isZip) "/vsizip/" else "" + s"$format:$vsiPrefix$filePath:$subdataset" + } + + def getZipPath(path: String): String = { + // It is really important that the resulting path is /vsizip// and not /vsizip/ + // /vsizip// is for absolute paths /viszip/ is relative to the current working directory + // /vsizip/ wont work on a cluster + // see: https://gdal.org/user/virtual_file_systems.html#vsizip-zip-archives + val isZip = path.endsWith(".zip") + val readPath = if (path.startsWith("/vsizip/")) path else if (isZip) s"/vsizip/$path" else path + readPath + } + + def copyToTmp(rawPath: String): String = { + try { + val path: String = resolvePath(rawPath) + + val fileName = path.split("/").last + val extension = getFormatExtension(path) + + val inPath = getCleanPath(path, useZipPath = extension == "zip") + + val randomID = UUID.randomUUID().toString + val tmpDir = Files.createTempDirectory(s"mosaic_local_$randomID").toFile.getAbsolutePath + + val outPath = s"$tmpDir/$fileName" + + Files.createDirectories(Paths.get(tmpDir)) + Files.copy(Paths.get(inPath), Paths.get(outPath)) + + if (isSubdataset(rawPath)) { + val format :: _ :: subdataset :: Nil = rawPath.split(":").toList + getSubdatasetPath(s"$format:$outPath:$subdataset") + } else { + outPath + } + } catch { + case _: Throwable => rawPath + } + } + + def createTmpFilePath(uuid: String, extension: String): String = { + val randomID = UUID.randomUUID() + val tmpDir = Files.createTempDirectory(s"mosaic_tmp_$randomID").toFile.getAbsolutePath + val outPath = s"$tmpDir/raster_${uuid.replace("-", "_")}.$extension" + Files.createDirectories(Paths.get(outPath).getParent) + outPath + } + + def fromSubdatasetPath(path: String): String = { + val _ :: filePath :: _ :: Nil = path.split(":").toList + filePath + } + +} diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml new file mode 100644 index 000000000..8b9237893 --- /dev/null +++ b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml @@ -0,0 +1,342 @@ + + + 1[-] HYBL="Hybrid level" + + + 1.136666106290528e-06 + 1.200369887769461e-06 + 196 + 0 + 0 + 1|1|0|1|1|2|7|1|0|0|0|1|1|1|2|1|3|3|3|4|6|2|1|0|0|0|0|0|0|0|0|0|0|0|0|1|0|1|1|2|2|6|5|1|1|2|0|1|0|1|1|0|1|1|1|1|0|1|0|1|1|2|7|1|0|0|0|0|0|0|0|0|0|1|0|1|0|1|1|1|6|2|1|0|0|0|0|0|0|1|0|1|0|1|1|1|6|2|0|1|0|0|0|0|0|1|0|1|0|1|1|1|5|2|1|1|0|0|0|0|0|1|0|1|0|1|1|1|4|3|1|1|0|0|0|0|1|0|0|1|1|0|1|2|3|3|1|1|0|0|0|0|1|0|1|0|1|1|1|1|2|4|1|1|0|0|0|1|0|1|0|1|0|1|1|1|2|3|2|1|0|0|1|0|1|0|1|0|1|0|1|1|2|3|2|1 + + + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-03T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 1 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 1 255 255 255 255 255 255 + 1622678400 sec UTC + 1-HYBL + [-] + 1622678400 sec UTC + 1.2002082030449e-06 + 1.1662431895312e-06 + 1.1368277910151e-06 + 1.9422780853555e-08 + 100 + + + -47 + 0 + 0 + 24 + 0.0000011368 + + + + 1[-] HYBL="Hybrid level" + + + 1.143818000822908e-06 + 1.20314421841223e-06 + 196 + 0 + 0 + 3|2|1|1|2|5|0|0|0|0|1|3|1|2|2|4|1|0|3|2|1|1|3|4|0|0|0|0|0|0|0|0|4|1|1|1|2|4|1|0|0|0|0|0|0|0|0|0|0|0|0|3|3|5|7|1|1|0|2|2|4|0|0|0|0|0|0|0|0|0|0|0|3|1|1|1|0|1|1|2|4|0|0|0|0|0|0|0|3|1|1|0|1|0|1|1|2|2|2|0|0|0|0|0|3|1|1|0|1|0|1|1|0|1|3|2|0|0|0|0|0|3|1|1|1|0|1|0|1|1|0|4|1|0|0|0|0|1|3|1|0|1|1|0|1|1|0|2|3|0|0|0|0|0|3|2|0|1|1|0|1|1|0|2|3|0|0|0|0|0|0|4|1|1|0|1|1|0|1|4|1|0|0|0|0|0|0|4|1|1|0|1|1|0|3|3 + + + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-04T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 1 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 1 255 255 255 255 255 255 + 1622764800 sec UTC + 1-HYBL + [-] + 1622764800 sec UTC + 1.2029936442559e-06 + 1.1711197875953e-06 + 1.1439685749792e-06 + 1.824681247154e-08 + 100 + + + -48 + 0 + 0 + 24 + 0.0000011440 + + + + 10[-] HYBL="Hybrid level" + + + 1.144498880922602e-05 + 1.174589644914685e-05 + 196 + 0 + 0 + 1|0|0|0|3|0|0|3|2|2|0|0|2|0|1|2|2|2|1|1|3|3|2|1|1|3|5|2|4|5|3|4|5|2|4|3|6|0|2|4|2|1|0|1|3|3|1|3|1|4|2|2|2|2|1|2|4|0|2|1|3|0|0|2|3|2|1|1|0|0|1|2|1|3|0|1|1|1|0|1|1|1|0|2|1|0|0|0|1|1|2|2|1|0|1|1|0|1|1|2|0|2|0|2|0|3|0|0|1|0|1|0|0|0|1|0|1|0|0|0|0|0|0|0|1|0|0|1|1|0|0|0|0|0|0|0|0|0|2|0|1|0|0|0|0|0|1|1|0|0|1|0|0|0|0|0|0|0|1|0|0|0|0|1|1|1|0|0|1|0|0|0|1|1|0|0|0|0|0|1|1|1|0|0|0|0|0|1|2|0|0|0|0|0|0|1 + + + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-03T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 10 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 10 255 255 255 255 255 255 + 1622678400 sec UTC + 10-HYBL + [-] + 1622678400 sec UTC + 1.1745132724172e-05 + 1.1539705123401e-05 + 1.14457525342e-05 + 6.9678470128824e-08 + 100 + + + -45 + 0 + 0 + 24 + 0.0000114458 + + + + 30[-] HYBL="Hybrid level" + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-03T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 30 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 30 255 255 255 255 255 255 + 1622678400 sec UTC + 30-HYBL + [-] + 1622678400 sec UTC + 1.068031849627e-07 + 8.8144559515281e-08 + 7.4302164421169e-08 + 9.4481458808801e-09 + 100 + + + -48 + 0 + 0 + 24 + 0.0000000743 + + + + 10[-] HYBL="Hybrid level" + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-04T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 10 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 10 255 255 255 255 255 255 + 1622764800 sec UTC + 10-HYBL + [-] + 1622764800 sec UTC + 1.2193680959172e-05 + 1.1760300362674e-05 + 1.147888997366e-05 + 1.7730224129066e-07 + 100 + + + -44 + 0 + 0 + 24 + 0.0000114789 + + + + 30[-] HYBL="Hybrid level" + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-04T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 30 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 30 255 255 255 255 255 255 + 1622764800 sec UTC + 30-HYBL + [-] + 1622764800 sec UTC + 1.1384071285647e-07 + 9.2106310315715e-08 + 7.2270665896212e-08 + 9.5383389050812e-09 + 100 + + + -48 + 0 + 0 + 24 + 0.0000000723 + + + + 10[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 10-ISBL + [-] + 1622678400 sec UTC + 1.6276792848657e-05 + 1.6106599578423e-05 + 1.583885295986e-05 + 1.0153528902132e-07 + 100 + + + + 20[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 20-ISBL + [-] + 1622678400 sec UTC + 1.2929541298945e-05 + 1.2611470742276e-05 + 1.2212967703817e-05 + 1.4723476239413e-07 + 100 + + + + 50[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 50-ISBL + [-] + 1622678400 sec UTC + 2.8687002213701e-06 + 2.5890412616161e-06 + 2.299082780155e-06 + 1.428912787031e-07 + 100 + + + + 100[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 100-ISBL + [-] + 1622678400 sec UTC + 2.502025040485e-07 + 1.9998846863352e-07 + 1.6797713442429e-07 + 1.9060562971876e-08 + 100 + + + + 10[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 10-ISBL + [-] + 1622764800 sec UTC + 1.6031418454077e-05 + 1.5874708642328e-05 + 1.5749257727293e-05 + 7.265758657701e-08 + 100 + + + + 20[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 20-ISBL + [-] + 1622764800 sec UTC + 1.3027401109866e-05 + 1.2695418569578e-05 + 1.1947801795031e-05 + 2.1390172203242e-07 + 100 + + + + 50[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 50-ISBL + [-] + 1622764800 sec UTC + 2.9717652978434e-06 + 2.6961222537076e-06 + 2.4221099010902e-06 + 1.215710670366e-07 + 100 + + + + 100[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 100-ISBL + [-] + 1622764800 sec UTC + 2.741275579865e-07 + 2.0168293846781e-07 + 1.650793706176e-07 + 2.4385349641867e-08 + 100 + + + diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml new file mode 100644 index 000000000..8b9237893 --- /dev/null +++ b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml @@ -0,0 +1,342 @@ + + + 1[-] HYBL="Hybrid level" + + + 1.136666106290528e-06 + 1.200369887769461e-06 + 196 + 0 + 0 + 1|1|0|1|1|2|7|1|0|0|0|1|1|1|2|1|3|3|3|4|6|2|1|0|0|0|0|0|0|0|0|0|0|0|0|1|0|1|1|2|2|6|5|1|1|2|0|1|0|1|1|0|1|1|1|1|0|1|0|1|1|2|7|1|0|0|0|0|0|0|0|0|0|1|0|1|0|1|1|1|6|2|1|0|0|0|0|0|0|1|0|1|0|1|1|1|6|2|0|1|0|0|0|0|0|1|0|1|0|1|1|1|5|2|1|1|0|0|0|0|0|1|0|1|0|1|1|1|4|3|1|1|0|0|0|0|1|0|0|1|1|0|1|2|3|3|1|1|0|0|0|0|1|0|1|0|1|1|1|1|2|4|1|1|0|0|0|1|0|1|0|1|0|1|1|1|2|3|2|1|0|0|1|0|1|0|1|0|1|0|1|1|2|3|2|1 + + + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-03T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 1 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 1 255 255 255 255 255 255 + 1622678400 sec UTC + 1-HYBL + [-] + 1622678400 sec UTC + 1.2002082030449e-06 + 1.1662431895312e-06 + 1.1368277910151e-06 + 1.9422780853555e-08 + 100 + + + -47 + 0 + 0 + 24 + 0.0000011368 + + + + 1[-] HYBL="Hybrid level" + + + 1.143818000822908e-06 + 1.20314421841223e-06 + 196 + 0 + 0 + 3|2|1|1|2|5|0|0|0|0|1|3|1|2|2|4|1|0|3|2|1|1|3|4|0|0|0|0|0|0|0|0|4|1|1|1|2|4|1|0|0|0|0|0|0|0|0|0|0|0|0|3|3|5|7|1|1|0|2|2|4|0|0|0|0|0|0|0|0|0|0|0|3|1|1|1|0|1|1|2|4|0|0|0|0|0|0|0|3|1|1|0|1|0|1|1|2|2|2|0|0|0|0|0|3|1|1|0|1|0|1|1|0|1|3|2|0|0|0|0|0|3|1|1|1|0|1|0|1|1|0|4|1|0|0|0|0|1|3|1|0|1|1|0|1|1|0|2|3|0|0|0|0|0|3|2|0|1|1|0|1|1|0|2|3|0|0|0|0|0|0|4|1|1|0|1|1|0|1|4|1|0|0|0|0|0|0|4|1|1|0|1|1|0|3|3 + + + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-04T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 1 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 1 255 255 255 255 255 255 + 1622764800 sec UTC + 1-HYBL + [-] + 1622764800 sec UTC + 1.2029936442559e-06 + 1.1711197875953e-06 + 1.1439685749792e-06 + 1.824681247154e-08 + 100 + + + -48 + 0 + 0 + 24 + 0.0000011440 + + + + 10[-] HYBL="Hybrid level" + + + 1.144498880922602e-05 + 1.174589644914685e-05 + 196 + 0 + 0 + 1|0|0|0|3|0|0|3|2|2|0|0|2|0|1|2|2|2|1|1|3|3|2|1|1|3|5|2|4|5|3|4|5|2|4|3|6|0|2|4|2|1|0|1|3|3|1|3|1|4|2|2|2|2|1|2|4|0|2|1|3|0|0|2|3|2|1|1|0|0|1|2|1|3|0|1|1|1|0|1|1|1|0|2|1|0|0|0|1|1|2|2|1|0|1|1|0|1|1|2|0|2|0|2|0|3|0|0|1|0|1|0|0|0|1|0|1|0|0|0|0|0|0|0|1|0|0|1|1|0|0|0|0|0|0|0|0|0|2|0|1|0|0|0|0|0|1|1|0|0|1|0|0|0|0|0|0|0|1|0|0|0|0|1|1|1|0|0|1|0|0|0|1|1|0|0|0|0|0|1|1|1|0|0|0|0|0|1|2|0|0|0|0|0|0|1 + + + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-03T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 10 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 10 255 255 255 255 255 255 + 1622678400 sec UTC + 10-HYBL + [-] + 1622678400 sec UTC + 1.1745132724172e-05 + 1.1539705123401e-05 + 1.14457525342e-05 + 6.9678470128824e-08 + 100 + + + -45 + 0 + 0 + 24 + 0.0000114458 + + + + 30[-] HYBL="Hybrid level" + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-03T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 30 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 30 255 255 255 255 255 255 + 1622678400 sec UTC + 30-HYBL + [-] + 1622678400 sec UTC + 1.068031849627e-07 + 8.8144559515281e-08 + 7.4302164421169e-08 + 9.4481458808801e-09 + 100 + + + -48 + 0 + 0 + 24 + 0.0000000743 + + + + 10[-] HYBL="Hybrid level" + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-04T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 10 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 10 255 255 255 255 255 255 + 1622764800 sec UTC + 10-HYBL + [-] + 1622764800 sec UTC + 1.2193680959172e-05 + 1.1760300362674e-05 + 1.147888997366e-05 + 1.7730224129066e-07 + 100 + + + -44 + 0 + 0 + 24 + 0.0000114789 + + + + 30[-] HYBL="Hybrid level" + + (prodType 192, cat 210, subcat 203) [-] + 192 + unknown + 0 sec + CENTER=98(ECMWF) SUBCENTER=0 MASTER_TABLE=5 LOCAL_TABLE=0 SIGNF_REF_TIME=1(Start_of_Forecast) REF_TIME=2021-06-04T00:00:00Z PROD_STATUS=0(Operational) TYPE=0(Analysis) + 0 + 210 203 0 255 146 65535 255 1 0 105 0 30 255 -127 -2147483647 + 210 203 0 255 146 255 255 255 1 0 0 0 0 105 0 0 0 0 30 255 255 255 255 255 255 + 1622764800 sec UTC + 30-HYBL + [-] + 1622764800 sec UTC + 1.1384071285647e-07 + 9.2106310315715e-08 + 7.2270665896212e-08 + 9.5383389050812e-09 + 100 + + + -48 + 0 + 0 + 24 + 0.0000000723 + + + + 10[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 10-ISBL + [-] + 1622678400 sec UTC + 1.6276792848657e-05 + 1.6106599578423e-05 + 1.583885295986e-05 + 1.0153528902132e-07 + 100 + + + + 20[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 20-ISBL + [-] + 1622678400 sec UTC + 1.2929541298945e-05 + 1.2611470742276e-05 + 1.2212967703817e-05 + 1.4723476239413e-07 + 100 + + + + 50[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 50-ISBL + [-] + 1622678400 sec UTC + 2.8687002213701e-06 + 2.5890412616161e-06 + 2.299082780155e-06 + 1.428912787031e-07 + 100 + + + + 100[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622678400 sec UTC + 100-ISBL + [-] + 1622678400 sec UTC + 2.502025040485e-07 + 1.9998846863352e-07 + 1.6797713442429e-07 + 1.9060562971876e-08 + 100 + + + + 10[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 10-ISBL + [-] + 1622764800 sec UTC + 1.6031418454077e-05 + 1.5874708642328e-05 + 1.5749257727293e-05 + 7.265758657701e-08 + 100 + + + + 20[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 20-ISBL + [-] + 1622764800 sec UTC + 1.3027401109866e-05 + 1.2695418569578e-05 + 1.1947801795031e-05 + 2.1390172203242e-07 + 100 + + + + 50[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 50-ISBL + [-] + 1622764800 sec UTC + 2.9717652978434e-06 + 2.6961222537076e-06 + 2.4221099010902e-06 + 1.215710670366e-07 + 100 + + + + 100[hPa] ISBL (Isobaric surface) + + undefined [-] + var203 of table 210 of center ECMWF + 0 sec + 1622764800 sec UTC + 100-ISBL + [-] + 1622764800 sec UTC + 2.741275579865e-07 + 2.0168293846781e-07 + 1.650793706176e-07 + 2.4385349641867e-08 + 100 + + + diff --git a/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsHexCodegen.scala b/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsHexCodegen.scala index 0749ebee8..e412730e2 100644 --- a/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsHexCodegen.scala +++ b/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsHexCodegen.scala @@ -1,7 +1,6 @@ package com.databricks.labs.mosaic.codegen import com.databricks.labs.mosaic.core.geometry.api.{ESRI, JTS} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL import com.databricks.labs.mosaic.core.index.H3IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.SparkCodeGenSuite @@ -10,8 +9,8 @@ import org.scalatest.flatspec.AnyFlatSpec class TestAsHexCodegen extends AnyFlatSpec with AsHexCodegenBehaviors with SparkCodeGenSuite { "AsHex Expression" should "do codegen for any index system and any geometry API" in { - it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, ESRI, GDAL)) - it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, JTS, GDAL)) + it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, ESRI)) + it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, JTS)) } } diff --git a/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsJSONCodegen.scala b/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsJSONCodegen.scala index 148ba49d2..b7fb2e5db 100644 --- a/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsJSONCodegen.scala +++ b/src/test/scala/com/databricks/labs/mosaic/codegen/TestAsJSONCodegen.scala @@ -1,7 +1,7 @@ package com.databricks.labs.mosaic.codegen import com.databricks.labs.mosaic.core.geometry.api.{JTS, ESRI} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL +import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.index.H3IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.SparkCodeGenSuite @@ -10,8 +10,8 @@ import org.scalatest.flatspec.AnyFlatSpec class TestAsJSONCodegen extends AnyFlatSpec with AsJSONCodegenBehaviors with SparkCodeGenSuite { "AsJson Expression" should "do codegen for any index system and any geometry API" in { - it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, ESRI, GDAL)) - it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, JTS, GDAL)) + it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, ESRI)) + it should behave like codeGeneration(MosaicContext.build(H3IndexSystem, JTS)) } } diff --git a/src/test/scala/com/databricks/labs/mosaic/core/TestMosaic.scala b/src/test/scala/com/databricks/labs/mosaic/core/TestMosaic.scala index 769b2d7b0..49ef57c4b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/TestMosaic.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/TestMosaic.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core -import com.databricks.labs.mosaic.{ESRI, JTS} +import com.databricks.labs.mosaic.{ESRI, H3, JTS} import com.databricks.labs.mosaic.core.index._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.must.Matchers.be @@ -65,4 +65,21 @@ class TestMosaic extends AnyFunSuite { } + test("MosaicFill should not return empty set for bounding box.") { + val wkt = "POLYGON ((-0.0011265364650581968 -127.48832860406948, " + + "-0.0010988881177292488 -127.48832851537821, -0.0010984513060565016 -127.4883092423591," + + " -0.0011272500508798704 -127.48830933794375, -0.0011265364650581968 -127.48832860406948))" + val wkt2 = "POLYGON (( -127.48832860406948 -0.0011265364650581968, " + + "-127.48832851537821 -0.0010988881177292488, -127.4883092423591 -0.0010984513060565016," + + " -127.48830933794375 -0.0011272500508798704, -127.48832860406948 -0.0011265364650581968))" + + val bbox = JTS.geometry(wkt2, "WKT") + + val cells = Mosaic + .mosaicFill(bbox, 6, keepCoreGeom = false, H3, JTS) + .map(_.indexAsLong(H3)) + + cells.length should be > 0 + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/core/index/IndexSystemIDTest.scala b/src/test/scala/com/databricks/labs/mosaic/core/index/IndexSystemIDTest.scala index b0f490df3..68c7b2dfc 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/index/IndexSystemIDTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/index/IndexSystemIDTest.scala @@ -1,6 +1,7 @@ package com.databricks.labs.mosaic.core.index import com.databricks.labs.mosaic.{BNG, H3} +import org.apache.spark.sql.SparkSession import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers._ @@ -15,7 +16,7 @@ class IndexSystemIDTest extends AnyFunSuite { test("IndexSystemID getIndexSystem from ID") { IndexSystemFactory.getIndexSystem(H3.name) shouldEqual H3IndexSystem IndexSystemFactory.getIndexSystem(BNG.name) shouldEqual BNGIndexSystem - an[Error] should be thrownBy IndexSystemFactory.getIndexSystem(null) + an[Throwable] should be thrownBy IndexSystemFactory.getIndexSystem(null: SparkSession) } } diff --git a/src/test/scala/com/databricks/labs/mosaic/core/index/TestBNGIndexSystem.scala b/src/test/scala/com/databricks/labs/mosaic/core/index/TestBNGIndexSystem.scala index 7c4db1574..f06a1ec3d 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/index/TestBNGIndexSystem.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/index/TestBNGIndexSystem.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.core.index import com.databricks.labs.mosaic.core.geometry.{MosaicGeometryESRI, MosaicGeometryJTS} import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum._ -import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers._ @@ -246,7 +245,7 @@ class TestBNGIndexSystem extends AnyFunSuite { test("Issue 354: KRing should work near the edge of the grid") { val kring = BNGIndexSystem.kRing("TM99", 1) - kring should contain theSameElementsAs(Seq("TM99", "TM88", "TM98", "TG90", "TG80", "TM89")) + kring should contain theSameElementsAs Seq("TM99", "TM88", "TM98", "TN08", "TN09", "TH00", "TG90", "TG80", "TM89") } } diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala index 8f6d12898..34d743ac1 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.core.raster +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.test.mocks.filePath import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.should.Matchers._ @@ -9,9 +10,12 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from GeoTIFF file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster(filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF")) + val testRaster = MosaicRasterGDAL.readRaster( + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + ) val testBand = testRaster.getBand(1) - testBand.asInstanceOf[MosaicRasterBandGDAL].band + testBand.getBand testBand.index shouldBe 1 testBand.units shouldBe "" testBand.description shouldBe "Nadir_Reflectance_Band1" @@ -33,6 +37,7 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") ) val testBand = testRaster.getBand(1) @@ -50,9 +55,15 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from a NetCDF file.") { assume(System.getProperty("os.name") == "Linux") - val superRaster = MosaicRasterGDAL.readRaster(filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc")) + val superRaster = MosaicRasterGDAL.readRaster( + filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), + filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") + ) val subdatasetPath = superRaster.subdatasets("bleaching_alert_area") - val testRaster = MosaicRasterGDAL.readRaster(subdatasetPath) + val testRaster = MosaicRasterGDAL.readRaster( + subdatasetPath, + subdatasetPath + ) val testBand = testRaster.getBand(1) testBand.dataType shouldBe 1 diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala index f37b5b390..a0055f9d9 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala @@ -1,11 +1,8 @@ package com.databricks.labs.mosaic.core.raster -import com.databricks.labs.mosaic.core.raster.api.RasterAPI -import com.databricks.labs.mosaic.GDAL -import com.databricks.labs.mosaic.sql.extensions.MosaicGDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.test.mocks.filePath import org.apache.spark.sql.test.SharedSparkSessionGDAL -import org.apache.spark.sql.SparkSessionExtensions import org.scalatest.matchers.should.Matchers._ import scala.sys.process._ @@ -35,7 +32,10 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from GeoTIFF file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster(filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF")) + val testRaster = MosaicRasterGDAL.readRaster( + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + ) testRaster.xSize shouldBe 2400 testRaster.ySize shouldBe 2400 testRaster.numBands shouldBe 1 @@ -43,7 +43,7 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.SRID shouldBe 0 testRaster.extent shouldBe Seq(-8895604.157333, 1111950.519667, -7783653.637667, 2223901.039333) testRaster.getRaster.GetProjection() - noException should be thrownBy testRaster.asInstanceOf[MosaicRasterGDAL].spatialRef + noException should be thrownBy testRaster.spatialRef an[Exception] should be thrownBy testRaster.getBand(-1) an[Exception] should be thrownBy testRaster.getBand(Int.MaxValue) testRaster.cleanUp() @@ -53,6 +53,7 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") ) testRaster.xSize shouldBe 14 @@ -67,10 +68,16 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from a NetCDF file.") { assume(System.getProperty("os.name") == "Linux") - val superRaster = MosaicRasterGDAL.readRaster(filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc")) + val superRaster = MosaicRasterGDAL.readRaster( + filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), + filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") + ) val subdatasetPath = superRaster.subdatasets("bleaching_alert_area") - val testRaster = MosaicRasterGDAL.readRaster(subdatasetPath) + val testRaster = MosaicRasterGDAL.readRaster( + subdatasetPath, + subdatasetPath + ) testRaster.xSize shouldBe 7200 testRaster.ySize shouldBe 3600 @@ -83,14 +90,27 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { superRaster.cleanUp() } - test("Auxiliary logic") { + test("Raster pixel and extent sizes are correct.") { assume(System.getProperty("os.name") == "Linux") - RasterAPI.apply("GDAL") shouldBe GDAL - RasterAPI.getReader("GDAL") shouldBe MosaicRasterGDAL - GDAL.name shouldBe "GDAL" - val extension = new MosaicGDAL() - noException should be thrownBy extension.apply(new SparkSessionExtensions) + val testRaster = MosaicRasterGDAL.readRaster( + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + ) + + testRaster.pixelXSize - 463.312716527 < 0.0000001 shouldBe true + testRaster.pixelYSize - -463.312716527 < 0.0000001 shouldBe true + testRaster.pixelDiagSize - 655.22312733 < 0.0000001 shouldBe true + + testRaster.diagSize - 3394.1125496954 < 0.0000001 shouldBe true + testRaster.originX - -8895604.157333 < 0.0000001 shouldBe true + testRaster.originY - 2223901.039333 < 0.0000001 shouldBe true + testRaster.xMax - -7783653.637667 < 0.0000001 shouldBe true + testRaster.yMax - 1111950.519667 < 0.0000001 shouldBe true + testRaster.xMin - -8895604.157333 < 0.0000001 shouldBe true + testRaster.yMin - 2223901.039333 < 0.0000001 shouldBe true + + testRaster.cleanUp() } } diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala index ae48e0454..99a1563ca 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala @@ -1,11 +1,13 @@ package com.databricks.labs.mosaic.datasource +import com.databricks.labs.mosaic.MOSAIC_RASTER_READ_STRATEGY +import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.test.SharedSparkSession -import org.scalatest.matchers.must.Matchers.{be, noException, not} -import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper} +import org.apache.spark.sql.test.SharedSparkSessionGDAL +import org.scalatest.matchers.must.Matchers.{be, noException} +import org.scalatest.matchers.should.Matchers.an -class GDALFileFormatTest extends QueryTest with SharedSparkSession { +class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { test("Read netcdf with GDALFileFormat") { assume(System.getProperty("os.name") == "Linux") @@ -28,7 +30,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSession { .format("gdal") .option("driverName", "NetCDF") .load(filePath) - .select("proj4Str") + .select("metadata") .take(1) } @@ -41,20 +43,27 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSession { noException should be thrownBy spark.read .format("gdal") + .option("extensions", "grib") + .option("raster_storage", "disk") + .option("extensions", "grib") .load(filePath) .take(1) noException should be thrownBy spark.read .format("gdal") - .option("driverName", "NetCDF") + .option("extensions", "grib") + .option("raster_storage", "disk") + .option("extensions", "grib") .load(filePath) .take(1) noException should be thrownBy spark.read .format("gdal") - .option("driverName", "NetCDF") + .option("extensions", "grib") + .option("raster_storage", "disk") + .option("extensions", "grib") .load(filePath) - .select("proj4Str") + .select("metadata") .take(1) } @@ -80,9 +89,15 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSession { .format("gdal") .option("driverName", "TIF") .load(filePath) - .select("proj4Str") + .select("metadata") .take(1) + noException should be thrownBy spark.read + .format("gdal") + .option(MOSAIC_RASTER_READ_STRATEGY, "retile_on_read") + .load(filePath) + .collect() + } test("Read zarr with GDALFileFormat") { @@ -109,7 +124,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSession { .option("driverName", "Zarr") .option("vsizip", "true") .load(filePath) - .select("proj4Str") + .select("metadata") .take(1) } @@ -118,30 +133,6 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSession { val reader = new GDALFileFormat() an[Error] should be thrownBy reader.prepareWrite(spark, null, null, null) - for ( - driver <- Seq( - "GTiff", - "HDF4", - "HDF5", - "JP2ECW", - "JP2KAK", - "JP2MrSID", - "JP2OpenJPEG", - "NetCDF", - "PDF", - "PNG", - "VRT", - "XPM", - "COG", - "GRIB", - "Zarr" - ) - ) { - GDALFileFormat.getFileExtension(driver) should not be "UNSUPPORTED" - } - - GDALFileFormat.getFileExtension("NotADriver") should be("UNSUPPORTED") - noException should be thrownBy Utils.createRow(Array(null)) noException should be thrownBy Utils.createRow(Array(1, 2, 3)) noException should be thrownBy Utils.createRow(Array(1.toByte)) diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/OGRFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/OGRFileFormatTest.scala index 5b8e90b1d..001642880 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/OGRFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/OGRFileFormatTest.scala @@ -1,9 +1,10 @@ package com.databricks.labs.mosaic.datasource import com.databricks.labs.mosaic.{H3, JTS} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL +import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.expressions.util.OGRReadeWithOffset import com.databricks.labs.mosaic.functions.MosaicContext +import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.test.SharedSparkSessionGDAL @@ -80,8 +81,8 @@ class OGRFileFormatTest extends QueryTest with SharedSparkSessionGDAL { noException should be thrownBy OGRFileFormat.enableOGRDrivers(force = true) - val path = getClass.getResource("/binary/geodb/bridges.gdb.zip").getPath.replace("file:", "") - val ds = ogr.Open(s"/vsizip/$path", 0) + val path = PathUtils.getCleanPath(getClass.getResource("/binary/geodb/bridges.gdb.zip").getPath, useZipPath = true) + val ds = ogr.Open(path, 0) noException should be thrownBy OGRFileFormat.getLayer(ds, 0, "layer2") @@ -145,7 +146,7 @@ class OGRFileFormatTest extends QueryTest with SharedSparkSessionGDAL { test("OGRFileFormat should handle partial schema: ISSUE 351") { assume(System.getProperty("os.name") == "Linux") - val mc = MosaicContext.build(H3, JTS, GDAL) + val mc = MosaicContext.build(H3, JTS) import mc.functions._ val issue351 = "/binary/issue351/" diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index 6d8a52551..1f7b4008b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -16,8 +16,8 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess assume(System.getProperty("os.name") == "Linux") MosaicContext.build(H3IndexSystem, JTS) - val grib = "/binary/netcdf-coral/" - val filePath = getClass.getResource(grib).getPath + val netcdf = "/binary/netcdf-coral/" + val filePath = getClass.getResource(netcdf).getPath noException should be thrownBy MosaicContext.read .format("raster_to_grid") @@ -42,7 +42,7 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess noException should be thrownBy MosaicContext.read .format("raster_to_grid") - .option("fileExtension", "grib") + .option("extensions", "grib") .option("combiner", "min") .option("retile", "true") .option("tileSize", "10") diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/constructors/TestConstructors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/constructors/TestConstructors.scala index a997f9b68..78ca9bb7e 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/constructors/TestConstructors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/constructors/TestConstructors.scala @@ -1,7 +1,7 @@ package com.databricks.labs.mosaic.expressions.constructors import com.databricks.labs.mosaic.core.geometry.api.{ESRI, JTS} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL +import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.index.H3IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.SparkSuite @@ -10,33 +10,33 @@ import org.scalatest.flatspec.AnyFlatSpec class TestConstructors extends AnyFlatSpec with ConstructorsBehaviors with SparkSuite { "ST_Point" should "construct a point geometry for any index system and any geometry API" in { - it should behave like createST_Point(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) - it should behave like createST_Point(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like createST_Point(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like createST_Point(MosaicContext.build(H3IndexSystem, JTS), spark) } "ST_MakeLine" should "construct a line geometry from an array of points for any index system and any geometry API" in { - it should behave like createST_MakeLineSimple(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) - it should behave like createST_MakeLineSimple(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like createST_MakeLineSimple(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like createST_MakeLineSimple(MosaicContext.build(H3IndexSystem, JTS), spark) } "ST_MakeLine" should "construct a line geometry from a set of geometries for any index system and any geometry API" in { - it should behave like createST_MakeLineComplex(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) - it should behave like createST_MakeLineComplex(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like createST_MakeLineComplex(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like createST_MakeLineComplex(MosaicContext.build(H3IndexSystem, JTS), spark) } "ST_MakeLine" should "return null if any input is null" in { - it should behave like createST_MakeLineAnyNull(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) - it should behave like createST_MakeLineAnyNull(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like createST_MakeLineAnyNull(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like createST_MakeLineAnyNull(MosaicContext.build(H3IndexSystem, JTS), spark) } "ST_MakePolygon" should "construct a polygon geometry without holes for any index system and any geometry API" in { - it should behave like createST_MakePolygonNoHoles(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) - it should behave like createST_MakePolygonNoHoles(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like createST_MakePolygonNoHoles(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like createST_MakePolygonNoHoles(MosaicContext.build(H3IndexSystem, JTS), spark) } "ST_MakePolygon" should "construct a polygon geometry with holes for any index system and any geometry API" in { - it should behave like createST_MakePolygonWithHoles(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) - it should behave like createST_MakePolygonWithHoles(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like createST_MakePolygonWithHoles(MosaicContext.build(H3IndexSystem, ESRI), spark) + it should behave like createST_MakePolygonWithHoles(MosaicContext.build(H3IndexSystem, JTS), spark) } } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferBehaviors.scala index 10558c760..674238e43 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_BufferBehaviors.scala @@ -50,6 +50,9 @@ trait ST_BufferBehaviors extends QueryTest { .collect() sqlResult.zip(expected).foreach { case (l, r) => math.abs(l - r) should be < 1e-8 } + + mocks.getWKTRowsDf().select(st_buffer_cap_style($"wkt", lit(1), lit("round"))).collect() + mocks.getWKTRowsDf().select(st_buffer_cap_style($"wkt", 1, "round")).collect() } def bufferCodegen(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_CentroidBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_CentroidBehaviors.scala index ce596f37b..055df10db 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_CentroidBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_CentroidBehaviors.scala @@ -47,6 +47,18 @@ trait ST_CentroidBehaviors extends MosaicSpatialQueryTest { .collect() sqlResult.zip(expected).foreach { case (l, r) => l.equals(r) shouldEqual true } + + val sqlResult2 = spark + .sql( + """with subquery ( + | select st_centroid2D(wkt) as coord from source + |) select coord.col1, coord.col2 from subquery""".stripMargin) + .as[(Double, Double)] + .collect() + + sqlResult2.zip(expected).foreach { case (l, r) => l.equals(r) shouldEqual true } + + noException should be thrownBy st_centroid2D(lit("POLYGON (1 1, 2 2, 3 3, 1 1)")) } def centroidCodegen(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionBehaviors.scala index bf39c79eb..2d88efd07 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectionBehaviors.scala @@ -128,6 +128,41 @@ trait ST_IntersectionBehaviors extends QueryTest { (results.select("area").as[Double].collect().head - indexPolygon1.union(indexChip1).union(indexPolygon2).union(indexChip2).getArea) should be < 10e-8 + + left.createOrReplaceTempView("left") + right.createOrReplaceTempView("right") + + val sqlResults = spark.sql( + """ + |SELECT + | left_row_id, + | ST_Area(ST_Intersection_Aggregate(left_index, right_index)) AS area + |FROM left + |JOIN right + |ON left_index.index_id = right_index.index_id + |GROUP BY left_row_id + |""".stripMargin + ) + + (sqlResults.select("area").as[Double].collect().head - + indexPolygon1.union(indexChip1).union(indexPolygon2).union(indexChip2).getArea) should be < 10e-8 + + val sqlResults2 = spark.sql( + """ + |SELECT + | left_row_id, + | ST_Area(ST_Intersection_Agg(left_index, right_index)) AS area + |FROM left + |JOIN right + |ON left_index.index_id = right_index.index_id + |GROUP BY left_row_id + |""".stripMargin + ) + + (sqlResults2.select("area").as[Double].collect().head - + indexPolygon1.union(indexChip1).union(indexPolygon2).union(indexChip2).getArea) should be < 10e-8 + + noException should be thrownBy st_intersection_agg(lit("POLYGON (1 1, 2 2, 3 3, 1 1)"), lit("POLYGON (1 1, 2 2, 3 3, 1 1)")) } def selfIntersectionBehaviour(indexSystem: IndexSystem, geometryAPI: GeometryAPI, resolution: Int): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectsBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectsBehaviors.scala index 36213630b..4a480dfd3 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectsBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_IntersectsBehaviors.scala @@ -78,6 +78,17 @@ trait ST_IntersectsBehaviors extends QueryTest { |""".stripMargin) result2.collect().length should be > 0 + + val result3 = spark.sql(""" + |SELECT ST_INTERSECTS_AGG(LEFT_INDEX, RIGHT_INDEX) + |FROM LEFT + |INNER JOIN RIGHT ON LEFT_INDEX.INDEX_ID == RIGHT_INDEX.INDEX_ID + |GROUP BY LEFT_ID, RIGHT_ID + |""".stripMargin) + + result3.collect().length should be > 0 + + noException should be thrownBy st_intersects_agg(lit("POLYGON (1 1, 2 2, 3 3, 1 1)"), lit("POLYGON (1 1, 2 2, 3 3, 1 1)")) } def intersectsAggBehaviour(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellAreaBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellAreaBehaviors.scala index 4939563e9..7f75c2098 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellAreaBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellAreaBehaviors.scala @@ -26,6 +26,17 @@ trait CellAreaBehaviors extends MosaicSpatialQueryTest { val result = Seq(cellId).toDF("cellId").select(grid_cellarea($"cellId")).collect() math.abs(result.head.getDouble(0) - Row(area).getDouble(0)) < 1e-6 shouldEqual true + + Seq(cellId).toDF("cellId").createOrReplaceTempView("cellId") + + val sqlResult = spark + .sql("""with subquery ( + | select grid_cellarea(cellId) as area from cellId + |) select * from subquery""".stripMargin) + .as[Double] + .collect() + + math.abs(sqlResult.head - area) < 1e-6 shouldEqual true } def columnFunctionSignatures(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionAggBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionAggBehaviors.scala index c8caee100..84f3ddec9 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionAggBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionAggBehaviors.scala @@ -68,6 +68,17 @@ trait CellIntersectionAggBehaviors extends MosaicSpatialQueryTest { res.foreach { case (actual, expected) => actual.equalsTopo(expected) shouldEqual true } + in_df.createOrReplaceTempView("source") + + noException should be thrownBy spark + .sql("""with subquery ( + | select grid_cell_intersection_agg(chip) as intersection_chips from source + | group by case_id, chip.index_id + |) select st_aswkt(intersection_chips.wkb) from subquery""".stripMargin) + .as[String] + .collect() + .map(wkt => mc.getGeometryAPI.geometry(wkt, "WKT")) + } def columnFunctionSignatures(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionBehaviors.scala index 7a2a69e42..c8dce2f01 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellIntersectionBehaviors.scala @@ -71,6 +71,19 @@ trait CellIntersectionBehaviors extends MosaicSpatialQueryTest { .map(r => (mc.getGeometryAPI.geometry(r._1, "WKT"), mc.getGeometryAPI.geometry(r._2, "WKT"))) res.foreach { case (actual, expected) => actual.equalsTopo(expected) shouldEqual true } + + + leftDf.createOrReplaceTempView("left") + + val sqlResult = spark + .sql("""with subquery ( + | select grid_cell_intersection(left_chip, left_chip) as intersection from left + |) select st_aswkt(intersection.wkb) from subquery""".stripMargin) + .as[String] + .collect() + .map(r => mc.getGeometryAPI.geometry(r, "WKT")) + + sqlResult.foreach(actual => actual.equalsTopo(mc.getGeometryAPI.geometry("POLYGON ((0 0, 2 0, 2 1, 0 1, 0 0))", "WKT")) shouldEqual true) } def columnFunctionSignatures(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionAggBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionAggBehaviors.scala index 055d1cc17..e34df0b15 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionAggBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionAggBehaviors.scala @@ -68,6 +68,16 @@ trait CellUnionAggBehaviors extends MosaicSpatialQueryTest { res.foreach { case (actual, expected) => actual.equalsTopo(expected) shouldEqual true } + in_df.createOrReplaceTempView("source") + + //noException should be thrownBy spark + spark.sql("""with subquery ( + | select grid_cell_union_agg(chip) as union_chip from source + | group by case_id, chip.index_id + |) select st_aswkt(union_chip.wkb) from subquery""".stripMargin) + .as[String] + .collect() + } def columnFunctionSignatures(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionBehaviors.scala index 51885be62..f758d93ea 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/CellUnionBehaviors.scala @@ -71,6 +71,22 @@ trait CellUnionBehaviors extends MosaicSpatialQueryTest { .map(r => (mc.getGeometryAPI.geometry(r._1, "WKT"), mc.getGeometryAPI.geometry(r._2, "WKT"))) res.foreach { case (actual, expected) => actual.equalsTopo(expected) shouldEqual true } + + leftDf.createOrReplaceTempView("left") + + val sqlResult = spark + .sql("""with subquery ( + | select grid_cell_union(left_chip, left_chip) as chip from left + |) select st_aswkt(chip.wkb) from subquery""".stripMargin) + .as[String] + .collect() + + sqlResult.exists { r => + mc.getGeometryAPI + .geometry(r, "WKT") + .equalsTopo(mc.getGeometryAPI.geometry("POLYGON ((0 0, 2 0, 2 1, 0 1, 0 0))", "WKT")) + } shouldEqual true + } def columnFunctionSignatures(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/GridDistanceBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/GridDistanceBehaviors.scala index fb400fab5..28d02b90b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/GridDistanceBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/GridDistanceBehaviors.scala @@ -16,6 +16,8 @@ trait GridDistanceBehaviors extends MosaicSpatialQueryTest { spark.sparkContext.setLogLevel("FATAL") val mc = mosaicContext import mc.functions._ + val sc = spark + import sc.implicits._ mc.register(spark) val resolution = 4 @@ -40,6 +42,18 @@ trait GridDistanceBehaviors extends MosaicSpatialQueryTest { ) cellPairs.where(col("grid_distance") =!= 0).count() shouldEqual cellPairs.count() + + boroughs.createOrReplaceTempView("boroughs") + + val sqlResult = spark + .sql("""with subquery ( + | select grid_distance(grid_pointascellid(st_centroid(wkt), 4), grid_pointascellid(st_centroid(wkt), 4)) as dist from boroughs + |) select * from subquery""".stripMargin) + .as[Long] + .collect() + + sqlResult.foreach(_ shouldEqual 0) + } def auxiliaryMethods(mosaicContext: MosaicContext): Unit = { @@ -50,7 +64,6 @@ trait GridDistanceBehaviors extends MosaicSpatialQueryTest { mc.register(spark) val wkt = mocks.getWKTRowsDf(mc.getIndexSystem).limit(1).select("wkt").as[String].collect().head - val k = 4 val gridDistanceExpr = GridDistance( mc.functions.grid_pointascellid(mc.functions.st_centroid(lit(wkt)), lit(4)).expr, diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala index 8b3f60284..d4fbf78dd 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicFillBehaviors.scala @@ -44,6 +44,8 @@ trait MosaicFillBehaviors extends MosaicSpatialQueryTest { .collect() boroughs.collect().length shouldEqual mosaics2.length + + noException should be thrownBy mosaicfill(col("wkt"), resolution, lit(true)) } def wkbMosaicFill(mosaicContext: MosaicContext): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetadataBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetadataBehaviors.scala index d76ffb4cb..6051ccc8e 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetadataBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetadataBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -11,42 +10,46 @@ import org.scalatest.matchers.should.Matchers._ trait RST_BandMetadataBehaviors extends QueryTest { def bandMetadataBehavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val sc = spark val mc = MosaicContext.build(indexSystem, geometryAPI) mc.register() - val sc = spark + import mc.functions._ import sc.implicits._ - noException should be thrownBy mc.getRasterAPI noException should be thrownBy MosaicContext.geometryAPI - val rasterDfWithBandMetadata = mocks - .getNetCDFBinaryDf(spark) - .withColumn("subdatasets", rst_subdatasets($"path")) - .withColumn("bleachingSubdataset", element_at($"subdatasets", "bleaching_alert_area")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val rasterDfWithBandMetadata = rastersInMemory + .withColumn("subdatasets", rst_subdatasets($"tile")) + .withColumn("tile", rst_getsubdataset($"tile", lit("bleaching_alert_area"))) + .withColumn("tile", rst_subdivide($"tile", 100)) .select( - rst_bandmetadata($"bleachingSubdataset", lit(1)) + rst_bandmetadata($"tile", lit(1)) .alias("metadata") ) - mocks - .getNetCDFBinaryDf(spark) - .withColumn("subdatasets", rst_subdatasets($"path")) - .withColumn("bleachingSubdataset", element_at($"subdatasets", "bleaching_alert_area")) + rastersInMemory + .withColumn("subdatasets", rst_subdatasets($"tile")) + .withColumn("tile", rst_getsubdataset($"tile", lit("bleaching_alert_area"))) .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_bandmetadata(bleachingSubdataset, 1) from source + |select rst_bandmetadata(tile, 1) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("subdatasets", rst_subdatasets($"path")) - .withColumn("bleachingSubdataset", element_at($"subdatasets", "bleaching_alert_area")) + noException should be thrownBy rastersInMemory + .withColumn("subdatasets", rst_subdatasets($"tile")) + .withColumn("tile", rst_getsubdataset($"tile", lit("bleaching_alert_area"))) .select( - rst_bandmetadata($"bleachingSubdataset", lit(1)) + rst_bandmetadata($"tile", lit(1)) .alias("metadata") ) + .collect() val result = rasterDfWithBandMetadata.as[Map[String, String]].collect() @@ -58,7 +61,6 @@ trait RST_BandMetadataBehaviors extends QueryTest { noException should be thrownBy rst_bandmetadata($"bleachingSubdataset", lit(1)) noException should be thrownBy rst_bandmetadata($"bleachingSubdataset", 1) - noException should be thrownBy rst_bandmetadata("bleachingSubdataset", 1) } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBoxBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBoxBehaviors.scala new file mode 100644 index 000000000..9478411bd --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBoxBehaviors.scala @@ -0,0 +1,46 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_BoundingBoxBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("bbox", rst_boundingbox($"tile")) + .select(st_area($"bbox").as("area")) + .as[Double] + .collect() + + gridTiles.forall(_ > 0.0) should be(true) + + rastersInMemory.createOrReplaceTempView("source") + + val gridTilesSQL = spark + .sql(""" + |SELECT ST_Area(RST_BoundingBox(tile)) AS area + |FROM source + |""".stripMargin) + .as[Double] + .collect() + + gridTilesSQL.forall(_ > 0.0) should be(true) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBoxTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBoxTest.scala new file mode 100644 index 000000000..28f393031 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBoxTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_BoundingBoxTest extends QueryTest with SharedSparkSessionGDAL with RST_BoundingBoxBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_BoundingBox with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ClipBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ClipBehaviors.scala new file mode 100644 index 000000000..dbc0b35e9 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ClipBehaviors.scala @@ -0,0 +1,49 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_ClipBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("bbox", rst_boundingbox($"tile")) + .withColumn("cent", st_centroid($"bbox")) + .withColumn("clip_region", st_buffer($"cent", 0.1)) + .withColumn("clip", rst_clip($"tile", $"clip_region")) + .withColumn("bbox2", rst_boundingbox($"clip")) + .withColumn("result", st_area($"bbox") =!= st_area($"bbox2")) + .select("result") + .as[Boolean] + .collect() + + gridTiles.forall(identity) should be(true) + + rastersInMemory.createOrReplaceTempView("source") + + noException should be thrownBy spark + .sql(""" + |select + | rst_clip(tile, st_buffer(st_centroid(rst_boundingbox(tile)), 0.1)) as tile + |from source + |""".stripMargin) + .collect() + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ClipTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ClipTest.scala new file mode 100644 index 000000000..e3597141d --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ClipTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_ClipTest extends QueryTest with SharedSparkSessionGDAL with RST_ClipBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Clip with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAggBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAggBehaviors.scala new file mode 100644 index 000000000..5ed81f8f1 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAggBehaviors.scala @@ -0,0 +1,60 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_CombineAvgAggBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory.union(rastersInMemory) + .withColumn("tiles", rst_tessellate($"tile", 2)) + .select("path", "tiles") + .groupBy("path") + .agg( + rst_combineavg_agg($"tiles").as("tiles") + ) + .select("tiles") + + rastersInMemory.union(rastersInMemory) + .createOrReplaceTempView("source") + + spark.sql(""" + |select rst_combineavg_agg(tiles) as tiles + |from ( + | select path, rst_tessellate(tile, 2) as tiles + | from source + |) + |group by path + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("tiles", rst_tessellate($"tile", 2)) + .select("path", "tiles") + .groupBy("path") + .agg( + rst_combineavg_agg($"tiles").as("tiles") + ) + .select("tiles") + + val result = gridTiles.collect() + + result.length should be(rastersInMemory.count()) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAggTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAggTest.scala new file mode 100644 index 000000000..222b89ed2 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAggTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_CombineAvgAggTest extends QueryTest with SharedSparkSessionGDAL with RST_CombineAvgAggBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_CombineAvgAgg with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFileBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFileBehaviors.scala new file mode 100644 index 000000000..f61fe174d --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFileBehaviors.scala @@ -0,0 +1,69 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_FromFileBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("binaryFile") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("tile", rst_fromfile($"path")) + .withColumn("bbox", rst_boundingbox($"tile")) + .withColumn("cent", st_centroid($"bbox")) + .withColumn("clip_region", st_buffer($"cent", 0.1)) + .withColumn("clip", rst_clip($"tile", $"clip_region")) + .withColumn("bbox2", rst_boundingbox($"clip")) + .withColumn("result", st_area($"bbox") =!= st_area($"bbox2")) + .select("result") + .as[Boolean] + .collect() + + gridTiles.forall(identity) should be(true) + + rastersInMemory.createOrReplaceTempView("source") + + val gridTilesSQL = spark + .sql(""" + |with subquery as ( + | select rst_fromfile(path) as tile from source + |) + |select st_area(rst_boundingbox(tile)) != st_area(rst_boundingbox(rst_clip(tile, st_buffer(st_centroid(rst_boundingbox(tile)), 0.1)))) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL.forall(identity) should be(true) + + + val gridTilesSQL2 = spark + .sql( + """ + |with subquery as ( + | select rst_fromfile(path, 4) as tile from source + |) + |select st_area(rst_boundingbox(tile)) != st_area(rst_boundingbox(rst_clip(tile, st_buffer(st_centroid(rst_boundingbox(tile)), 0.1)))) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL2.forall(identity) should be(true) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFileTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFileTest.scala new file mode 100644 index 000000000..f595693d5 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFileTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_FromFileTest extends QueryTest with SharedSparkSessionGDAL with RST_FromFileBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_FromFile with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReferenceBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReferenceBehaviors.scala index 09756115a..6e698426d 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReferenceBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReferenceBehaviors.scala @@ -3,8 +3,8 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.lit import org.scalatest.matchers.should.Matchers._ trait RST_GeoReferenceBehaviors extends QueryTest { @@ -17,22 +17,24 @@ trait RST_GeoReferenceBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val geoReferenceDf = mocks - .getNetCDFBinaryDf(spark) - .withColumn("georeference", rst_georeference($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val geoReferenceDf = rastersInMemory + .withColumn("georeference", rst_georeference($"tile")) .select("georeference") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_georeference(path) from source + |select rst_georeference(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("georeference", rst_georeference("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("georeference", rst_georeference($"tile")) .select("georeference") val result = geoReferenceDf.as[Map[String, Double]].collect() diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_HeightBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_HeightBehaviors.scala index 3e707cebc..7effc2e14 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_HeightBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_HeightBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_HeightBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_height($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_height($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_height(path) from source + |select rst_height(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_height("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_height($"tile")) .select("result") val result = df.as[Int].collect() diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmptyBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmptyBehaviors.scala index f19e19bba..0db36ec39 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmptyBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmptyBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -17,30 +16,39 @@ trait RST_IsEmptyBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_isempty($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_isempty($"tile")) + .select("result") + + val df2 = rastersInMemory + .withColumn("tile", rst_getsubdataset($"tile", "bleaching_alert_area")) + .withColumn("result", rst_isempty($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_isempty(path) from source + |select rst_isempty(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_isempty("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_isempty($"tile")) .select("result") val result = df.as[Boolean].collect() + val result2 = df2.as[Boolean].collect() result.head shouldBe false + result2.head shouldBe false an[Exception] should be thrownBy spark.sql(""" - |select rst_isempty(path, 1, 1) from source + |select rst_isempty() from source |""".stripMargin) } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSizeBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSizeBehaviors.scala index cf3a11278..741fad613 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSizeBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSizeBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_MemSizeBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_memsize($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_memsize($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") - noException should be thrownBy spark.sql(""" - |select rst_memsize(path) from source - |""".stripMargin) + spark.sql(""" + |select rst_memsize(tile) from source + |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_memsize("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_memsize($"tile")) .select("result") val result = df.as[Long].collect() diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAggBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAggBehaviors.scala new file mode 100644 index 000000000..0533eafee --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAggBehaviors.scala @@ -0,0 +1,60 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_MergeAggBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("tiles", rst_tessellate($"tile", 3)) + .select("path", "tiles") + .groupBy("path") + .agg( + rst_merge_agg($"tiles").as("tiles") + ) + .select("tiles") + + rastersInMemory + .createOrReplaceTempView("source") + + spark.sql(""" + |select rst_merge_agg(tiles) as tiles + |from ( + | select path, rst_tessellate(tile, 3) as tiles + | from source + |) + |group by path + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("tiles", rst_tessellate($"tile", 3)) + .select("path", "tiles") + .groupBy("path") + .agg( + rst_merge_agg($"tiles").as("tiles") + ) + .select("tiles") + + val result = gridTiles.collect() + + result.length should be(rastersInMemory.count()) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAggTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAggTest.scala new file mode 100644 index 000000000..7689d7685 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAggTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MergeAggTest extends QueryTest with SharedSparkSessionGDAL with RST_MergeAggBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_MergeAgg with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala new file mode 100644 index 000000000..f4b17ce83 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala @@ -0,0 +1,68 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.collect_set +import org.scalatest.matchers.should.Matchers._ + +trait RST_MergeBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", 3)) + .select("path", "tile") + .groupBy("path") + .agg( + collect_set("tile").as("tiles") + ) + .select( + rst_merge($"tiles").as("tile") + ) + + rastersInMemory + .createOrReplaceTempView("source") + + spark.sql(""" + |select rst_merge(tiles) as tile + |from ( + | select collect_set(tile) as tiles + | from ( + | select path, rst_tessellate(tile, 3) as tile + | from source + | ) + | group by path + |) + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("tile", rst_tessellate($"tile", 3)) + .select("path", "tile") + .groupBy("path") + .agg( + collect_set("tile").as("tiles") + ) + .select( + rst_merge($"tiles").as("tile") + ) + + val result = gridTiles.collect() + + result.length should be(rastersInMemory.count()) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeTest.scala new file mode 100644 index 000000000..3174a5070 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MergeTest extends QueryTest with SharedSparkSessionGDAL with RST_MergeBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Merge with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetadataBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetadataBehaviors.scala index 7ee3f4246..d6869fce7 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetadataBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetadataBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,21 +15,24 @@ trait RST_MetadataBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val rasterDfWithMetadata = mocks - .getGeotiffBinaryDf(spark) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val rasterDfWithMetadata = rastersInMemory .select( - rst_metadata($"path").alias("metadata") + rst_metadata($"tile").alias("metadata") ) .select("metadata") val result = rasterDfWithMetadata.as[Map[String, String]].collect() - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_metadata(path) from source + |select rst_metadata(tile) from source |""".stripMargin) result.head.getOrElse("SHORTNAME", "") shouldBe "MCD43A4" @@ -40,7 +42,6 @@ trait RST_MetadataBehaviors extends QueryTest { result.head.getOrElse("TileID", "") shouldBe "51010007" noException should be thrownBy rst_metadata($"path") - noException should be thrownBy rst_metadata("path") } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBandsBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBandsBehaviors.scala index 2992c4073..711cab7ce 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBandsBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBandsBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_NumBandsBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_numbands($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_numbands($"tile")) .select("result") - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_numbands(path) from source + |select rst_numbands(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_numbands("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_numbands($"tile")) .select("result") val result = df.as[Int].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeightBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeightBehaviors.scala index 1463fa42a..d9f0c66f1 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeightBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeightBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,23 +15,26 @@ trait RST_PixelHeightBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_pixelheight($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_pixelheight($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_pixelheight(path) from source + |select rst_pixelheight(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_pixelheight("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_pixelheight($"tile")) .select("result") + .collect() val result = df.as[Double].collect() diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidthBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidthBehaviors.scala index 6ce29bb6b..895c12a52 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidthBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidthBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_PixelWidthBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_pixelwidth($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_pixelwidth($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_pixelwidth(path) from source + |select rst_pixelwidth(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_pixelwidth("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_pixelwidth($"tile")) .select("result") val result = df.as[Double].collect() @@ -39,7 +40,7 @@ trait RST_PixelWidthBehaviors extends QueryTest { result.head > 0 shouldBe true an[Exception] should be thrownBy spark.sql(""" - |select rst_pixelwidth(path, 1, 1) from source + |select rst_pixelwidth() from source |""".stripMargin) } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvgBehaviors.scala index ec0569cb9..2a08fe559 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvgBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridAvgBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,25 +16,27 @@ trait RST_RasterToGridAvgBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridavg($"path", lit(3))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_rastertogridavg($"tile", lit(3))) .select("result") .select(explode($"result").as("result")) .select(explode($"result").as("result")) .select($"result".getItem("measure").as("result")) - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertogridavg(path, 3) from source + |select rst_rastertogridavg(tile, 3) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridavg("/dummy/path", lit(3))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridavg($"tile", lit(3))) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCountBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCountBehaviors.scala index 207cc6a19..2d1eca342 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCountBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridCountBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,25 +16,27 @@ trait RST_RasterToGridCountBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridcount($"path", lit(3))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_rastertogridcount($"tile", lit(3))) .select("result") .select(explode($"result").as("result")) .select(explode($"result").as("result")) .select($"result".getItem("measure").as("result")) - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertogridcount(path, 3) from source + |select rst_rastertogridcount(tile, 3) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridcount("/dummy/path", lit(3))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridcount($"tile", lit(3))) .select("result") val result = df.as[Int].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMaxBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMaxBehaviors.scala index 37cc6e84e..f150abdf9 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMaxBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMaxBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,25 +16,27 @@ trait RST_RasterToGridMaxBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridmax($"path", lit(3))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) .select("result") .select(explode($"result").as("result")) .select(explode($"result").as("result")) .select($"result".getItem("measure").as("result")) - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertogridmax(path, 3) from source + |select rst_rastertogridmax(tile, 3) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridmax("/dummy/path", lit(3))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedianBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedianBehaviors.scala index fc6c02c2c..49ca59dd3 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedianBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMedianBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,25 +16,27 @@ trait RST_RasterToGridMedianBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridmedian($"path", lit(3))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_rastertogridmedian($"tile", lit(3))) .select("result") .select(explode($"result").as("result")) .select(explode($"result").as("result")) .select($"result".getItem("measure").as("result")) - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertogridmedian(path, 3) from source + |select rst_rastertogridmedian(tile, 3) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridmedian("/dummy/path", lit(3))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmedian($"tile", lit(3))) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMinBehaviors.scala index 9999ddbcf..134f0bfa4 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMinBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToGridMinBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,25 +16,27 @@ trait RST_RasterToGridMinBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridmin($"path", lit(3))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_rastertogridmin($"tile", lit(3))) .select("result") .select(explode($"result").as("result")) .select(explode($"result").as("result")) .select($"result".getItem("measure").as("result")) - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertogridmin(path, 3) from source + |select rst_rastertogridmin(tile, 3) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_rastertogridmin("/dummy/path", lit(3))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmin($"tile", lit(3))) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordBehaviors.scala index 805b0887a..8265e745a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions.lit import org.scalatest.matchers.should.Matchers._ @@ -17,23 +16,25 @@ trait RST_RasterToWorldCoordBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rastertoworldcoord($"path", lit(2), lit(2))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_rastertoworldcoord($"tile", lit(2), lit(2))) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertoworldcoord(path, 2, 2) from source + |select rst_rastertoworldcoord(tile, 2, 2) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rastertoworldcoord(lit("/dummy/path"), 2, 2)) - .withColumn("result", rst_rastertoworldcoord("/dummy/path", lit(2), lit(2))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertoworldcoord(lit($"tile"), 2, 2)) + .withColumn("result", rst_rastertoworldcoord($"tile", lit(2), lit(2))) .select("result") val result = df.as[String].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordXBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordXBehaviors.scala index f900384ee..079e0839b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordXBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordXBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions.lit import org.scalatest.matchers.should.Matchers._ @@ -17,23 +16,25 @@ trait RST_RasterToWorldCoordXBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rastertoworldcoordx($"path", lit(2), lit(2))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_rastertoworldcoordx($"tile", lit(2), lit(2))) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertoworldcoordx(path, 2, 2) from source + |select rst_rastertoworldcoordx(tile, 2, 2) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rastertoworldcoordx(lit("/dummy/path"), 2, 2)) - .withColumn("result", rst_rastertoworldcoordx("/dummy/path", lit(2), lit(2))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertoworldcoordx(lit($"tile"), 2, 2)) + .withColumn("result", rst_rastertoworldcoordx($"tile", lit(2), lit(2))) .select("result") val result = df.as[Double].collect().max @@ -45,7 +46,6 @@ trait RST_RasterToWorldCoordXBehaviors extends QueryTest { |""".stripMargin) noException should be thrownBy rst_rastertoworldcoordx(lit("/dummy/path"), 2, 2) - noException should be thrownBy rst_rastertoworldcoordx("/dummy/path", lit(2), lit(2)) noException should be thrownBy rst_rastertoworldcoordx(lit("/dummy/path"), lit(2), lit(2)) } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordYBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordYBehaviors.scala index 9ba8e2c11..c27722e8b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordYBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordYBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions.lit import org.scalatest.matchers.should.Matchers._ @@ -17,23 +16,25 @@ trait RST_RasterToWorldCoordYBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rastertoworldcoordy($"path", lit(2), lit(2))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_rastertoworldcoordy($"tile", lit(2), lit(2))) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rastertoworldcoordy(path, 2, 2) from source + |select rst_rastertoworldcoordy(tile, 2, 2) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) + noException should be thrownBy rastersInMemory .withColumn("result", rst_rastertoworldcoordy(lit("/dummy/path"), 2, 2)) - .withColumn("result", rst_rastertoworldcoordy("/dummy/path", lit(2), lit(2))) + .withColumn("result", rst_rastertoworldcoordy($"tile", lit(2), lit(2))) .select("result") val result = df.as[Double].collect().max @@ -45,7 +46,6 @@ trait RST_RasterToWorldCoordYBehaviors extends QueryTest { |""".stripMargin) noException should be thrownBy rst_rastertoworldcoordy(lit("/dummy/path"), 2, 2) - noException should be thrownBy rst_rastertoworldcoordy("/dummy/path", lit(2), lit(2)) noException should be thrownBy rst_rastertoworldcoordy(lit("/dummy/path"), lit(2), lit(2)) } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTileBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTileBehaviors.scala index 7c96cea26..608c3de85 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTileBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTileBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,26 +16,28 @@ trait RST_ReTileBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_retile($"path", lit(100), lit(100))) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_retile($"tile", lit(400), lit(400))) .select("result") - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_retile(path, 100, 100) from source + |select rst_retile(tile, 400, 400) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_retile($"path", 100, 100)) - .withColumn("result", rst_retile("/dummy/path", 100, 100)) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_retile($"tile", 400, 400)) + .withColumn("result", rst_retile($"tile", 400, 400)) .select("result") - val result = df.as[String].collect().length + val result = df.collect().length result should be > 0 diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RotationBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RotationBehaviors.scala index 24ca88f51..6469d7292 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RotationBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_RotationBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_RotationBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rotation($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_rotation($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_rotation(path) from source + |select rst_rotation(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_rotation("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rotation($"tile")) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRIDBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRIDBehaviors.scala index 7b3b8d028..debe3d0a1 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRIDBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRIDBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_SRIDBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_srid($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("result", rst_srid($"tile")) .select("result") - mocks - .getGeotiffBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_srid(path) from source + |select rst_srid(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getGeotiffBinaryDf(spark) - .withColumn("result", rst_srid("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_srid($"tile")) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleXBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleXBehaviors.scala index 12cd90f12..e12dca7fe 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleXBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleXBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_ScaleXBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_scalex($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_scalex($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_scalex(path) from source + |select rst_scalex(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_scalex("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_scalex($"tile")) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleYBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleYBehaviors.scala index fca203e78..e264199b1 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleYBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleYBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_ScaleYBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_scaley($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_scaley($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_scaley(path) from source + |select rst_scaley(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_scaley("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_scaley($"tile")) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewXBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewXBehaviors.scala index 322911aea..2a5b5e3db 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewXBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewXBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_SkewXBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_skewx($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_skewx($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_skewx(path) from source + |select rst_skewx(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_skewx("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_skewx($"tile")) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewYBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewYBehaviors.scala index 4f5b495b6..294157065 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewYBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewYBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_SkewYBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_skewy($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_skewy($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_skewy(path) from source + |select rst_skewy(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_skewy("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_skewy($"tile")) .select("result") val result = df.as[Double].collect().max diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SubdatasetsBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SubdatasetsBehaviors.scala index 50c04ec6f..ad713f17c 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SubdatasetsBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SubdatasetsBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,38 +15,41 @@ trait RST_SubdatasetsBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val rasterDfWithSubdatasets = mocks - .getNetCDFBinaryDf(spark) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val rasterDfWithSubdatasets = rastersInMemory .select( - rst_subdatasets($"path") + rst_subdatasets($"tile") .alias("subdatasets") ) val result = rasterDfWithSubdatasets.as[Map[String, String]].collect() - mocks - .getNetCDFBinaryDf(spark) - .orderBy("path") + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_subdatasets(path) from source + |select rst_subdatasets(tile) from source |""".stripMargin) an[Exception] should be thrownBy spark.sql(""" |select rst_subdatasets() from source |""".stripMargin) - noException should be thrownBy spark.sql(""" - |select rst_subdatasets("dummy/path") from source - |""".stripMargin) - result.head.keys.toList.length shouldBe 6 - result.head.values.toList.map(_.nonEmpty).reduce(_ && _) shouldBe true + noException should be thrownBy rastersInMemory + .select( + rst_subdatasets($"tile") + .alias("subdatasets") + ) + .take(1) - noException should be thrownBy rst_subdatasets($"path") - noException should be thrownBy rst_subdatasets("path") + result.head.values.toList.map(_.nonEmpty).reduce(_ && _) shouldBe true + noException should be thrownBy rst_subdatasets($"tile") } } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SummaryBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SummaryBehaviors.scala index d654851ad..1d53cdb4a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SummaryBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_SummaryBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_SummaryBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_summary($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_summary($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_summary(path) from source + |select rst_summary(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_summary("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_summary($"tile")) .select("result") val result = df.as[String].collect().head.length diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala new file mode 100644 index 000000000..daad95af0 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala @@ -0,0 +1,45 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_TessellateBehaviors extends QueryTest { + + // noinspection MapGetGet + def tessellateBehavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("tiles", rst_tessellate($"tile", 3)) + .select("tiles") + + rastersInMemory + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_tessellate(tile, 3) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("tiles", rst_tessellate($"tile", 3)) + .select("tiles") + + val result = gridTiles.collect() + + result.length should be(380) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateTest.scala new file mode 100644 index 000000000..5bc0eae57 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_TessellateTest extends QueryTest with SharedSparkSessionGDAL with RST_TessellateBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_GridTiles with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + tessellateBehavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftXBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftXBehaviors.scala index e40c68d6e..88e5ecd3a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftXBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftXBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_UpperLeftXBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_upperleftx($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_upperleftx($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_upperleftx(path) from source + |select rst_upperleftx(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_upperleftx("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_upperleftx($"tile")) .select("result") val result = df.as[String].collect().head.length diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftYBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftYBehaviors.scala index 7e8993512..fc83d11d2 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftYBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftYBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_UpperLeftYBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_upperlefty($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_upperlefty($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_upperlefty(path) from source + |select rst_upperlefty(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_upperlefty("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_upperlefty($"tile")) .select("result") val result = df.as[String].collect().head.length diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WidthBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WidthBehaviors.scala index f65c2ff09..885a3e05a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WidthBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WidthBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.scalatest.matchers.should.Matchers._ @@ -16,22 +15,24 @@ trait RST_WidthBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_width($"path")) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_width($"tile")) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_width(path) from source + |select rst_width(tile) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_width("/dummy/path")) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_width($"tile")) .select("result") val result = df.as[String].collect().head.length diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordBehaviors.scala index 7dc2d11d2..4aaf86b3e 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,23 +16,25 @@ trait RST_WorldToRasterCoordBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_worldtorastercoord($"path", 0, 0)) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_worldtorastercoord($"tile", 0, 0)) .select($"result".getItem("x").as("x"), $"result".getItem("y").as("y")) - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_worldtorastercoord(path, 1, 1) from source + |select rst_worldtorastercoord(tile, 1, 1) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_worldtorastercoord("/dummy/path", 0, 0)) - .withColumn("result", rst_worldtorastercoord($"path", lit(0), lit(0))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_worldtorastercoord($"tile", 0, 0)) + .withColumn("result", rst_worldtorastercoord($"tile", lit(0), lit(0))) .select("result") noException should be thrownBy df.as[(Int, Int)].collect().head diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordXBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordXBehaviors.scala index 2c400fd8c..9dc26422a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordXBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordXBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,23 +16,25 @@ trait RST_WorldToRasterCoordXBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_worldtorastercoordx($"path", 0, 0)) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_worldtorastercoordx($"tile", 0, 0)) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_worldtorastercoordx(path, 1, 1) from source + |select rst_worldtorastercoordx(tile, 1, 1) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_worldtorastercoordx("/dummy/path", 0, 0)) - .withColumn("result", rst_worldtorastercoordx($"path", lit(0), lit(0))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_worldtorastercoordx($"tile", 0, 0)) + .withColumn("result", rst_worldtorastercoordx($"tile", lit(0), lit(0))) .select("result") val result = df.as[String].collect().head.length diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordYBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordYBehaviors.scala index 0d2bc552d..e2a259b55 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordYBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordYBehaviors.scala @@ -3,7 +3,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.mocks import org.apache.spark.sql.QueryTest import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ @@ -17,23 +16,25 @@ trait RST_WorldToRasterCoordYBehaviors extends QueryTest { import mc.functions._ import sc.implicits._ - val df = mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_worldtorastercoordy($"path", 0, 0)) + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/binary/netcdf-coral") + + val df = rastersInMemory + .withColumn("result", rst_worldtorastercoordy($"tile", 0, 0)) .select("result") - mocks - .getNetCDFBinaryDf(spark) + rastersInMemory .createOrReplaceTempView("source") noException should be thrownBy spark.sql(""" - |select rst_worldtorastercoordy(path, 1, 1) from source + |select rst_worldtorastercoordy(tile, 1, 1) from source |""".stripMargin) - noException should be thrownBy mocks - .getNetCDFBinaryDf(spark) - .withColumn("result", rst_worldtorastercoordy("/dummy/path", 0, 0)) - .withColumn("result", rst_worldtorastercoordy($"path", lit(0), lit(0))) + noException should be thrownBy rastersInMemory + .withColumn("result", rst_worldtorastercoordy($"tile", 0, 0)) + .withColumn("result", rst_worldtorastercoordy($"tile", lit(0), lit(0))) .select("result") val result = df.as[Double].collect().head diff --git a/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextBehaviors.scala index ef87b5947..96ef8e559 100644 --- a/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextBehaviors.scala @@ -33,7 +33,7 @@ trait MosaicContextBehaviors extends MosaicSpatialQueryTest { MosaicContext.indexSystem match { case BNGIndexSystem => mc.getIndexSystem.getCellIdDataType shouldEqual StringType case H3IndexSystem => mc.getIndexSystem.getCellIdDataType shouldEqual LongType - case _ => mc.getIndexSystem.getCellIdDataType shouldEqual LongType + case _ => mc.getIndexSystem.getCellIdDataType shouldEqual LongType } an[Error] should be thrownBy mc.setCellIdDataType("binary") } @@ -58,12 +58,12 @@ trait MosaicContextBehaviors extends MosaicSpatialQueryTest { val gridCellLong = indexSystem match { case BNGIndexSystem => lit(1050138790).expr case H3IndexSystem => lit(623060282076758015L).expr - case _ => lit(0L).expr + case _ => lit(0L).expr } val gridCellStr = indexSystem match { case BNGIndexSystem => lit("TQ388791").expr case H3IndexSystem => lit("8a58e0682d6ffff").expr - case _ => lit("0").expr + case _ => lit("0").expr } noException should be thrownBy getFunc("as_hex").apply(Seq(pointWkt)) @@ -212,9 +212,9 @@ trait MosaicContextBehaviors extends MosaicSpatialQueryTest { functionBuilder ) registry.registerFunction( - FunctionIdentifier("h3_distance", None), - new ExpressionInfo("product", "h3_distance"), - functionBuilder + FunctionIdentifier("h3_distance", None), + new ExpressionInfo("product", "h3_distance"), + functionBuilder ) mc.register(spark) @@ -248,6 +248,13 @@ trait MosaicContextBehaviors extends MosaicSpatialQueryTest { method.apply(1).asInstanceOf[Int] shouldBe 2 } + def printWarnings(): Unit = { + spark.conf.set("spark.databricks.clusterUsageTags.sparkVersion", "x") + spark.conf.set("spark.databricks.photon.enabled", "false") + spark.conf.set("spark.databricks.clusterUsageTags.clusterType", "x") + MosaicContext.checkDBR(spark) should be(false) + } + } object MosaicContextBehaviors extends MockFactory { diff --git a/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextTest.scala b/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextTest.scala index 193f637df..22a4e2112 100644 --- a/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/functions/MosaicContextTest.scala @@ -17,5 +17,6 @@ class MosaicContextTest extends MosaicSpatialQueryTest with SharedSparkSession w test("MosaicContext lookup correct sql functions") { sqlFunctionLookup() } test("MosaicContext should use databricks h3") { callDatabricksH3() } test("MosaicContext should correctly reflect functions") { reflectedMethods() } + test("MosaicContext should printWarning") { printWarnings() } } diff --git a/src/test/scala/com/databricks/labs/mosaic/functions/MosaicRegistryBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/functions/MosaicRegistryBehaviors.scala index edf85a2b7..e47f3c797 100644 --- a/src/test/scala/com/databricks/labs/mosaic/functions/MosaicRegistryBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/functions/MosaicRegistryBehaviors.scala @@ -38,7 +38,7 @@ object MosaicRegistryBehaviors extends MockFactory { ix.name _ when () returns H3.name val gapi = stub[GeometryAPI] gapi.name _ when () returns JTS.name - MosaicContext.build(ix, gapi, GDAL) + MosaicContext.build(ix, gapi) } } diff --git a/src/test/scala/com/databricks/labs/mosaic/functions/auxiliary/BadIndexSystem.scala b/src/test/scala/com/databricks/labs/mosaic/functions/auxiliary/BadIndexSystem.scala index 3ae228faf..162224242 100644 --- a/src/test/scala/com/databricks/labs/mosaic/functions/auxiliary/BadIndexSystem.scala +++ b/src/test/scala/com/databricks/labs/mosaic/functions/auxiliary/BadIndexSystem.scala @@ -10,6 +10,8 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType} // Used for testing only object BadIndexSystem extends IndexSystem(BooleanType) { + override def crsID: Int = throw new UnsupportedOperationException + val name = "BadIndexSystem" override def getResolutionStr(resolution: Int): String = throw new UnsupportedOperationException diff --git a/src/test/scala/com/databricks/labs/mosaic/sql/extensions/TestSQLExtensions.scala b/src/test/scala/com/databricks/labs/mosaic/sql/extensions/TestSQLExtensions.scala index d1d697006..5c332ae79 100644 --- a/src/test/scala/com/databricks/labs/mosaic/sql/extensions/TestSQLExtensions.scala +++ b/src/test/scala/com/databricks/labs/mosaic/sql/extensions/TestSQLExtensions.scala @@ -3,7 +3,7 @@ package com.databricks.labs.mosaic.sql.extensions import com.databricks.labs.mosaic._ import com.databricks.labs.mosaic.core.geometry.api.{ESRI, JTS} import com.databricks.labs.mosaic.core.index.{BNGIndexSystem, H3IndexSystem} -import com.databricks.labs.mosaic.core.raster.api.RasterAPI.GDAL +import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.SparkSuite import org.apache.spark.SparkConf @@ -19,7 +19,7 @@ class TestSQLExtensions extends AnyFlatSpec with SQLExtensionsBehaviors with Spa .set(MOSAIC_RASTER_API, "GDAL") .set("spark.sql.extensions", "com.databricks.labs.mosaic.sql.extensions.MosaicSQL") var spark = withConf(conf) - it should behave like sqlRegister(MosaicContext.build(H3IndexSystem, JTS, GDAL), spark) + it should behave like sqlRegister(MosaicContext.build(H3IndexSystem, JTS), spark) conf = new SparkConf(false) .set(MOSAIC_INDEX_SYSTEM, "H3") @@ -27,7 +27,7 @@ class TestSQLExtensions extends AnyFlatSpec with SQLExtensionsBehaviors with Spa .set(MOSAIC_RASTER_API, "GDAL") .set("spark.sql.extensions", "com.databricks.labs.mosaic.sql.extensions.MosaicSQL") spark = withConf(conf) - it should behave like sqlRegister(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) + it should behave like sqlRegister(MosaicContext.build(H3IndexSystem, ESRI), spark) conf = new SparkConf(false) .set(MOSAIC_INDEX_SYSTEM, "BNG") @@ -35,7 +35,7 @@ class TestSQLExtensions extends AnyFlatSpec with SQLExtensionsBehaviors with Spa .set(MOSAIC_RASTER_API, "GDAL") .set("spark.sql.extensions", "com.databricks.labs.mosaic.sql.extensions.MosaicSQL") spark = withConf(conf) - it should behave like sqlRegister(MosaicContext.build(BNGIndexSystem, JTS, GDAL), spark) + it should behave like sqlRegister(MosaicContext.build(BNGIndexSystem, JTS), spark) conf = new SparkConf(false) .set(MOSAIC_INDEX_SYSTEM, "BNG") @@ -43,7 +43,7 @@ class TestSQLExtensions extends AnyFlatSpec with SQLExtensionsBehaviors with Spa .set(MOSAIC_RASTER_API, "GDAL") .set("spark.sql.extensions", "com.databricks.labs.mosaic.sql.extensions.MosaicSQL") spark = withConf(conf) - it should behave like sqlRegister(MosaicContext.build(BNGIndexSystem, ESRI, GDAL), spark) + it should behave like sqlRegister(MosaicContext.build(BNGIndexSystem, ESRI), spark) conf = new SparkConf(false) .set(MOSAIC_INDEX_SYSTEM, "DummyIndex") @@ -58,7 +58,7 @@ class TestSQLExtensions extends AnyFlatSpec with SQLExtensionsBehaviors with Spa conf = new SparkConf(false) .set("spark.sql.extensions", "com.databricks.labs.mosaic.sql.extensions.MosaicSQLDefault") spark = withConf(conf) - it should behave like sqlRegister(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) + it should behave like sqlRegister(MosaicContext.build(H3IndexSystem, ESRI), spark) } @@ -69,7 +69,7 @@ class TestSQLExtensions extends AnyFlatSpec with SQLExtensionsBehaviors with Spa .set("spark.sql.extensions", "com.databricks.labs.mosaic.sql.extensions.MosaicGDAL") .set(MOSAIC_GDAL_NATIVE, "true") val spark = withConf(conf) - it should behave like mosaicGDAL(MosaicContext.build(H3IndexSystem, ESRI, GDAL), spark) + it should behave like mosaicGDAL(MosaicContext.build(H3IndexSystem, ESRI), spark) } diff --git a/src/test/scala/com/databricks/labs/mosaic/test/MosaicGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/test/MosaicGDAL.scala deleted file mode 100644 index ddbd561f4..000000000 --- a/src/test/scala/com/databricks/labs/mosaic/test/MosaicGDAL.scala +++ /dev/null @@ -1,42 +0,0 @@ -package com.databricks.labs.mosaic.test - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession - -import scala.io.{BufferedSource, Source} -import scala.sys.process._ - -object MosaicGDAL extends Logging { - - def installGDAL(spark: SparkSession): Unit = installGDAL(Some(spark)) - - def installGDAL(spark: Option[SparkSession]): Unit = { - val sc = spark.map(_.sparkContext) - val numExecutors = sc.map(_.getExecutorMemoryStatus.size - 1) - val script = getScript - for (cmd <- script.getLines.toList) { - try { - if (!cmd.startsWith("#") || cmd.nonEmpty) cmd.!! - sc.map { sparkContext => - if (!sparkContext.isLocal) { - sparkContext.parallelize(1 to numExecutors.get).pipe(cmd).collect - } - } - } catch { - case e: Throwable => logError(e.getMessage) - } finally { - script.close - } - } - } - - private def getScript: BufferedSource = { - val scriptPath = System.getProperty("os.name").toLowerCase() match { - case o: String if o.contains("nux") => "/scripts/install-gdal-databricks.sh" - case _ => throw new UnsupportedOperationException("This method only supports Ubuntu Linux with `apt`.") - } - val script = Source.fromInputStream(getClass.getResourceAsStream(scriptPath)) - script - } - -} diff --git a/src/test/scala/com/databricks/labs/mosaic/test/SparkSuite.scala b/src/test/scala/com/databricks/labs/mosaic/test/SparkSuite.scala index a470c9112..a1c05d893 100644 --- a/src/test/scala/com/databricks/labs/mosaic/test/SparkSuite.scala +++ b/src/test/scala/com/databricks/labs/mosaic/test/SparkSuite.scala @@ -9,7 +9,7 @@ trait SparkSuite extends TestSuite with BeforeAndAfterAll { var sparkConf: SparkConf = new SparkConf(false) - .set("spark.executor.extraLibraryPath", "/usr/local/lib/gdal") + .set("spark.executor.extraLibraryPath", "/usr/lib/gdal") @transient private var _sc: SparkContext = _ @transient private var _spark: SparkSession = _ diff --git a/src/test/scala/com/databricks/labs/mosaic/test/TestMosaicGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/test/TestMosaicGDAL.scala deleted file mode 100644 index e486247ad..000000000 --- a/src/test/scala/com/databricks/labs/mosaic/test/TestMosaicGDAL.scala +++ /dev/null @@ -1,54 +0,0 @@ -package com.databricks.labs.mosaic.test - -import com.databricks.labs.mosaic.gdal.MosaicGDAL._ -import com.twitter.chill.Base64.InputStream -import org.apache.spark.sql.SparkSession -import org.apache.spark.SparkException -import org.apache.spark.internal.Logging - -import java.io.{ByteArrayInputStream, IOException} -import scala.io.{BufferedSource, Source} -import scala.sys.process._ - -object TestMosaicGDAL extends Logging { - - def installGDAL(spark: SparkSession): Unit = { - if (!wasEnabled(spark) && !isEnabled) installGDAL(Some(spark)) - } - - def installGDAL(spark: Option[SparkSession]): Unit = { - val sc = spark.map(_.sparkContext) - val numExecutors = sc.map(_.getExecutorMemoryStatus.size - 1) - val script = getScript - for (cmd <- script.getLines.toList) { - try { - if (!cmd.startsWith("#") || cmd.nonEmpty) cmd.!! - sc.map { sparkContext => - if (!sparkContext.isLocal) { - sparkContext.parallelize(1 to numExecutors.get).pipe(cmd).collect - } - } - } catch { - case e: IOException => logError(e.getMessage) - case e: IllegalStateException => logError(e.getMessage) - case e: SparkException => logError(e.getMessage) - case e: Throwable => logError(e.getMessage) - } finally { - script.close - } - } - } - - def getScript: BufferedSource = { - System.getProperty("os.name").toLowerCase() match { - case o: String if o.contains("nux") => - val script = Source.fromInputStream(getClass.getResourceAsStream("/scripts/install-gdal-databricks.sh")) - script - case _ => - logInfo("This method only supports Ubuntu Linux with `apt`.") - Source.fromInputStream(getClass.getResourceAsStream("")) - } - - } - -} diff --git a/src/test/scala/com/databricks/labs/mosaic/test/package.scala b/src/test/scala/com/databricks/labs/mosaic/test/package.scala index fd0e886cd..435ee552c 100644 --- a/src/test/scala/com/databricks/labs/mosaic/test/package.scala +++ b/src/test/scala/com/databricks/labs/mosaic/test/package.scala @@ -18,7 +18,7 @@ package object test { import org.apache.spark.sql._ import org.apache.spark.sql.types.{StructField, StructType} - val hex_rows_epsg4326 = + val hex_rows_epsg4326: List[List[Any]] = List( List( 1, @@ -52,7 +52,7 @@ package object test { case id :: hex :: _ => List(id, JTS.geometry(hex, "HEX").mapXY((x, y) => (math.abs(x) * 1000, math.abs(y) * 1000)).toHEX) case _ => throw new Error("Unexpected test data format!") } - val wkt_rows_epsg4326 = + val wkt_rows_epsg4326: List[List[Any]] = List( List(1, "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"), List(2, "MULTIPOLYGON (((0 0, 0 1, 2 2, 0 0)))"), @@ -73,7 +73,7 @@ package object test { List(7, "LINESTRING (30 10, 10 30, 40 40)"), List(8, "MULTILINESTRING ((10 10, 20 20, 10 40), (40 40, 30 30, 40 20, 30 10))") ) - val wkt_rows_epsg27700 = + val wkt_rows_epsg27700: List[List[Any]] = List( List(1, "POLYGON ((30000 10000, 40000 40000, 20000 40000, 10000 20000, 30000 10000))"), List(2, "MULTIPOLYGON (((0 0, 0 1000, 2000 2000, 0 0)))"), @@ -94,7 +94,7 @@ package object test { List(7, "LINESTRING (30000 10000, 10000 30000, 40000 40000)"), List(8, "MULTILINESTRING ((10000 10000, 20000 20000, 10000 40000), (40000 40000, 30000 30000, 40000 20000, 30000 10000))") ) - val geoJSON_rows = + val geoJSON_rows: List[List[Any]] = List( List( 1, @@ -126,7 +126,7 @@ package object test { """{"type":"MultiLineString","coordinates":[[[10,10],[20,20],[10,40]],[[40,40],[30,30],[40,20],[30,10]]],"crs":{"type":"name","properties":{"name":"EPSG:0"}}}""" ) ) - val wkt_rows_boroughs_epsg4326 = + val wkt_rows_boroughs_epsg4326: List[List[Any]] = List( List( 1, @@ -344,21 +344,13 @@ package object test { Paths.get(inFile.getPath).toAbsolutePath.toString } - // noinspection ScalaCustomHdfsFormat - def getBinaryDf(spark: SparkSession, resourcePath: String, pathGlobFilter: String): DataFrame = - spark.read.format("binaryFile").option("pathGlobFilter", pathGlobFilter).load(resourcePath) - - def getGeotiffBinaryDf(spark: SparkSession): DataFrame = getBinaryDf(spark, "src/test/resources/modis/", "*.TIF") - - def getGribBinaryDf(spark: SparkSession): DataFrame = getBinaryDf(spark, "src/test/resources/binary/grib-cams", "*.grib") - - def getNetCDFBinaryDf(spark: SparkSession): DataFrame = getBinaryDf(spark, "src/test/resources/binary/netcdf-coral", "*.nc") - } // noinspection NotImplementedCode, ScalaStyle object MockIndexSystem extends IndexSystem(LongType) { + override def crsID: Int = ??? + override def name: String = "MOCK" override def polyfill(geometry: MosaicGeometry, resolution: Int, geometryAPI: Option[GeometryAPI]): Seq[Long] = ??? diff --git a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala index 537333908..a666e0578 100644 --- a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala +++ b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala @@ -1,10 +1,10 @@ package org.apache.spark.sql.test import com.databricks.labs.mosaic.gdal.MosaicGDAL -import com.databricks.labs.mosaic.test.TestMosaicGDAL import com.databricks.labs.mosaic.{MOSAIC_GDAL_NATIVE, MOSAIC_RASTER_CHECKPOINT} import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession +import org.gdal.gdal.gdal import java.nio.file.Files import scala.util.Try @@ -21,15 +21,19 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { conf.set(MOSAIC_RASTER_CHECKPOINT, Files.createTempDirectory("mosaic").toFile.getAbsolutePath) SparkSession.cleanupAnyExistingSession() val session = new TestSparkSession(conf) - if (conf.get(MOSAIC_GDAL_NATIVE, "false").toBoolean) { - Try { - TestMosaicGDAL.installGDAL(session) - val tempPath = Files.createTempDirectory("mosaic-gdal") - MosaicGDAL.prepareEnvironment(session, tempPath.toAbsolutePath.toString, "/usr/lib/jni") - MosaicGDAL.enableGDAL(session) - } + session.sparkContext.setLogLevel("FATAL") + Try { + val tempPath = Files.createTempDirectory("mosaic-gdal") + MosaicGDAL.prepareEnvironment(session, tempPath.toAbsolutePath.toString) + MosaicGDAL.enableGDAL(session) } session } + override def beforeEach(): Unit = { + super.beforeEach() + MosaicGDAL.enableGDAL(this.spark) + gdal.AllRegister() + } + }