Skip to content

Commit

Permalink
Add tests for existing filter expectations
Browse files Browse the repository at this point in the history
  • Loading branch information
mdales committed Nov 12, 2024
1 parent 76b53f9 commit df11e55
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 29 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ jobs:
python -m pip install -r requirements.txt
- name: Lint with pylint
run: |
python3 -m pylint deltap predictors prepare-layers prepare-species
python3 -m pylint deltap predictors prepare_layers prepare_species
- name: Tests
run: |
python3 -m pytest ./tests
2 changes: 1 addition & 1 deletion predictors/endemism.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def endemism(

species_rasters = {}
for raster_path in aohs:
speciesid = os.path.splitext(os.path.basename(raster_path))[0]
speciesid = os.path.basename(raster_path).split('_')[0]
full_path = os.path.join(aohs_dir, raster_path)
try:
species_rasters[speciesid].add(full_path)
Expand Down
2 changes: 1 addition & 1 deletion predictors/species_richness.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def species_richness(

species_rasters = {}
for raster_path in aohs:
speciesid = os.path.splitext(os.path.basename(raster_path))[0]
speciesid = os.path.basename(raster_path).split('_')[0]
full_path = os.path.join(aohs_dir, raster_path)
try:
species_rasters[speciesid].add(full_path)
Expand Down
52 changes: 26 additions & 26 deletions prepare_species/extract_species_psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import Optional, Tuple
from typing import Dict, List, Optional, Tuple

# import pyshark # pylint: disable=W0611
import geopandas as gpd
Expand Down Expand Up @@ -113,17 +113,12 @@ def tidy_reproject_save(
res_projected = res.to_crs(target_crs)
res_projected.to_file(output_path, driver="GeoJSON")

def process_row_inner(
row: Tuple,
habitats_data: Tuple,
geometries_data: Tuple,
) -> Tuple:

id_no, _, _, _ = row
def process_habitats(
habitats_data: List,
) -> Dict:

if len(habitats_data) == 0:
logger.debug("Dropping %s as no habitats found", id_no)
return
raise ValueError("No habitats found")

# Clean up habitats to ensure they're unique (the system agg in the SQL statement might duplicate them)
# In the database there are the following seasons:
Expand Down Expand Up @@ -153,17 +148,14 @@ def process_row_inner(
case 'non-breeding' | 'Non-Breeding Season':
season_code = 3
case _:
raise ValueError(f"Unexpected season {season} for {id_no}")
raise ValueError(f"Unexpected season {season}")

if systems is None:
logger.debug("Dropping %s: no systems in DB", id_no)
continue
if "Marine" in systems:
logger.debug("Dropping %s: marine in systems", id_no)
return
raise ValueError("Marine in systems")

if habitat_values is None:
logger.debug("Dropping %s: no habitats in DB", id_no)
continue
habitat_set = set(habitat_values.split('|'))
if len(habitat_set) == 0:
Expand All @@ -172,22 +164,22 @@ def process_row_inner(
habitats[season_code] = habitat_set | habitats.get(season_code, set())

if major_importance == 'Yes':
major_habitats_lvl_1[season_code] = {int(float(x)) for x in habitat_set} | major_habitats_lvl_1.get(season_code, {})
major_habitats_lvl_1[season_code] = {int(float(x)) for x in habitat_set} | major_habitats_lvl_1.get(season_code, set())

# habitat based filtering
if len(habitats) == 0:
logger.debug("Dropping %s: No habitats", id_no)
return
raise ValueError("No filtered habitats")

for season in major_habitats_lvl_1:
major_habitats = major_habitats_lvl_1[season_code]
major_habitats = major_habitats_lvl_1[season]
if any((x == 7) for x in major_habitats):
logger.debug("Dropping %s: Habitat 7 in major importance habitat list", id_no)
return
raise ValueError("Habitat 7 in major importance habitat list")

return habitats

def process_geometries(geometries_data: List) -> Dict:
if len(geometries_data) == 0:
logger.info("Dropping %s: no geometries", id_no)
return
raise ValueError("No geometries")

geometries = {}
for season, geometry in geometries_data:
Expand All @@ -206,7 +198,7 @@ def process_row_inner(
except KeyError:
geometries[season_code] = grange

return habitats, geometries
return geometries

def process_row(
output_directory_path: str,
Expand All @@ -222,11 +214,19 @@ def process_row(

cursor.execute(HABITATS_STATEMENT, (assessment_id,))
habitats_data = cursor.fetchall()
try:
habitats = process_habitats(habitats_data)
except ValueError as exc:
logging.info("Dropping %s: %s", id_no, str(exc))
return

cursor.execute(GEOMETRY_STATEMENT, (assessment_id, presence))
geometries_data = cursor.fetchall()

habitats, geometries = process_row_inner(row, habitats_data, geometries_data)
try:
geometries = process_geometries(geometries_data)
except ValueError as exc:
logging.info("Dropping %s: %s", id_no, str(exc))
return

seasons = set(geometries.keys()) | set(habitats.keys())

Expand Down
93 changes: 93 additions & 0 deletions tests/test_species_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest

from prepare_species.extract_species_psql import process_habitats

@pytest.mark.parametrize("label", [
"resident",
"Resident",
"unknown",
"Seasonal Occurrence Unknown",
None
])
def test_simple_resident_species_filter(label):
habitat_data = [
("resident", "Yes", "4.1|4.2", "Terrestrial"),
("resident", "No", "4.3", "Terrestrial"),
]
res = process_habitats(habitat_data)

# Just resident
assert list(res.keys()) == [1]
assert res[1] == set(["4.1", "4.2", "4.3"])

@pytest.mark.parametrize("breeding_label,non_breeding_label", [
("breeding", "non-breeding"),
("Breeding Season", "non-breeding"),
("breeding", "Non-Breeding Season"),
("Breeding Season", "Non-Breeding Season"),
])
def test_simple_migratory_species_filter(breeding_label, non_breeding_label):
habitat_data = [
(breeding_label, "Yes", "4.1|4.2", "Terrestrial"),
(non_breeding_label, "No", "4.3", "Terrestrial"),
]
res = process_habitats(habitat_data)

# Just resident
assert list(res.keys()) == [2, 3]
assert res[2] == set(["4.1", "4.2"])
assert res[3] == set(["4.3"])

def test_reject_if_marine_in_system():
habitat_data = [
("resident", "Yes", "4.1|4.2", "Terrestrial"),
("resident", "No", "4.3", "Terrestrial|Marine"),
]
with pytest.raises(ValueError):
_ = process_habitats(habitat_data)

def test_reject_if_caves_in_major_habitat():
habitat_data = [
("resident", "Yes", "4.1|7.2", "Terrestrial"),
("resident", "No", "4.3", "Terrestrial"),
]
with pytest.raises(ValueError):
_ = process_habitats(habitat_data)

def test_do_not_reject_if_caves_in_minor_habitat():
habitat_data = [
("resident", "Yes", "4.1|4.2", "Terrestrial"),
("resident", "No", "7.3", "Terrestrial"),
]
res = process_habitats(habitat_data)

# Just resident
assert list(res.keys()) == [1]
assert res[1] == set(["4.1", "4.2", "7.3"])

@pytest.mark.parametrize("label", [
"passage",
"Passage",
])
def test_passage_ignored(label):
habitat_data = [
("resident", "Yes", "4.1|4.2", "Terrestrial"),
(label, "No", "4.3", "Terrestrial"),
]
res = process_habitats(habitat_data)

# Just resident
assert list(res.keys()) == [1]
assert res[1] == set(["4.1", "4.2"])

def test_fail_no_habitats_before_filter():
habitat_data = []
with pytest.raises(ValueError):
_ = process_habitats(habitat_data)

def test_fail_no_habitats_after_filter():
habitat_data = [
("passage", "Yes", "4.1|7.2", "Terrestrial"),
]
with pytest.raises(ValueError):
_ = process_habitats(habitat_data)

0 comments on commit df11e55

Please sign in to comment.