Skip to content

Commit

Permalink
FeatureSource: remove need to specify id_field
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Jan 2, 2024
1 parent 9096b8e commit f377115
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 118 deletions.
6 changes: 3 additions & 3 deletions python/src/exactextract/feature_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class GDALFeatureSource(FeatureSource):
def __init__(self, src):
super().__init__("")
super().__init__()

if isinstance(src, (str, os.PathLike)):
from osgeo import ogr
Expand All @@ -34,7 +34,7 @@ def __iter__(self):

class JSONFeatureSource(FeatureSource):
def __init__(self, src):
super().__init__("")
super().__init__()
if type(src) is dict:
self.src = [src]
else:
Expand All @@ -47,7 +47,7 @@ def __iter__(self):

class GeoPandasFeatureSource(FeatureSource):
def __init__(self, src):
super().__init__("")
super().__init__()
self.src = src

def __iter__(self):
Expand Down
13 changes: 3 additions & 10 deletions python/src/pybindings/feature_source_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ namespace exactextract {
class PyFeatureSourceBase : public FeatureSource
{
public:
PyFeatureSourceBase(const std::string& id_field)
: m_id_field(id_field)
, m_initialized(false)
PyFeatureSourceBase()
: m_initialized(false)
{
}

Expand All @@ -50,11 +49,6 @@ class PyFeatureSourceBase : public FeatureSource
return m_feature.cast<const Feature&>();
}

const std::string& id_field() const override
{
return m_id_field;
}

virtual py::object py_iter() = 0;

// debug
Expand All @@ -68,7 +62,6 @@ class PyFeatureSourceBase : public FeatureSource
}

private:
std::string m_id_field;
py::object m_src;
py::object m_feature;
py::iterator m_it;
Expand All @@ -95,7 +88,7 @@ bind_feature_source(py::module& m)
py::class_<PyFeatureSourceBase>(m, "PyFeatureSourceBase");

