Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow 'nearest_dtos' 2-d regridding to work with discrete sampling geometry source grids #833

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ version NEXTVERSION

**2024-??-??**

* Allow ``'nearest_dtos'`` 2-d regridding to work with discrete
sampling geometry source grids
(https://github.com/NCAS-CMS/cf-python/issues/832)
* New method: `cf.Field.filled`
(https://github.com/NCAS-CMS/cf-python/issues/811)
* New method: `cf.Field.is_discrete_axis`
Expand Down
1 change: 0 additions & 1 deletion cf/cfimplementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TiePointIndex,
)
from .data import Data

from .data.array import (
BoundsFromNodesArray,
CellConnectivityArray,
Expand Down
52 changes: 45 additions & 7 deletions cf/data/dask_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,17 +507,20 @@ def _regrid(
# Note: It is much more efficient to access
# 'weights.indptr', 'weights.indices', and
# 'weights.data' directly, rather than iterating
# over rows of 'weights' and using 'weights.getrow'.
# over rows of 'weights' and using
# 'weights.getrow'. Also, 'np.count_nonzero' is much
# faster than 'np.any' and 'np.all'.
count_nonzero = np.count_nonzero
indptr = weights.indptr.tolist()
indices = weights.indices
data = weights.data
for j, (i0, i1) in enumerate(zip(indptr[:-1], indptr[1:])):
mask = src_mask[indices[i0:i1]]
if not count_nonzero(mask):
n_masked = count_nonzero(mask)
if not n_masked:
continue

if mask.all():
if n_masked == mask.size:
dst_mask[j] = True
continue

Expand All @@ -529,8 +532,8 @@ def _regrid(

del indptr

elif method in ("linear", "bilinear", "nearest_dtos"):
# 2) Linear and nearest neighbour methods:
elif method in ("linear", "bilinear"):
# 2) Linear methods:
#
# Mask out any row j that contains at least one positive
# (i.e. greater than or equal to 'min_weight') w_ji that
Expand All @@ -546,7 +549,9 @@ def _regrid(
# Note: It is much more efficient to access
# 'weights.indptr', 'weights.indices', and
# 'weights.data' directly, rather than iterating
# over rows of 'weights' and using 'weights.getrow'.
# over rows of 'weights' and using
# 'weights.getrow'. Also, 'np.count_nonzero' is much
# faster than 'np.any' and 'np.all'.
count_nonzero = np.count_nonzero
where = np.where
indptr = weights.indptr.tolist()
Expand All @@ -562,12 +567,45 @@ def _regrid(

del indptr, pos_data

elif method == "nearest_dtos":
# 3) Nearest neighbour dtos method:
#
# Set to 0 any weight that corresponds to a masked source
sadielbartholomew marked this conversation as resolved.
Show resolved Hide resolved
# grid cell.
#
# Mask out any row j for which all source grid cells are
# masked.
dst_size = weights.shape[0]
if dst_mask is None:
dst_mask = np.zeros((dst_size,), dtype=bool)
else:
dst_mask = dst_mask.copy()

# Note: It is much more efficient to access
# 'weights.indptr', 'weights.indices', and
# 'weights.data' directly, rather than iterating
# over rows of 'weights' and using
# 'weights.getrow'. Also, 'np.count_nonzero' is much
# faster than 'np.any' and 'np.all'.
count_nonzero = np.count_nonzero
indptr = weights.indptr.tolist()
indices = weights.indices
for j, (i0, i1) in enumerate(zip(indptr[:-1], indptr[1:])):
mask = src_mask[indices[i0:i1]]
n_masked = count_nonzero(mask)
if n_masked == mask.size:
dst_mask[j] = True
elif n_masked:
weights.data[np.arange(i0, i1)[mask]] = 0

del indptr

elif method in (
"patch",
"conservative_2nd",
"nearest_stod",
):
# 3) Patch recovery and second-order conservative methods:
# 4) Patch recovery and second-order conservative methods:
#
# A reference source data mask has already been
# incorporated into the weights matrix, and 'a' is assumed
Expand Down
1 change: 0 additions & 1 deletion cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from ..units import Units
from .collapse import Collapse
from .creation import generate_axis_identifiers, to_dask

from .dask_utils import (
_da_ma_allclose,
cf_asanyarray,
Expand Down
4 changes: 3 additions & 1 deletion cf/docstring/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@
mapped to the closest destination point. A
destination point can be mapped to multiple source
points. Some destination points may not be
mapped. Useful for regridding of categorical data.
mapped. Each regridded value is the sum of its
contributing source elements. Useful for binning or
for categorical data.

* `None`: This is the default and can only be used
when *dst* is a `RegridOperator`.""",
Expand Down
19 changes: 13 additions & 6 deletions cf/regrid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,10 @@ def regrid(
"are a UGRID mesh"
)

if src_grid.is_locstream or dst_grid.is_locstream:
if dst_grid.is_locstream:
raise ValueError(
f"{method!r} regridding is (at the moment) only available "
"when neither the source and destination grids are "
"DSG featureTypes."
f"{method!r} regridding is (at the moment) not available "
"when the destination grid is a DSG featureType."
)

elif cartesian and (src_grid.is_mesh or dst_grid.is_mesh):
Expand Down Expand Up @@ -656,6 +655,7 @@ def regrid(
dst=dst,
weights_file=weights_file if from_file else None,
src_mesh_location=src_grid.mesh_location,
src_featureType=src_grid.featureType,
dst_featureType=dst_grid.featureType,
src_z=src_grid.z,
dst_z=dst_grid.z,
Expand All @@ -674,6 +674,9 @@ def regrid(
)

if return_operator:
# Note: The `RegridOperator.tosparse` method will also set
# 'dst_mask' to False for destination points with all
# zero weights.
regrid_operator.tosparse()
return regrid_operator

Expand Down Expand Up @@ -1279,7 +1282,7 @@ def spherical_grid(

# Set cyclicity of X axis
if mesh_location or featureType:
cyclic = None
cyclic = False
elif cyclic is None:
cyclic = f.iscyclic(x_axis)
else:
Expand Down Expand Up @@ -2281,6 +2284,11 @@ def create_esmpy_locstream(grid, mask=None):
# but the esmpy mask requires 0/1 for masked/unmasked
# elements.
mask = np.invert(mask).astype("int32")
if mask.size == 1:
# Make sure that there's a mask element for each point in
# the locstream (rather than a scalar that applies to all
# elements).
mask = np.full((location_count,), mask, dtype="int32")
else:
# No masked points
mask = np.full((location_count,), 1, dtype="int32")
Expand Down Expand Up @@ -2465,7 +2473,6 @@ def create_esmpy_weights(
from netCDF4 import Dataset

from .. import __version__

from ..data.array.locks import netcdf_lock

if (
Expand Down
81 changes: 73 additions & 8 deletions cf/test/test_regrid_featureType.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
except ImportError:
pass

disallowed_methods = (
"conservative",
"conservative_2nd",
"nearest_dtos",
)

methods = (
"linear",
"nearest_stod",
Expand Down Expand Up @@ -169,6 +163,78 @@ def test_Field_regrid_grid_to_featureType_3d(self):
else:
self.assertFalse(y.mask.any())

@unittest.skipUnless(esmpy_imported, "Requires esmpy/ESMF package.")
def test_Field_regrid_featureType_to_grid_2d(self):
self.assertFalse(cf.regrid_logging())

# Create some nice data
src = self.dst_featureType
src.del_construct("cellmethod0")
src = src[:12]
src[...] = 273 + np.arange(12)
x = src.coord("X")
x[...] = [4, 6, 9, 11, 14, 16, 4, 6, 9, 11, 14, 16]
y = src.coord("Y")
y[...] = [41, 41, 31, 31, 21, 21, 39, 39, 29, 29, 19, 19]

dst = self.src_grid.copy()
x = dst.coord("X")
x[...] = [5, 10, 15, 20]
y = dst.coord("Y")
y[...] = [10, 20, 30, 40]

# Mask some destination grid points
dst[0, 0, 1, 2] = cf.masked

# Expected destination regridded values
y0 = np.ma.array(
[[0, 0, 0, 0], [0, 0, 1122, 0], [0, 1114, 0, 0], [1106, 0, 0, 0]],
mask=[
[True, True, True, True],
[True, True, False, True],
[True, False, True, True],
[False, True, True, True],
],
)

for src_masked in (False, True):
y = y0.copy()
if src_masked:
src = src.copy()
src[6:8] = cf.masked
# This following element should be smaller, because it
# now only has two source cells conrtibuting to it,
davidhassell marked this conversation as resolved.
Show resolved Hide resolved
# rather than four.
y[3, 0] = 547

# Loop over whether or not to use the destination grid
# masked points
for use_dst_mask in (False, True):
if use_dst_mask:
y = y.copy()
y[1, 2] = np.ma.masked

kwargs = {"use_dst_mask": use_dst_mask}
method = "nearest_dtos"
for return_operator in (False, True):
if return_operator:
r = src.regrids(
dst, method=method, return_operator=True, **kwargs
)
x = src.regrids(r)
else:
x = src.regrids(dst, method=method, **kwargs)

a = x.array

self.assertEqual(y.size, a.size)
self.assertTrue(np.allclose(y, a, atol=atol, rtol=rtol))

if isinstance(a, np.ma.MaskedArray):
self.assertTrue((y.mask == a.mask).all())
else:
self.assertFalse(y.mask.any())

@unittest.skipUnless(esmpy_imported, "Requires esmpy/ESMF package.")
def test_Field_regrid_grid_to_featureType_2d(self):
self.assertFalse(cf.regrid_logging())
Expand Down Expand Up @@ -196,7 +262,6 @@ def test_Field_regrid_grid_to_featureType_2d(self):
a = x.array

y = esmpy_regrid(coord_sys, method, src, dst, **kwargs)

self.assertEqual(y.size, a.size)
self.assertTrue(np.allclose(y, a, atol=atol, rtol=rtol))

Expand Down Expand Up @@ -259,7 +324,7 @@ def test_Field_regrid_featureType_bad_methods(self):
dst = self.dst_featureType.copy()
src = self.src_grid.copy()

for method in disallowed_methods:
for method in ("conservative", "conservative_2nd"):
with self.assertRaises(ValueError):
src.regrids(dst, method=method)

Expand Down
Loading