diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index b399fa0a2..40697f061 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -58,6 +58,8 @@ case class MosaicRasterGDAL( result } if (spatialRef == null) { + // Avoids null-CRS rasters + raster.SetSpatialRef(MosaicGDAL.WSG84) MosaicGDAL.WSG84 } else { spatialRef @@ -480,7 +482,9 @@ case class MosaicRasterGDAL( tmpPath } } - val byteArray = FileUtils.readBytes(readPath) + // For corrupted files, return empty byte array + // We will have the reason for corruption in the last_error field + val byteArray = Try(FileUtils.readBytes(readPath)).getOrElse(Array.empty[Byte]) if (dispose) RasterCleaner.dispose(this) if (readPath != PathUtils.getCleanPath(parentPath)) { Files.deleteIfExists(Paths.get(readPath)) @@ -608,7 +612,7 @@ case class MosaicRasterGDAL( .delete() val outputRaster = gdal.Open(resultRasterPath, GF_Write) - + for (bandIndex <- 1 to this.numBands) { val band = this.getBand(bandIndex) val outputBand = outputRaster.GetRasterBand(bandIndex) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala index 516560a76..f3d4e5dc1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import org.gdal.gdal.{WarpOptions, gdal} import java.nio.file.{Files, Paths} +import scala.util.Try /** GDALWarp is a wrapper for the GDAL Warp command. */ object GDALWarp { @@ -29,7 +30,7 @@ object GDALWarp { val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions) // Format will always be the same as the first raster val errorMsg = gdal.GetLastErrorMsg - val size = Files.size(Paths.get(outputPath)) + val size = Try(Files.size(Paths.get(outputPath))).getOrElse(-1L) val createInfo = Map( "path" -> outputPath, "parentPath" -> rasters.head.getParentPath, diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala index 5d7c5f5f2..7c7727cdf 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala @@ -32,7 +32,15 @@ object RasterProject { // Note that Null is the right value here val authName = destCRS.GetAuthorityName(null) val authCode = destCRS.GetAuthorityCode(null) - + + val srcAuthName = raster.getSpatialReference.GetAuthorityName(null) + val srcAuthCode = raster.getSpatialReference.GetAuthorityCode(null) + + // There is no need to translate if the CRSs match + if (authName == srcAuthName && authCode == srcAuthCode) { + return raster + } + val result = GDALWarp.executeWarp( resultFileName, Seq(raster), diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala index 9580cc441..73a415d26 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala @@ -22,17 +22,20 @@ object SeparateBands { def separate( tile: => MosaicRasterTile ): Seq[MosaicRasterTile] = { - val raster = tile.getRaster + val raster = if (tile.getRaster.getWriteOptions.format == "Zarr") { + zarrToNetCDF(tile).getRaster + } else { + tile.getRaster + } val tiles = for (i <- 0 until raster.numBands) yield { val fileExtension = raster.getRasterFileExtension val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortDriver = raster.getDriversShortName val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortDriver -b ${i + 1}", + command = s"gdal_translate -b ${i + 1}", writeOptions = outOptions ) @@ -49,8 +52,36 @@ object SeparateBands { val (_, valid) = tiles.partition(_._1) + if (tile.getRaster.getWriteOptions.format == "Zarr") dispose(raster) + + for (elem <- valid) { elem._2.raster.SetSpatialRef(raster.getSpatialReference) } valid.map(t => new MosaicRasterTile(null, t._2)) } + def zarrToNetCDF( + tile: => MosaicRasterTile + ): MosaicRasterTile = { + val raster = tile.getRaster + val fileExtension = "nc" + val rasterPath = PathUtils.createTmpFilePath(fileExtension) + val outOptions = raster.getWriteOptions.copy( + format = "NetCDF" + ) + + val result = GDALTranslate.executeTranslate( + rasterPath, + raster, + command = s"gdal_translate", + writeOptions = outOptions + ) + result.raster.SetSpatialRef(raster.getSpatialReference) + result.raster.FlushCache() + + val isEmpty = result.isEmpty + if (isEmpty) dispose(result) + + new MosaicRasterTile(tile.index, result) + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala index 2f5bf39b6..8d58d45a7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala @@ -20,7 +20,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead private val mc = MosaicContext.context() import mc.functions._ - def getNPartitions(config: Map[String, String]): Int = { + private def getNPartitions(config: Map[String, String]): Int = { val shufflePartitions = sparkSession.conf.get("spark.sql.shuffle.partitions") val nPartitions = config.getOrElse("nPartitions", shufflePartitions).toInt nPartitions @@ -75,7 +75,11 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead .agg(rst_combineavg_agg(col("tile")).alias("tile")) .withColumn( "grid_measures", - rasterToGridCombiner(col("tile")) + // when tessellation fails the last_error will be populated + // we should surface up the error but we cant aggregate + // so we force a null value + when(col("tile.metadata.last_error").isNotNull, lit(null)) + .otherwise(rasterToGridCombiner(col("tile"))) ) .select( "grid_measures", @@ -148,13 +152,17 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val readSubdataset = config("readSubdataset").toBoolean val subdatasetName = config("subdatasetName") - if (readSubdataset) { - pathsDf - .withColumn("subdatasets", rst_subdatasets(col("tile"))) - .withColumn("tile", rst_getsubdataset(col("tile"), lit(subdatasetName))) - } else { - pathsDf.select(col("tile")) - } + val resolved = + if (readSubdataset) { + pathsDf + .withColumn("subdatasets", rst_subdatasets(col("tile"))) + .withColumn("tile", rst_getsubdataset(col("tile"), lit(subdatasetName))) + } else { + pathsDf.select(col("tile")) + } + resolved + .withColumn("tile", rst_separatebands(col("tile"))) + .where(rst_pixelcount(col("tile")).getItem(0) > 0) } /** diff --git a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala index 12dcac6f3..d611ba81a 100644 --- a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala +++ b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala @@ -18,7 +18,7 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { override def createSparkSession: TestSparkSession = { val conf = sparkConf - conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/mnt/")) + conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/tmp/tmp")) SparkSession.cleanupAnyExistingSession() val session = new MosaicTestSparkSession(conf) session.sparkContext.setLogLevel("FATAL")