Skip to content

Commit

Permalink
Task/tpi summarize hfi (bcgov#3800)
Browse files Browse the repository at this point in the history
Co-authored-by: dgboss <[email protected]>

Summarizes advisory area based on TPI. Uses the snow-masked pmtiles hfi layer for the given date to mask contributing TPI pixels.

After finding out the initial implementation uses up to 20GB of memory, commit bcgov@ab72307 and onwards introduce a more memory optimized implementation. Basically all datasets are loaded from S3, but data is only read in and transformed a chunk at a time. This implementation stays under 4.5GB of memory usage.
  • Loading branch information
conbrad authored Aug 8, 2024
1 parent 32f3cf6 commit 1322133
Show file tree
Hide file tree
Showing 11 changed files with 519 additions and 218 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deployment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ jobs:
shell: bash
run: |
oc login "${{ secrets.OPENSHIFT_CLUSTER }}" --token="${{ secrets.OC4_DEV_TOKEN }}"
PROJ_DEV="e1e498-dev" bash openshift/scripts/oc_provision_nats.sh ${SUFFIX} apply
PROJ_DEV="e1e498-dev" MEMORY_REQUEST=250Mi MEMORY_LIMIT=500Mi bash openshift/scripts/oc_provision_nats.sh ${SUFFIX} apply
scan-dev:
name: ZAP Baseline Scan Dev
Expand Down
9 changes: 9 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
],
"typescript.preferences.importModuleSpecifier": "non-relative",
"cSpell.words": [
"actuals",
"actuals",
"aiobotocore",
"aiohttp",
Expand All @@ -64,15 +65,20 @@
"Behaviour",
"botocore",
"cffdrs",
"cutline",
"determinates",
"excinfo",
"FBAN",
"fastapi",
"fireweather",
"firezone",
"GDPS",
"geoalchemy",
"GEOGCS",
"geopackage",
"geospatial",
"geotiff",
"gpkg",
"grib",
"gribs",
"HAINES",
Expand All @@ -85,8 +91,10 @@
"maxy",
"miny",
"morecast",
"morecast",
"nats",
"ndarray",
"Neighbour",
"numba",
"ORJSON",
"osgeo",
Expand All @@ -99,6 +107,7 @@
"PROJCS",
"pydantic",
"RDPS",
"reproject",
"rocketchat",
"sfms",
"sqlalchemy",
Expand Down
54 changes: 54 additions & 0 deletions api/alembic/versions/6910d017b626_add_advisory_tpi_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""add advisory tpi table
Revision ID: 6910d017b626
Revises: be128a7bb4fd
Create Date: 2024-07-31 16:27:31.642156
"""

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "6910d017b626"
down_revision = "be128a7bb4fd"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic ###
op.create_table(
"advisory_tpi_stats",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("advisory_shape_id", sa.Integer(), nullable=False),
sa.Column("run_parameters", sa.Integer(), nullable=False),
sa.Column("valley_bottom", sa.Integer(), nullable=False),
sa.Column("mid_slope", sa.Integer(), nullable=False),
sa.Column("upper_slope", sa.Integer(), nullable=False),
sa.Column("pixel_size_metres", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["advisory_shape_id"],
["advisory_shapes.id"],
),
sa.ForeignKeyConstraint(
["run_parameters"],
["run_parameters.id"],
),
sa.PrimaryKeyConstraint("id"),
comment="Elevation TPI stats per fire shape",
)
op.create_index(op.f("ix_advisory_tpi_stats_advisory_shape_id"), "advisory_tpi_stats", ["advisory_shape_id"], unique=False)
op.create_index(op.f("ix_advisory_tpi_stats_id"), "advisory_tpi_stats", ["id"], unique=False)
op.create_index(op.f("ix_advisory_tpi_stats_run_parameters"), "advisory_tpi_stats", ["run_parameters"], unique=False)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic ###
op.drop_index(op.f("ix_advisory_tpi_stats_run_parameters"), table_name="advisory_tpi_stats")
op.drop_index(op.f("ix_advisory_tpi_stats_id"), table_name="advisory_tpi_stats")
op.drop_index(op.f("ix_advisory_tpi_stats_advisory_shape_id"), table_name="advisory_tpi_stats")
op.drop_table("advisory_tpi_stats")
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions api/app/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ OBJECT_STORE_SECRET=object_store_secret
OBJECT_STORE_BUCKET=object_store_bucket
DEM_NAME=dem_mosaic_250_max.tif
TPI_DEM_NAME=bc_dem_50m_tpi.tif
CLASSIFIED_TPI_DEM_NAME=bc_dem_50m_tpi_win100_classified.tif
SENTRY_DSN=some_dsn
135 changes: 132 additions & 3 deletions api/app/auto_spatial_advisory/elevation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Takes a classified HFI image and calculates basic elevation statistics associated with advisory areas per fire zone."""

from dataclasses import dataclass
from datetime import date, datetime
from time import perf_counter
import logging
import os
import tempfile
from typing import Dict
import numpy as np
from osgeo import gdal
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -13,16 +15,47 @@
from app import config
from app.auto_spatial_advisory.classify_hfi import classify_hfi
from app.auto_spatial_advisory.run_type import RunType
from app.db.crud.auto_spatial_advisory import get_run_parameters_id, save_advisory_elevation_stats
from app.db.crud.auto_spatial_advisory import get_run_parameters_id, save_advisory_elevation_stats, save_advisory_elevation_tpi_stats
from app.db.database import get_async_read_session_scope, get_async_write_session_scope, DB_READ_STRING
from app.db.models.auto_spatial_advisory import AdvisoryElevationStats
from app.db.models.auto_spatial_advisory import AdvisoryElevationStats, AdvisoryTPIStats
from app.auto_spatial_advisory.hfi_filepath import get_raster_filepath, get_raster_tif_filename
from app.utils.s3 import get_client
from app.utils.geospatial import raster_mul, warp_to_match_extent


logger = logging.getLogger(__name__)
DEM_GDAL_SOURCE = None


async def process_elevation_tpi(run_type: RunType, run_datetime: datetime, for_date: date):
"""
Create new elevation statistics records for the given parameters.
:param hfi_s3_key: the object store key pointing to the hfi tif to intersect with tpi layer
:param run_type: The type of run to process. (is it a forecast or actual run?)
:param run_datetime: The date and time of the run to process. (when was the hfi file created?)
:param for_date: The date of the hfi to process. (when is the hfi for?)
"""
logger.info("Processing elevation stats %s for run date: %s, for date: %s", run_type, run_datetime, for_date)
perf_start = perf_counter()
# Get the id from run_parameters associated with the provided run_type, for_date and for_datetime
async with get_async_write_session_scope() as session:
run_parameters_id = await get_run_parameters_id(session, run_type, run_datetime, for_date)

stmt = select(AdvisoryTPIStats).where(AdvisoryTPIStats.run_parameters == run_parameters_id)

exists = (await session.execute(stmt)).scalars().first() is not None
if not exists:
fire_zone_stats = await process_tpi_by_firezone(run_type, run_datetime.date(), for_date)
await store_elevation_tpi_stats(session, run_parameters_id, fire_zone_stats)
else:
logger.info("Elevation stats already computed")

perf_end = perf_counter()
delta = perf_end - perf_start
logger.info("%f delta count before and after processing elevation stats", delta)


async def process_elevation(source_path: str, run_type: RunType, run_datetime: datetime, for_date: date):
"""Create new elevation statistics records for the given parameters.
Expand Down Expand Up @@ -178,6 +211,79 @@ def apply_threshold_mask_to_dem(threshold: int, mask_path: str, temp_dir: str):
return masked_dem_path


@dataclass(frozen=True)
class FireZoneTPIStats:
"""
Captures fire zone stats of TPI pixels hitting >4K HFI threshold via
a dictionary, fire_zone_stats, of {source_identifier: {1: X, 2: Y, 3: Z}}, where 1 = valley bottom, 2 = mid slope, 3 = upper slope
and X, Y, Z are pixel counts at each of those elevation classes respectively.
Also includes the TPI raster's pixel size in metres.
"""

fire_zone_stats: Dict[int, Dict[int, int]]
pixel_size_metres: int


async def process_tpi_by_firezone(run_type: RunType, run_date: date, for_date: date):
"""
Given run parameters, lookup associated snow-masked HFI and static classified TPI geospatial data.
Cut out each fire zone shape from the above and intersect the TPI and HFI pixels, counting each pixel contributing to the TPI class.
Capture all fire zone stats keyed by its source_identifier.
:param run_type: forecast or actual
:param run_date: date the computation ran
:param for_date: date the computation is for
:return: fire zone TPI status
"""

gdal.SetConfigOption("AWS_SECRET_ACCESS_KEY", config.get("OBJECT_STORE_SECRET"))
gdal.SetConfigOption("AWS_ACCESS_KEY_ID", config.get("OBJECT_STORE_USER_ID"))
gdal.SetConfigOption("AWS_S3_ENDPOINT", config.get("OBJECT_STORE_SERVER"))
gdal.SetConfigOption("AWS_VIRTUAL_HOSTING", "FALSE")
bucket = config.get("OBJECT_STORE_BUCKET")
dem_file = config.get("CLASSIFIED_TPI_DEM_NAME")

key = f"/vsis3/{bucket}/dem/tpi/{dem_file}"
tpi_source: gdal.Dataset = gdal.Open(key, gdal.GA_ReadOnly)
pixel_size_metres = int(tpi_source.GetGeoTransform()[1])

hfi_raster_filename = get_raster_tif_filename(for_date)
hfi_raster_key = get_raster_filepath(run_date, run_type, hfi_raster_filename)
hfi_key = f"/vsis3/{bucket}/{hfi_raster_key}"
hfi_source: gdal.Dataset = gdal.Open(hfi_key, gdal.GA_ReadOnly)

warped_mem_path = f"/vsimem/warp_{hfi_raster_filename}"
resized_hfi_source: gdal.Dataset = warp_to_match_extent(hfi_source, tpi_source, warped_mem_path)
hfi_masked_tpi = raster_mul(tpi_source, resized_hfi_source)
resized_hfi_source = None
hfi_source = None
tpi_source = None
gdal.Unlink(warped_mem_path)

fire_zone_stats: Dict[int, Dict[int, int]] = {}
async with get_async_write_session_scope() as session:
stmt = text("SELECT id, source_identifier FROM advisory_shapes;")
result = await session.execute(stmt)

for row in result:
output_path = f"/vsimem/firezone_{row[1]}.tif"
warp_options = gdal.WarpOptions(format="GTiff", cutlineDSName=DB_READ_STRING, cutlineSQL=f"SELECT geom FROM advisory_shapes WHERE id={row[0]}", cropToCutline=True)
cut_hfi_masked_tpi: gdal.Dataset = gdal.Warp(output_path, hfi_masked_tpi, options=warp_options)
# Get unique values and their counts
tpi_classes, counts = np.unique(cut_hfi_masked_tpi.GetRasterBand(1).ReadAsArray(), return_counts=True)
cut_hfi_masked_tpi = None
gdal.Unlink(output_path)
tpi_class_freq_dist = dict(zip(tpi_classes, counts))

# Drop TPI class 4, this is the no data value from the TPI raster
tpi_class_freq_dist.pop(4, None)
fire_zone_stats[row[1]] = tpi_class_freq_dist

hfi_masked_tpi = None
return FireZoneTPIStats(fire_zone_stats=fire_zone_stats, pixel_size_metres=pixel_size_metres)


async def process_elevation_by_firezone(threshold: int, masked_dem_path: str, run_parameters_id: int):
"""
Given a tif that only contains elevations values at pixels where HFI exceeds the threshold, calculate statistics
Expand Down Expand Up @@ -205,7 +311,7 @@ def intersect_raster_by_firezone(threshold: int, advisory_shape_id: int, source_
:param threshold: The current threshold being processed, 1 = 4k-10k, 2 = > 10k
:param advisory_shape_id: The id of the fire zone (aka advisory_shape object) to clip with
:param source_identifier: The source identifer of the fire zone.
:param source_identifier: The source identifier of the fire zone.
:param raster_path: The path to the raster to be clipped.
:param temp_dir: A temporary location for storing intermediate files
"""
Expand Down Expand Up @@ -261,3 +367,26 @@ async def store_elevation_stats(session: AsyncSession, threshold: int, shape_id:
threshold=threshold,
)
await save_advisory_elevation_stats(session, advisory_elevation_stats)


async def store_elevation_tpi_stats(session: AsyncSession, run_parameters_id: int, fire_zone_tpi_stats: FireZoneTPIStats):
"""
Writes elevation TPI statistics to the database.
:param shape_id: The advisory shape id.
:param run_parameters_id: The RunParameter object id associated with this run_type, for_date and run_datetime
:param fire_zone_stats: Dictionary keying shape id to a dictionary of classified tpi hfi pixel counts
"""
advisory_tpi_stats_list = []
for shape_id, tpi_freq_count in fire_zone_tpi_stats.fire_zone_stats.items():
advisory_tpi_stats = AdvisoryTPIStats(
advisory_shape_id=int(shape_id),
run_parameters=run_parameters_id,
valley_bottom=tpi_freq_count.get(1, 0),
mid_slope=tpi_freq_count.get(2, 0),
upper_slope=tpi_freq_count.get(3, 0),
pixel_size_metres=fire_zone_tpi_stats.pixel_size_metres,
)
advisory_tpi_stats_list.append(advisory_tpi_stats)

await save_advisory_elevation_tpi_stats(session, advisory_tpi_stats_list)
18 changes: 7 additions & 11 deletions api/app/auto_spatial_advisory/process_elevation_hfi.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
""" Code relating to processing HFI data related to elevation
"""
"""Code relating to processing HFI data related to elevation"""

import logging
from datetime import date, datetime
from time import perf_counter
from app.auto_spatial_advisory.common import get_s3_key
from app.auto_spatial_advisory.elevation import process_elevation
from app.auto_spatial_advisory.elevation import process_elevation_tpi
from app.auto_spatial_advisory.run_type import RunType

logger = logging.getLogger(__name__)


async def process_hfi_elevation(run_type: RunType, run_date: date, run_datetime: datetime, for_date: date):
""" Create a new elevation based hfi analysis records for the given date.
"""Create a new elevation based hfi analysis records for the given date.
:param run_type: The type of run to process. (is it a forecast or actual run?)
:param run_date: The date of the run to process. (when was the hfi file created?)
:param for_date: The date of the hfi to process. (when is the hfi for?)
"""

logger.info('Processing HFI elevation %s for run date: %s, for date: %s', run_type, run_date, for_date)
logger.info("Processing HFI elevation %s for run date: %s, for date: %s", run_type, run_date, for_date)
perf_start = perf_counter()

key = get_s3_key(run_type, run_date, for_date)
logger.info(f'Key to HFI in object storage: {key}')

await process_elevation(key, run_type, run_datetime, for_date)
await process_elevation_tpi(run_type, run_datetime, for_date)

perf_end = perf_counter()
delta = perf_end - perf_start
logger.info('%f delta count before and after processing HFI elevation', delta)
logger.info("%f delta count before and after processing HFI elevation", delta)
Loading

0 comments on commit 1322133

Please sign in to comment.