From b5779fd4564895bb4651ad4561ed84d78decc691 Mon Sep 17 00:00:00 2001 From: Ondrej Sykora Date: Sun, 8 Sep 2024 22:29:59 +0000 Subject: [PATCH] [waypoint_collection] Added a container class for waypoints with efficient lookup. Added WaypointCollection, a class that indexes waypoints in the form of place IDs and latlng coordinates and provides efficient lookup. --- python/gmpro/waypoint_collection.py | 106 ++++++++++++ python/gmpro/waypoint_collection_test.py | 206 +++++++++++++++++++++++ 2 files changed, 312 insertions(+) create mode 100644 python/gmpro/waypoint_collection.py create mode 100644 python/gmpro/waypoint_collection_test.py diff --git a/python/gmpro/waypoint_collection.py b/python/gmpro/waypoint_collection.py new file mode 100644 index 00000000..9bed2788 --- /dev/null +++ b/python/gmpro/waypoint_collection.py @@ -0,0 +1,106 @@ +"""A collection of waypoints with efficient membership testing.""" + +import bisect +from collections.abc import Collection +from typing import TypeAlias + +from .json import cfr_json + + +LatLng: TypeAlias = tuple[float, float] | cfr_json.LatLng +Waypoint: TypeAlias = str | LatLng + + +def _as_tuple(latlng: LatLng) -> tuple[float, float]: + if isinstance(latlng, tuple): + return latlng + return latlng["latitude"], latlng["longitude"] + + +class WaypointCollection: + """A collection of place IDs and latlongs with efficient membership testing. + + Place IDs are matched as strings, this class makes no attempt at geocoding + them. Latlngs can be matched either precisely or within an L_inf distance. + + Uses bisection by (latitude, longitude) for lookup of coordinates. Precise + lookups have O(logN) time complexity; approximate lookups have time complexity + O(logN + D) where N is the number of stored coordinates, and D is the maximal + number of coordinates whose latitude falls into the same interval of length + 2*max_delta used in the query. + """ + + # TODO(ondrasej): Replace the lookup with an actual KD-tree in case the lookup + # performance is not sufficient or if we encounter pathologic cases. + + def __init__(self): + """Initializes an empty collection.""" + self._place_ids: set[str] = set() + self._latlngs: list[tuple[float, float]] = [] + + def add_place_ids(self, place_ids: Collection[str]) -> None: + """Adds `place_ids` to the collection.""" + self._place_ids.update(place_ids) + + def add_latlngs(self, latlngs: Collection[LatLng]) -> None: + """Adds `latlngs` to the collection. + + Runs in O(N*logN) where N is the number of latlngs in the collection after + the insertion. + + Args: + latlngs: The latlngs to add. + """ + unique_latlngs = set(_as_tuple(latlng) for latlng in latlngs) + unique_latlngs.difference_update(self._latlngs) + self._latlngs.extend(unique_latlngs) + self._latlngs.sort() + + def contains(self, waypoint: Waypoint, max_delta: float = 0) -> bool: + """Checks whether the collection contains a waypoint. + + Args: + waypoint: The place ID or latlng to look up in the collection. + max_delta: The maximal L_inf distance in degrees for a match. + + Returns: + True if there is a match, False otherwise. + """ + if isinstance(waypoint, str): + return self.contains_place_id(waypoint) + else: + return self.contains_latlng(waypoint, max_delta) + + def contains_place_id(self, place_id: str) -> bool: + """Checks whether the collection contains a place ID.""" + return place_id in self._place_ids + + def contains_latlng(self, latlng: LatLng, max_delta: float = 0) -> bool: + """Checks for coordinates within max_delta of latlng. + + Args: + latlng: The coordinates to look up. + max_delta: The maximal L_inf distance in degrees in which the coordinates + are looked up. + + Returns: + True if there is a match; False otherwise. + """ + latlng = _as_tuple(latlng) + + left = bisect.bisect_left( + self._latlngs, (latlng[0] - max_delta, latlng[1] - max_delta) + ) + right = bisect.bisect_right( + self._latlngs, (latlng[0] + max_delta, latlng[1] + max_delta) + ) + if left == right: + # We got an empty interval. + return False + longitude_min = latlng[1] - max_delta + longitude_max = latlng[1] + max_delta + for i in range(left, right): + candidate = self._latlngs[i] + if longitude_min <= candidate[1] <= longitude_max: + return True + return False diff --git a/python/gmpro/waypoint_collection_test.py b/python/gmpro/waypoint_collection_test.py new file mode 100644 index 00000000..b500851e --- /dev/null +++ b/python/gmpro/waypoint_collection_test.py @@ -0,0 +1,206 @@ +"""Tests for the waypoint collection.""" + +from collections.abc import Sequence +import datetime +import itertools +import logging +import random +import unittest + +from . import waypoint_collection +from .json import cfr_json + + +class WaypointCollectionTest(unittest.TestCase): + + # _JSON_LATLNGS_A and _TUPLE_LATLNGS_A contain the same coordinates in + # different representations (tuples and JSON latlng structures). + _JSON_LATLNGS_A: Sequence[cfr_json.LatLng] = ( + {"latitude": 0, "longitude": 0}, + {"latitude": 48.877104524088146, "longitude": 2.329973366337609}, + {"latitude": 48.879156912623536, "longitude": 2.3270195883955864}, + ) + _TUPLE_LATLNGS_A: Sequence[tuple[float, float]] = ( + (0, 0), + (48.877104524088146, 2.329973366337609), + (48.879156912623536, 2.3270195883955864), + ) + # _JSON_LATLNGS_B and _TUPLE_LATLNGS_B use the same setup as the _A version, + # but they are different coordinates with no overlap between _A and _B. + _JSON_LATLNGS_B: Sequence[cfr_json.LatLng] = ( + {"latitude": 37.42461654144618, "longitude": -122.09252441795736}, + {"latitude": 37.422335039773735, "longitude": -122.0838965937761}, + {"latitude": 37.42168743305142, "longitude": -122.0790336749436}, + ) + _TUPLE_LATLNGS_B: Sequence[tuple[float, float]] = ( + (37.42461654144618, -122.09252441795736), + (37.422335039773735, -122.0838965937761), + (37.42168743305142, -122.0790336749436), + ) + + def test_empty_collection(self): + collection = waypoint_collection.WaypointCollection() + waypoints = ( + "foo", + "bar", + "baz", + *self._JSON_LATLNGS_A, + *self._TUPLE_LATLNGS_A, + ) + for waypoint in waypoints: + with self.subTest(waypoint=waypoint): + self.assertFalse(collection.contains(waypoint)) + self.assertFalse(collection.contains(waypoint, max_delta=0.001)) + self.assertFalse(collection.contains(waypoint, max_delta=0.1)) + + def test_place_ids(self): + place_ids = ("foo", "bar") + collection = waypoint_collection.WaypointCollection() + collection.add_place_ids(place_ids) + + for place_id in place_ids: + self.assertTrue(collection.contains(place_id)) + + self.assertFalse(collection.contains("baz")) + + latlngs = (*self._TUPLE_LATLNGS_B, *self._JSON_LATLNGS_A) + for latlng in latlngs: + self.assertFalse(collection.contains(latlng)) + + def test_latlng_exact_match(self): + collection = waypoint_collection.WaypointCollection() + collection.add_latlngs(self._JSON_LATLNGS_A) + + for latlng in itertools.chain(self._JSON_LATLNGS_A, self._TUPLE_LATLNGS_A): + with self.subTest(latlng=latlng): + self.assertTrue(collection.contains(latlng)) + + for latlng in itertools.chain(self._JSON_LATLNGS_B, self._TUPLE_LATLNGS_B): + with self.subTest(latlng=latlng): + self.assertFalse(collection.contains(latlng)) + + def test_latlng_approximate_match(self): + collection = waypoint_collection.WaypointCollection() + latlng = self._TUPLE_LATLNGS_A[0] + collection.add_latlngs((latlng,)) + + max_delta = 0.0001 + num_tests = 1000 + rnd = random.Random(b"123456789") + for _ in range(num_tests): + perturbed_latlng = ( + latlng[0] + rnd.uniform(-max_delta, max_delta), + latlng[1] + rnd.uniform(-max_delta, max_delta), + ) + self.assertTrue(collection.contains(perturbed_latlng, max_delta)) + + deltas = ( + (max_delta + 1e-6, 0), + (0, max_delta + 1e-6), + (-max_delta - 1e-6, 0), + (0, -max_delta - 1e-6), + ) + for delta_lat, delta_lng in deltas: + perturbed_latlng = latlng[0] + delta_lat, latlng[1] + delta_lng + self.assertFalse(collection.contains(perturbed_latlng, max_delta)) + + def test_exact_latlng_lookup_large_set(self): + # The number of coordinates added to the collection. + num_collection_elements = 1000000 + # The number of lookups performed in the lookup tests. The test makes + # `num_test_lookups` of lookups for points that are in the collection and + # the same number of lookups for points that are not in the collection. + num_test_lookups = 1000 + + collection = waypoint_collection.WaypointCollection() + + rnd = random.Random(b"123456789") + member_latlngs = set() + while len(member_latlngs) < num_collection_elements: + latlng = rnd.uniform(-80, 80), rnd.uniform(-180, 180) + member_latlngs.add(latlng) + + collection.add_latlngs(member_latlngs) + + non_member_latlngs = set() + while len(non_member_latlngs) < num_test_lookups: + latlng = rnd.uniform(-80, 80), rnd.uniform(-180, 180) + if latlng not in member_latlngs: + non_member_latlngs.add(latlng) + + member_start = datetime.datetime.now() + for latlng in itertools.islice(member_latlngs, num_test_lookups): + self.assertTrue(collection.contains(latlng)) + non_member_start = datetime.datetime.now() + for latlng in non_member_latlngs: + self.assertFalse(collection.contains(latlng)) + non_member_end = datetime.datetime.now() + + logging.info( + "member lookup = %fs", + (non_member_start - member_start).total_seconds(), + ) + logging.info( + "non-member lookup = %fs", + (non_member_end - non_member_start).total_seconds(), + ) + + def test_approximate_lookup_large_set(self): + # The number of coordinates added to the collection. + num_collection_elements = 1000000 + # The number of lookups performed in the lookup tests. The test makes + # `num_test_lookups` of lookups for points that are in the collection and + # the same number of lookups for points that are not in the collection. + num_test_lookups = 1000 + + max_delta = 0.0001 + + collection = waypoint_collection.WaypointCollection() + + # Member and non-member coordinates are generated in different octants, so + # that a query for a non-member never returns True. + rnd = random.Random(b"123456789") + member_latlngs = set() + while len(member_latlngs) < num_collection_elements: + latlng = rnd.uniform(0, 80), rnd.uniform(1, 90) + member_latlngs.add(latlng) + + collection.add_latlngs(member_latlngs) + + non_member_latlngs = set() + while len(non_member_latlngs) < num_test_lookups: + latlng = rnd.uniform(0, 80), rnd.uniform(-90, -1) + if latlng not in member_latlngs: + non_member_latlngs.add(latlng) + + member_start = datetime.datetime.now() + for latlng in itertools.islice(member_latlngs, num_test_lookups): + # Create a version of the element that is randomly perturbed but that is + # still within the tolerance specified by max_delta. + perturbed_latlng = ( + latlng[0] + rnd.uniform(-max_delta, max_delta), + latlng[1] + rnd.uniform(-max_delta, max_delta), + ) + self.assertTrue(collection.contains(perturbed_latlng, max_delta)) + non_member_start = datetime.datetime.now() + for latlng in non_member_latlngs: + self.assertFalse(collection.contains(latlng), max_delta) + non_member_end = datetime.datetime.now() + + logging.info( + "member lookup: %fs", + (non_member_start - member_start).total_seconds(), + ) + logging.info( + "non-member lookup: %fs", + (non_member_end - non_member_start).total_seconds(), + ) + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(filename)s:%(lineno)d %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + unittest.main()