Skip to content

Commit

Permalink
Python: Add progress argument to exact_extract
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Jul 26, 2024
1 parent 38ea8a5 commit 2cb8c96
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
34 changes: 32 additions & 2 deletions python/src/exactextract/exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def exact_extract(
max_cells_in_memory: int = 30000000,
output: str = "geojson",
output_options: Optional[Mapping] = None,
progress=False,
):
"""Calculate zonal statistics
Expand Down Expand Up @@ -297,7 +298,9 @@ def exact_extract(
which may be significant for operations with large result sizes
such as ``cell_id``, ``values``, etc.
output_options: an optional dictionary of options passed to the :py:class:`writer.JSONWriter`, :py:class:`writer.PandasWriter`, or :py:class:`writer.GDALWriter`.
progress: if `True`, a progress bar will be displayed. Alternatively, a
function may be provided that will be called with the completion fraction
and a status message.
"""
rast = prep_raster(rast)
weights = prep_raster(weights, name_root="weight")
Expand All @@ -317,8 +320,35 @@ def exact_extract(
if include_geom:
processor.add_geom()
processor.set_max_cells_in_memory(max_cells_in_memory)
processor.process()

if progress:
processor.show_progress(True)

if callable(progress):
processor.set_progress_fn(progress)
elif progress is True:
try:
import tqdm

bar = tqdm.tqdm(total=100)

def status(frac, message):
pct = frac * 100
bar.update(pct - bar.n)
bar.set_description(message)
if pct == 100:
bar.close()

except ImportError:

def status(frac, message):
print(f"[{frac*100:0.1f}%] {message}")

processor.set_progress_fn(status)
else:
raise ValueError("progress should be True or a function")

processor.process()
writer.finish()

return writer.features()
10 changes: 10 additions & 0 deletions python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,3 +1190,13 @@ def test_explicit_operation():
results = exact_extract(None, square, op)

assert results[0]["properties"]["my_op"] == 4.0


def test_progress():

rast = NumPyRasterSource(np.arange(9).reshape(3, 3))
square = make_rect(0.5, 0.5, 2.5, 2.5)

squares = [square] * 10

exact_extract(rast, squares, "count", progress=True)

0 comments on commit 2cb8c96

Please sign in to comment.