py::class_<PyFeatureSource, PyFeatureSourceBase, FeatureSource>(m, "FeatureSource")
.def(py::init<const std::string&>())
.def(py::init<>())
.def("__iter__", &PyFeatureSourceBase::py_iter)
// debug
.def("feature", &PyFeatureSourceBase::feature)
Expand Down
5 changes: 1 addition & 4 deletions src/coverage_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ CoverageProcessor::process()
while (m_shp.next()) {
const Feature& f_in = m_shp.feature();

progress(f_in, m_shp.id_field());
progress(f_in, m_include_cols.empty() ? "." : m_include_cols.front());

auto geom = f_in.geometry();
auto feature_bbox = geos_get_box(m_geos_context, geom);
Expand All @@ -74,9 +74,6 @@ CoverageProcessor::process()
for (const auto& loc : RasterCoverageIteration<ValueType, WeightType>(coverage_fractions, values, weights, grid, areas.get())) {

auto f_out = m_output.create_feature();
if (m_shp.id_field() != "") {
f_out->set(m_shp.id_field(), f_in);
}
for (const auto& col : m_include_cols) {
f_out->set(col, f_in);
}
Expand Down
92 changes: 73 additions & 19 deletions src/exactextract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ using exactextract::GDALRasterWrapper;
using exactextract::Operation;

static GDALDatasetWrapper
load_dataset(const std::string& descriptor, const std::string& field_name);
load_dataset(const std::string& descriptor,
const std::vector<std::string>& include_cols,
const std::string& src_id_name,
const std::string& dst_id_name,
const std::string& dst_id_type);

static std::unordered_map<std::string, GDALRasterWrapper>
load_rasters(const std::vector<std::string>& descriptors);

static std::vector<std::unique_ptr<Operation>>
prepare_operations(const std::vector<std::string>& descriptors,
std::unordered_map<std::string, GDALRasterWrapper>& rasters,
Expand All @@ -52,7 +58,7 @@ main(int argc, char** argv)
{
CLI::App app{ "Zonal statistics using exactextract: version " + exactextract::version() };

std::string poly_descriptor, field_name, output_filename, strategy, id_type, id_name;
std::string poly_descriptor, src_id_name, output_filename, strategy, dst_id_type, dst_id_name;
std::vector<std::string> stats;
std::vector<std::string> raster_descriptors;
std::vector<std::string> include_cols;
Expand All @@ -69,13 +75,13 @@ main(int argc, char** argv)

app.add_option("-p,--polygons", poly_descriptor, "polygon dataset")->required(true);
app.add_option("-r,--raster", raster_descriptors, "raster dataset")->required(true);
app.add_option("-f,--fid", field_name, "id from polygon dataset to retain in output")->required(true);
app.add_option("-f,--fid", src_id_name, "id from polygon dataset to retain in output")->required(false);
app.add_option("-o,--output", output_filename, "output filename")->required(true);
app.add_option("-s,--stat", stats, "statistics")->required(false)->expected(-1);
app.add_option("--max-cells", max_cells_in_memory, "maximum number of raster cells to read in memory at once, in millions")->required(false)->default_val("30");
app.add_option("--strategy", strategy, "processing strategy")->required(false)->default_val("feature-sequential");
app.add_option("--id-type", id_type, "override type of id field in output")->required(false);
app.add_option("--id-name", id_name, "override name of id field in output")->required(false);
app.add_option("--id-type", dst_id_type, "override type of id field in output")->required(false);
app.add_option("--id-name", dst_id_name, "override name of id field in output")->required(false);
app.add_flag("--include-xy", coverage_opts.include_xy, "include cell center coordinates with coverage fractions");
app.add_flag("--include-cell", coverage_opts.include_cell, "include cell identifier with coverage fractions");
app.add_flag("--include-area", include_area, "include cell area with coverage fractions");
Expand All @@ -90,10 +96,13 @@ main(int argc, char** argv)
}
CLI11_PARSE(app, argc, argv)

if (id_name.empty() != id_type.empty()) {
if (dst_id_name.empty() != dst_id_type.empty()) {
std::cerr << "Must specify both --id_type and --id_name" << std::endl;
return 1;
}
if (src_id_name.empty() && !dst_id_name.empty()) {
src_id_name = dst_id_name;
}

max_cells_in_memory *= 1000000;

Expand All @@ -105,11 +114,11 @@ main(int argc, char** argv)
OGRRegisterAll();
auto rasters = load_rasters(raster_descriptors);

GDALDatasetWrapper shp = load_dataset(poly_descriptor, field_name);

if (include_area) {
const GDALRasterWrapper& rast = rasters.begin()->second;
coverage_opts.area_method = rast.cartesian() ? exactextract::CoverageOperation::AreaMethod::CARTESIAN : exactextract::CoverageOperation::AreaMethod::SPHERICAL;
coverage_opts.area_method = rast.cartesian()
? exactextract::CoverageOperation::AreaMethod::CARTESIAN
: exactextract::CoverageOperation::AreaMethod::SPHERICAL;
}

auto operations = prepare_operations(stats, rasters, coverage_opts);
Expand All @@ -120,12 +129,21 @@ main(int argc, char** argv)
}
}

std::unique_ptr<exactextract::GDALWriter> gdal_writer = defer_writing ? std::make_unique<exactextract::DeferredGDALWriter>(output_filename) : std::make_unique<exactextract::GDALWriter>(output_filename);
if (!id_name.empty() && !id_type.empty()) {
gdal_writer->add_id_field(id_name, id_type);
} else {
gdal_writer->copy_id_field(shp);
std::unique_ptr<exactextract::GDALWriter> gdal_writer = defer_writing
? std::make_unique<
exactextract::DeferredGDALWriter>(
output_filename)
: std::make_unique<exactextract::GDALWriter>(
output_filename);

GDALDatasetWrapper shp = load_dataset(poly_descriptor, include_cols, src_id_name, dst_id_name, dst_id_type);

if (!dst_id_name.empty()) {
include_cols.insert(include_cols.begin(), dst_id_name);
} else if (!src_id_name.empty()) {
include_cols.insert(include_cols.begin(), src_id_name);
}

for (const auto& field : include_cols) {
gdal_writer->copy_field(shp, field);
}
Expand Down Expand Up @@ -169,11 +187,46 @@ main(int argc, char** argv)
}

static GDALDatasetWrapper
load_dataset(const std::string& descriptor, const std::string& field_name)
load_dataset(const std::string& descriptor,
const std::vector<std::string>& include_cols,
const std::string& src_id_name,
const std::string& dst_id_name,
const std::string& dst_id_type)
{
auto parsed = exactextract::parse_dataset_descriptor(descriptor);
const auto parsed = exactextract::parse_dataset_descriptor(descriptor);

std::vector<std::string> select;

if (!src_id_name.empty()) {
std::string id_select;

if (!dst_id_type.empty()) {
id_select += "CAST(";
}
id_select += src_id_name;

if (!dst_id_type.empty()) {
id_select += " AS " + dst_id_type + ")";
}

if (!dst_id_name.empty()) {
id_select += " AS " + dst_id_name + "";
}

select.push_back(id_select);
}

for (const auto& col : include_cols) {
select.push_back(col);
}

auto ds = GDALDatasetWrapper{ parsed.first, parsed.second };

if (!select.empty()) {
ds.set_select(select);
}

return GDALDatasetWrapper{ parsed.first, parsed.second, field_name };
return ds;
}

static std::unordered_map<std::string, GDALRasterWrapper>
Expand Down Expand Up @@ -205,7 +258,7 @@ prepare_operations(
bool found_stat = false;

for (const auto& descriptor : descriptors) {
auto stat = exactextract::parse_stat_descriptor(descriptor);
const auto stat = exactextract::parse_stat_descriptor(descriptor);

auto values_it = rasters.find(stat.values);
if (values_it == rasters.end()) {
Expand All @@ -228,7 +281,8 @@ prepare_operations(

if (stat.stat == "coverage") {
found_coverage = true;
ops.emplace_back(std::make_unique<exactextract::CoverageOperation>(stat.name, values, weights, coverage_opts));
ops.emplace_back(
std::make_unique<exactextract::CoverageOperation>(stat.name, values, weights, coverage_opts));
} else {
found_stat = true;
ops.emplace_back(std::make_unique<Operation>(stat.stat, stat.name, values, weights));
Expand Down
5 changes: 1 addition & 4 deletions src/feature_sequential_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FeatureSequentialProcessor::process()

auto geom = f_in.geometry();

progress(f_in, m_shp.id_field());
progress(f_in, m_include_cols.empty() ? "." : m_include_cols.front());

Box feature_bbox = exactextract::geos_get_box(m_geos_context, geom);

Expand Down Expand Up @@ -88,9 +88,6 @@ FeatureSequentialProcessor::process()
}

auto f_out = m_output.create_feature();
if (m_shp.id_field() != "") {
f_out->set(m_shp.id_field(), f_in);
}
for (const auto& col : m_include_cols) {
f_out->set(col, f_in);
}
Expand Down
2 changes: 0 additions & 2 deletions src/feature_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class FeatureSource
virtual const Feature& feature() const = 0;

virtual bool next() = 0;

virtual const std::string& id_field() const = 0;
};

}
62 changes: 53 additions & 9 deletions src/gdal_dataset_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018-2023 ISciences, LLC.
// Copyright (c) 2018-2024 ISciences, LLC.
// All rights reserved.
//
// This software is licensed under the Apache License, Version 2.0 (the "License").
Expand All @@ -16,19 +16,20 @@

#include <algorithm>
#include <memory>
#include <sstream>
#include <stdexcept>

namespace exactextract {
#include "utils.h"

namespace exactextract {
const Feature&
GDALDatasetWrapper::feature() const
{
return m_feature;
}

GDALDatasetWrapper::GDALDatasetWrapper(const std::string& filename, const std::string& layer, std::string id_field)
: m_id_field{ std::move(id_field) }
, m_feature(nullptr)
GDALDatasetWrapper::GDALDatasetWrapper(const std::string& filename, const std::string& layer)
: m_feature(nullptr)
{
m_dataset = GDALOpenEx(filename.c_str(), GDAL_OF_VECTOR, nullptr, nullptr, nullptr);
if (m_dataset == nullptr) {
Expand All @@ -48,13 +49,53 @@ GDALDatasetWrapper::GDALDatasetWrapper(const std::string& filename, const std::s
}

OGR_L_ResetReading(m_layer);
}

GDALDatasetWrapper::GDALDatasetWrapper(GDALDatasetWrapper&& other) noexcept
: m_dataset(other.m_dataset)
, m_layer(other.m_layer)
, m_feature(std::move(other.m_feature))
, m_layer_is_sql(other.m_layer_is_sql)

Check warning on line 58 in src/gdal_dataset_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/gdal_dataset_wrapper.cpp#L54-L58

Added lines #L54 - L58 were not covered by tests
{
other.m_dataset = nullptr;
other.m_layer = nullptr;
other.m_layer_is_sql = false;
}

Check warning on line 63 in src/gdal_dataset_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/gdal_dataset_wrapper.cpp#L60-L63

Added lines #L60 - L63 were not covered by tests

GDALDatasetWrapper&
GDALDatasetWrapper::operator=(GDALDatasetWrapper&& other) noexcept

Check warning on line 66 in src/gdal_dataset_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/gdal_dataset_wrapper.cpp#L66

Added line #L66 was not covered by tests
{
m_dataset = other.m_dataset;
m_layer = other.m_layer;
m_feature = std::move(other.m_feature);
m_layer_is_sql = other.m_layer_is_sql;

Check warning on line 71 in src/gdal_dataset_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/gdal_dataset_wrapper.cpp#L68-L71

Added lines #L68 - L71 were not covered by tests

auto defn = OGR_L_GetLayerDefn(m_layer);
auto index = OGR_FD_GetFieldIndex(defn, m_id_field.c_str());
other.m_dataset = nullptr;
other.m_layer = nullptr;
other.m_layer_is_sql = false;

Check warning on line 75 in src/gdal_dataset_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/gdal_dataset_wrapper.cpp#L73-L75

Added lines #L73 - L75 were not covered by tests

if (index == -1) {
throw std::runtime_error("ID field '" + m_id_field + "' not found in " + filename + ".");
return *this;

Check warning on line 77 in src/gdal_dataset_wrapper.cpp

View check run for this annotation

Codecov / codecov/patch

src/gdal_dataset_wrapper.cpp#L77

Added line #L77 was not covered by tests
}

void
GDALDatasetWrapper::set_select(const std::vector<std::string>& cols)
{
const char* layer_name = OGR_L_GetName(m_layer);
std::stringstream sql;

sql << "SELECT ";
for (std::size_t i = 0; i < cols.size(); i++) {
if (i > 0) {
sql << ", ";
}
sql << cols[i];
}
sql << " FROM " << layer_name;

std::string sql_str = sql.str();

m_layer = GDALDatasetExecuteSQL(m_dataset, sql.str().c_str(), nullptr, nullptr);
m_layer_is_sql = true;
}

bool
Expand Down Expand Up @@ -82,6 +123,9 @@ GDALDatasetWrapper::copy_field(const std::string& name, OGRLayerH copy_to) const

GDALDatasetWrapper::~GDALDatasetWrapper()
{
if (m_layer_is_sql) {
GDALDatasetReleaseResultSet(m_dataset, m_layer);
}
GDALClose(m_dataset);
}
}
Loading

0 comments on commit f377115

Please sign in to comment.