Skip to content

Commit

Permalink
Fixes #73 (#74)
Browse files Browse the repository at this point in the history
* Fixes #73

* Add check for inf dist to NN search, ignore points outside max_distance,
add warning and tests

* Update changelog

Co-authored-by: Wolfgang Preimesberger <[email protected]>
  • Loading branch information
sebhahn and wpreimes authored Aug 12, 2021
1 parent ed21347 commit d54ea73
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
Changelog
=========

Version 0.4.1
=============

- Fixes a bug in the nearest neighbour lookup, where no points were returned
when less than k gpis are found in the selected max distance
(`Issue #73 <https://github.com/TUW-GEO/pygeogrids/issues/73>`_ ).

Version 0.4.0
=============

Expand Down
15 changes: 10 additions & 5 deletions src/pygeogrids/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,17 +443,22 @@ def find_k_nearest_gpi(self, lon, lat, max_dist=np.Inf, k=1):
Returns
-------
gpi : long
Grid point index.
distance : float
Distance of gpi to given lon, lat.
gpi : np.ndarray
Grid point indices.
distance : np.ndarray
Distance of gpi(s) to given lon, lat.
At the moment not on a great circle but in spherical
cartesian coordinates.
"""
if self.kdTree is None:
self._setup_kdtree()

distance, ind = self.kdTree.find_nearest_index(lon, lat, max_dist=max_dist, k=k)
distance, ind = self.kdTree.find_nearest_index(lon, lat,
max_dist=max_dist, k=k)
mask = np.isinf(distance)
if np.any(mask):
ind = ind[~mask]
distance = distance[~mask]

if self.gpidirect and self.allpoints or len(ind) == 0:
gpi = ind
Expand Down
20 changes: 13 additions & 7 deletions src/pygeogrids/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings

try:
import pykdtree.kdtree as pykd
Expand Down Expand Up @@ -108,7 +109,8 @@ def __init__(self, lon, lat, geodatum, grid=False, kd_tree_name="pykdtree"):
self.lon_size = len(lon)
else:
if lat.shape != lon.shape:
raise Exception("lat and lon np.arrays have to have equal shapes")
raise Exception(
"lat and lon np.arrays have to have equal shapes")
lat_init = lat
lon_init = lon
# Earth radius
Expand Down Expand Up @@ -137,7 +139,8 @@ def _transform_lonlats(self, lon, lat):
lon = np.array(lon)
lat = np.array(lat)
coords = np.zeros((lon.size, 3), dtype=np.float64)
(coords[:, 0], coords[:, 1], coords[:, 2]) = self.geodatum.toECEF(lon, lat)
(coords[:, 0], coords[:, 1], coords[:, 2]
) = self.geodatum.toECEF(lon, lat)

return coords

Expand Down Expand Up @@ -202,17 +205,20 @@ def find_nearest_index(self, lon, lat, max_dist=np.Inf, k=1):

if k is None:
if self.kd_tree_name != "scipy":
raise NotImplementedError("Only available for the scipy kdTree")
raise NotImplementedError(
"Only available for the scipy kdTree")
query_coords = query_coords[0]
k = self.kdtree.query_ball_point(
query_coords, r=max_dist, return_length=True
)

d, ind = self.kdtree.query(query_coords, distance_upper_bound=max_dist, k=k)
d, ind = self.kdtree.query(
query_coords, distance_upper_bound=max_dist, k=k)

# if no point was found, d == inf
if not np.all(np.isfinite(d)):
d, ind = np.array([]), np.array([])
if np.any(np.isinf(d)):
warnings.warn(f"Less than k={k} points found within "
f"max_dist={max_dist}. Distance set to 'Inf'."
)

if not self.grid:
return d, ind
Expand Down
14 changes: 9 additions & 5 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_nearest_neighbor_numpy_single(self):
assert lon == 145.5
assert lat == 45.5


def test_k_nearest_neighbor(self):
gpi, dist = self.grid.find_k_nearest_gpi(14.3, 18.5, k=2)
assert gpi[0, 0] == 25754
Expand All @@ -131,6 +132,14 @@ def test_k_nearest_neighbor(self):
assert lon == 13.5
assert lat == 18.5

with pytest.warns(UserWarning):
gpi, dist = self.grid.find_k_nearest_gpi(14.3, 18.5, k=2,
max_dist=25000)
assert len(gpi) == len(dist) == 1
assert np.all(np.isfinite(dist))
assert gpi == 25754


def test_k_nearest_neighbor_list(self):
gpi, dist = self.grid.find_k_nearest_gpi(
[145.1, 90.2], [45.8, -16.3], k=2)
Expand Down Expand Up @@ -708,8 +717,3 @@ def test_BasicGrid_transform_lon():
# case 3: no warning and no transform
grid = BasicGrid(lon_pos, lat, transform_lon=False)
assert np.all(grid.arrlon == lon_pos)



if __name__ == "__main__":
unittest.main()

0 comments on commit d54ea73

Please sign in to comment.