Skip to content

Commit

Permalink
Merge pull request #478 from databrickslabs/feature/remove_vsimem
Browse files Browse the repository at this point in the history
Add ExpressionConfig to enableGdal() to allow fo customization of GDA…
  • Loading branch information
Milos Colic authored Dec 11, 2023
2 parents 1aaea1c + cb13a17 commit bdf6826
Show file tree
Hide file tree
Showing 44 changed files with 203 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package com.databricks.labs.mosaic.core.raster.api
import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, MosaicRasterGDAL}
import com.databricks.labs.mosaic.core.raster.io.RasterCleaner
import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.gdal.gdal.gdal
Expand All @@ -17,6 +19,14 @@ import java.util.UUID
*/
object GDAL {

def dropDrivers(): Unit = {
val n = gdal.GetDriverCount()
for (i <- 0 until n) {
val driver = gdal.GetDriver(i)
driver.delete()
}
}

/**
* Returns the no data value for the given GDAL data type. For non-numeric
* data types, it returns 0.0. For numeric data types, it returns the
Expand Down Expand Up @@ -52,12 +62,17 @@ object GDAL {
* on the worker nodes. This method registers all the drivers on the worker
* nodes.
*/
def enable(): Unit = {
configureGDAL()
def enable(mosaicConfig: MosaicExpressionConfig): Unit = {
configureGDAL(mosaicConfig)
gdal.UseExceptions()
gdal.AllRegister()
}

def enable(spark: SparkSession): Unit = {
val mosaicConfig = MosaicExpressionConfig(spark)
enable(mosaicConfig)
}

/**
* Returns the extension of the given driver.
* @param driverShortName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,12 @@ case class MosaicRasterGDAL(
val isSubdataset = PathUtils.isSubdataset(path)
val filePath = if (isSubdataset) PathUtils.fromSubdatasetPath(path) else path
val pamFilePath = s"$filePath.aux.xml"
Try(gdal.GetDriverByName(driverShortName).Delete(path))
Try(Files.deleteIfExists(Paths.get(path)))
Try(Files.deleteIfExists(Paths.get(filePath)))
Try(Files.deleteIfExists(Paths.get(pamFilePath)))
if (path != PathUtils.getCleanPath(parentPath)) {
Try(gdal.GetDriverByName(driverShortName).Delete(path))
Try(Files.deleteIfExists(Paths.get(path)))
Try(Files.deleteIfExists(Paths.get(filePath)))
Try(Files.deleteIfExists(Paths.get(pamFilePath)))
}
}

/**
Expand Down Expand Up @@ -405,7 +407,9 @@ case class MosaicRasterGDAL(
}
val byteArray = Files.readAllBytes(Paths.get(readPath))
if (dispose) RasterCleaner.dispose(this)
Files.deleteIfExists(Paths.get(readPath))
if (readPath != PathUtils.getCleanPath(parentPath)) {
Files.deleteIfExists(Paths.get(readPath))
}
byteArray
}

Expand Down Expand Up @@ -506,6 +510,7 @@ object MosaicRasterGDAL extends RasterReader {
case Some(driverShortName) =>
val drivers = new JVector[String]()
drivers.add(driverShortName)
gdal.GetDriverByName(driverShortName).Register()
gdal.OpenEx(path, GA_ReadOnly, drivers)
case None => gdal.Open(path, GA_ReadOnly)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.datasource.gdal

import com.databricks.labs.mosaic.core.index.IndexSystemFactory
import com.databricks.labs.mosaic.core.raster.api.GDAL
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
Expand Down Expand Up @@ -35,7 +36,7 @@ class GDALFileFormat extends BinaryFileFormat {
* An instance of [[StructType]].
*/
override def inferSchema(sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = {
GDAL.enable()
GDAL.enable(sparkSession)

val reader = ReadStrategy.getReader(options)
val schema = super
Expand Down Expand Up @@ -116,28 +117,26 @@ class GDALFileFormat extends BinaryFileFormat {
options: Map[String, String],
hadoopConf: org.apache.hadoop.conf.Configuration
): PartitionedFile => Iterator[org.apache.spark.sql.catalyst.InternalRow] = {
GDAL.enable()
GDAL.enable(sparkSession)

val indexSystem = IndexSystemFactory.getIndexSystem(sparkSession)
val expressionConfig = MosaicExpressionConfig(sparkSession)

val supportedExtensions = options.getOrElse("extensions", "*").split(";").map(_.trim.toLowerCase(Locale.ROOT))

val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val filterFuncs = filters.flatMap(createFilterFunction)
val maxLength = sparkSession.conf.get("spark.sql.sources.binaryFile.maxLength", Int.MaxValue.toString).toInt

// Identify the reader to use for the file format.
// GDAL supports multiple reading strategies.
val reader = ReadStrategy.getReader(options)

file: PartitionedFile => {
GDAL.enable()
GDAL.enable(expressionConfig)
val path = new Path(new URI(file.filePath))
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
val status = fs.getFileStatus(path)

if (status.getLen > maxLength) throw CantReadBytesException(maxLength, status)

if (supportedExtensions.contains("*") || supportedExtensions.exists(status.getPath.getName.toLowerCase(Locale.ROOT).endsWith)) {
if (filterFuncs.forall(_.apply(status)) && isAllowedExtension(status, options)) {
reader.read(status, fs, requiredSchema, options, indexSystem)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ object ReadInMemory extends ReadStrategy {
indexSystem: IndexSystem
): Iterator[InternalRow] = {
val inPath = status.getPath.toString
val driverShortName = MosaicRasterGDAL.identifyDriver(inPath)
val readPath = PathUtils.getCleanPath(inPath)
val driverShortName = MosaicRasterGDAL.identifyDriver(readPath)
val contentBytes: Array[Byte] = readContent(fs, status)
val raster = MosaicRasterGDAL.readRaster(contentBytes, inPath, driverShortName)
val raster = MosaicRasterGDAL.readRaster(readPath, inPath)
val uuid = getUUID(status)

val fields = requiredSchema.fieldNames.filter(_ != TILE).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class RST_BoundingBox(
* The bounding box of the raster as a WKB polygon.
*/
override def rasterTransform(tile: MosaicRasterTile): Any = {
val raster = tile.getRaster
var raster = tile.getRaster
val gt = raster.getRaster.GetGeoTransform()
val (originX, originY) = GDAL.toWorldCoord(gt, 0, 0)
val (endX, endY) = GDAL.toWorldCoord(gt, raster.xSize, raster.ySize)
Expand All @@ -44,6 +44,7 @@ case class RST_BoundingBox(
).map(geometryAPI.fromCoords),
GeometryTypeEnum.POLYGON
)
raster = null
bboxPolygon.toWKB
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ case class RST_Clip(
override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = {
val geometry = geometryAPI.geometry(arg1, geometryExpr.dataType)
val geomCRS = geometry.getSpatialReferenceOSR
val clipped = RasterClipByVector.clip(tile.getRaster, geometry, geomCRS, geometryAPI)
tile.copy(raster = clipped)
tile.copy(
raster = RasterClipByVector.clip(tile.getRaster, geometry, geomCRS, geometryAPI)
)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ case class RST_CombineAvgAgg(
with UnaryLike[Expression]
with RasterExpressionSerialization {

GDAL.enable()

override lazy val deterministic: Boolean = true
override val child: Expression = rasterExpr
override val nullable: Boolean = false
Expand Down Expand Up @@ -61,18 +59,18 @@ case class RST_CombineAvgAgg(
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def eval(buffer: ArrayBuffer[Any]): Any = {
GDAL.enable()
GDAL.enable(expressionConfig)

if (buffer.isEmpty) {
null
} else {

// Do do move the expression
val tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType))
var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType))

// If merging multiple index rasters, the index value is dropped
val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null
val combined = CombineAVG.compute(tiles.map(_.getRaster)).flushCache()
var combined = CombineAVG.compute(tiles.map(_.getRaster)).flushCache()
// TODO: should parent path be an array?
val parentPath = tiles.head.getParentPath
val driver = tiles.head.getDriver
Expand All @@ -84,6 +82,9 @@ case class RST_CombineAvgAgg(
tiles.foreach(RasterCleaner.dispose(_))
RasterCleaner.dispose(result)

tiles = null
combined = null

result
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ case class RST_DerivedBand(
val pythonFunc = arg1.asInstanceOf[UTF8String].toString
val funcName = arg2.asInstanceOf[UTF8String].toString
val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null
val result = PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName)
MosaicRasterTile(
index,
result,
PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName),
tiles.head.getParentPath,
tiles.head.getDriver
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ case class RST_DerivedBandAgg(
with TernaryLike[Expression]
with RasterExpressionSerialization {

GDAL.enable()

override lazy val deterministic: Boolean = true
override val nullable: Boolean = false
override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType)
Expand Down Expand Up @@ -67,7 +65,7 @@ case class RST_DerivedBandAgg(
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def eval(buffer: ArrayBuffer[Any]): Any = {
GDAL.enable()
GDAL.enable(expressionConfig)

if (buffer.isEmpty) {
null
Expand All @@ -78,12 +76,12 @@ case class RST_DerivedBandAgg(
val funcName = funcNameExpr.eval(null).asInstanceOf[UTF8String].toString

// Do do move the expression
val tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType))
var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType))

// If merging multiple index rasters, the index value is dropped
val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null

val combined = PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName)
var combined = PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName)
// TODO: should parent path be an array?
val parentPath = tiles.head.getParentPath
val driver = tiles.head.getDriver
Expand All @@ -95,6 +93,9 @@ case class RST_DerivedBandAgg(
tiles.foreach(RasterCleaner.dispose(_))
RasterCleaner.dispose(result)

tiles = null
combined = null

result
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ case class RST_FromBands(
expressionConfig: MosaicExpressionConfig
) extends RasterArrayExpression[RST_FromBands](
bandsExpr,
RasterTileType(expressionConfig.getCellIdType),
RasterTileType(expressionConfig.getCellIdType),
returnsRaster = true,
expressionConfig = expressionConfig
)
Expand All @@ -31,8 +31,7 @@ case class RST_FromBands(
* The stacked and resampled raster.
*/
override def rasterTransform(rasters: Seq[MosaicRasterTile]): Any = {
val raster = MergeBands.merge(rasters.map(_.getRaster), "bilinear")
rasters.head.copy(raster = raster)
rasters.head.copy(raster = MergeBands.merge(rasters.map(_.getRaster), "bilinear"))
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ case class RST_FromFile(
with NullIntolerant
with CodegenFallback {

GDAL.enable()

override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType)

protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI)
Expand All @@ -60,26 +58,31 @@ case class RST_FromFile(
* The tiles.
*/
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
GDAL.enable()
GDAL.enable(expressionConfig)
val path = rasterPathExpr.eval(input).asInstanceOf[UTF8String].toString
val driver = MosaicRasterGDAL.identifyDriver(path)
val tmpPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver))
val readPath = PathUtils.getCleanPath(path)
Files.copy(Paths.get(readPath), Paths.get(tmpPath), StandardCopyOption.REPLACE_EXISTING)
val driver = MosaicRasterGDAL.identifyDriver(path)
val targetSize = sizeInMB.eval(input).asInstanceOf[Int]
if (targetSize <= 0) {
val raster = MosaicRasterGDAL.readRaster(tmpPath, path)
val tile = MosaicRasterTile(null, raster, path, raster.getDriversShortName)
if (targetSize <= 0 && Files.size(Paths.get(readPath)) <= Integer.MAX_VALUE) {
var raster = MosaicRasterGDAL.readRaster(readPath, path)
var tile = MosaicRasterTile(null, raster, path, raster.getDriversShortName)
val row = tile.formatCellId(indexSystem).serialize()
RasterCleaner.dispose(raster)
RasterCleaner.dispose(tile)
Files.deleteIfExists(Paths.get(tmpPath))
raster = null
tile = null
Seq(InternalRow.fromSeq(Seq(row)))
} else {
val tiles = ReTileOnRead.localSubdivide(tmpPath, path, targetSize)
// If target size is <0 and we are here that means the file is too big to fit in memory
// We split to tiles of size 64MB
val tmpPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver))
Files.copy(Paths.get(readPath), Paths.get(tmpPath), StandardCopyOption.REPLACE_EXISTING)
val size = if (targetSize <= 0) 64 else targetSize
var tiles = ReTileOnRead.localSubdivide(tmpPath, path, size)
val rows = tiles.map(_.formatCellId(indexSystem).serialize())
tiles.foreach(RasterCleaner.dispose(_))
Files.deleteIfExists(Paths.get(tmpPath))
tiles = null
rows.map(row => InternalRow.fromSeq(Seq(row)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ case class RST_GeoReference(raster: Expression, expressionConfig: MosaicExpressi

/** Returns the georeference of the raster. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
val raster = tile.getRaster
val geoTransform = raster.getRaster.GetGeoTransform()
val geoTransform = tile.getRaster.getRaster.GetGeoTransform()
buildMapDouble(
Map(
"upperLeftX" -> geoTransform(0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ case class RST_GetSubdataset(raster: Expression, subsetName: Expression, express
/** Returns the subdatasets of the raster. */
override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = {
val subsetName = arg1.asInstanceOf[UTF8String].toString
val subdataset = tile.getRaster.getSubdataset(subsetName)
tile.copy(raster = subdataset)
tile.copy(raster = tile.getRaster.getSubdataset(subsetName))
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ case class RST_InitNoData(
.map(GDAL.getNoDataConstant)
.mkString(" ")
val resultPath = PathUtils.createTmpFilePath(GDAL.getExtension(tile.getDriver))
val result = GDALWarp.executeWarp(
resultPath,
Seq(tile.getRaster),
command = s"""gdalwarp -of ${tile.getDriver} -dstnodata "$dstNoDataValues" -srcnodata "$noDataValues""""
val cmd = s"""gdalwarp -of ${tile.getDriver} -dstnodata "$dstNoDataValues" -srcnodata "$noDataValues""""
tile.copy(
raster = GDALWarp.executeWarp(
resultPath,
Seq(tile.getRaster),
command = cmd
)
)
tile.copy(raster = result)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ case class RST_IsEmpty(raster: Expression, expressionConfig: MosaicExpressionCon

/** Returns true if the raster is empty. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
val raster = tile.getRaster
(raster.ySize == 0 && raster.xSize == 0) || raster.isEmpty
var raster = tile.getRaster
val result = (raster.ySize == 0 && raster.xSize == 0) || raster.isEmpty
raster = null
result
}

}
Expand Down
Loading

0 comments on commit bdf6826

Please sign in to comment.