Skip to content

Commit

Permalink
[waypoint_collection] Added a container class for waypoints with effi…
Browse files Browse the repository at this point in the history
…cient lookup.

Added WaypointCollection, a class that indexes waypoints in the form of place
IDs and latlng coordinates and provides efficient lookup.
  • Loading branch information
ondrasej committed Sep 8, 2024
1 parent 5ea11c1 commit b5779fd
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 0 deletions.
106 changes: 106 additions & 0 deletions python/gmpro/waypoint_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""A collection of waypoints with efficient membership testing."""

import bisect
from collections.abc import Collection
from typing import TypeAlias

from .json import cfr_json


LatLng: TypeAlias = tuple[float, float] | cfr_json.LatLng
Waypoint: TypeAlias = str | LatLng


def _as_tuple(latlng: LatLng) -> tuple[float, float]:
if isinstance(latlng, tuple):
return latlng
return latlng["latitude"], latlng["longitude"]


class WaypointCollection:
"""A collection of place IDs and latlongs with efficient membership testing.
Place IDs are matched as strings, this class makes no attempt at geocoding
them. Latlngs can be matched either precisely or within an L_inf distance.
Uses bisection by (latitude, longitude) for lookup of coordinates. Precise
lookups have O(logN) time complexity; approximate lookups have time complexity
O(logN + D) where N is the number of stored coordinates, and D is the maximal
number of coordinates whose latitude falls into the same interval of length
2*max_delta used in the query.
"""

# TODO(ondrasej): Replace the lookup with an actual KD-tree in case the lookup
# performance is not sufficient or if we encounter pathologic cases.

def __init__(self):
"""Initializes an empty collection."""
self._place_ids: set[str] = set()
self._latlngs: list[tuple[float, float]] = []

def add_place_ids(self, place_ids: Collection[str]) -> None:
"""Adds `place_ids` to the collection."""
self._place_ids.update(place_ids)

def add_latlngs(self, latlngs: Collection[LatLng]) -> None:
"""Adds `latlngs` to the collection.
Runs in O(N*logN) where N is the number of latlngs in the collection after
the insertion.
Args:
latlngs: The latlngs to add.
"""
unique_latlngs = set(_as_tuple(latlng) for latlng in latlngs)
unique_latlngs.difference_update(self._latlngs)
self._latlngs.extend(unique_latlngs)
self._latlngs.sort()

def contains(self, waypoint: Waypoint, max_delta: float = 0) -> bool:
"""Checks whether the collection contains a waypoint.
Args:
waypoint: The place ID or latlng to look up in the collection.
max_delta: The maximal L_inf distance in degrees for a match.
Returns:
True if there is a match, False otherwise.
"""
if isinstance(waypoint, str):
return self.contains_place_id(waypoint)
else:
return self.contains_latlng(waypoint, max_delta)

def contains_place_id(self, place_id: str) -> bool:
"""Checks whether the collection contains a place ID."""
return place_id in self._place_ids

def contains_latlng(self, latlng: LatLng, max_delta: float = 0) -> bool:
"""Checks for coordinates within max_delta of latlng.
Args:
latlng: The coordinates to look up.
max_delta: The maximal L_inf distance in degrees in which the coordinates
are looked up.
Returns:
True if there is a match; False otherwise.
"""
latlng = _as_tuple(latlng)

left = bisect.bisect_left(
self._latlngs, (latlng[0] - max_delta, latlng[1] - max_delta)
)
right = bisect.bisect_right(
self._latlngs, (latlng[0] + max_delta, latlng[1] + max_delta)
)
if left == right:
# We got an empty interval.
return False
longitude_min = latlng[1] - max_delta
longitude_max = latlng[1] + max_delta
for i in range(left, right):
candidate = self._latlngs[i]
if longitude_min <= candidate[1] <= longitude_max:
return True
return False
206 changes: 206 additions & 0 deletions python/gmpro/waypoint_collection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""Tests for the waypoint collection."""

from collections.abc import Sequence
import datetime
import itertools
import logging
import random
import unittest

from . import waypoint_collection
from .json import cfr_json


class WaypointCollectionTest(unittest.TestCase):

# _JSON_LATLNGS_A and _TUPLE_LATLNGS_A contain the same coordinates in
# different representations (tuples and JSON latlng structures).
_JSON_LATLNGS_A: Sequence[cfr_json.LatLng] = (
{"latitude": 0, "longitude": 0},
{"latitude": 48.877104524088146, "longitude": 2.329973366337609},
{"latitude": 48.879156912623536, "longitude": 2.3270195883955864},
)
_TUPLE_LATLNGS_A: Sequence[tuple[float, float]] = (
(0, 0),
(48.877104524088146, 2.329973366337609),
(48.879156912623536, 2.3270195883955864),
)
# _JSON_LATLNGS_B and _TUPLE_LATLNGS_B use the same setup as the _A version,
# but they are different coordinates with no overlap between _A and _B.
_JSON_LATLNGS_B: Sequence[cfr_json.LatLng] = (
{"latitude": 37.42461654144618, "longitude": -122.09252441795736},
{"latitude": 37.422335039773735, "longitude": -122.0838965937761},
{"latitude": 37.42168743305142, "longitude": -122.0790336749436},
)
_TUPLE_LATLNGS_B: Sequence[tuple[float, float]] = (
(37.42461654144618, -122.09252441795736),
(37.422335039773735, -122.0838965937761),
(37.42168743305142, -122.0790336749436),
)

def test_empty_collection(self):
collection = waypoint_collection.WaypointCollection()
waypoints = (
"foo",
"bar",
"baz",
*self._JSON_LATLNGS_A,
*self._TUPLE_LATLNGS_A,
)
for waypoint in waypoints:
with self.subTest(waypoint=waypoint):
self.assertFalse(collection.contains(waypoint))
self.assertFalse(collection.contains(waypoint, max_delta=0.001))
self.assertFalse(collection.contains(waypoint, max_delta=0.1))

def test_place_ids(self):
place_ids = ("foo", "bar")
collection = waypoint_collection.WaypointCollection()
collection.add_place_ids(place_ids)

for place_id in place_ids:
self.assertTrue(collection.contains(place_id))

self.assertFalse(collection.contains("baz"))

latlngs = (*self._TUPLE_LATLNGS_B, *self._JSON_LATLNGS_A)
for latlng in latlngs:
self.assertFalse(collection.contains(latlng))

def test_latlng_exact_match(self):
collection = waypoint_collection.WaypointCollection()
collection.add_latlngs(self._JSON_LATLNGS_A)

for latlng in itertools.chain(self._JSON_LATLNGS_A, self._TUPLE_LATLNGS_A):
with self.subTest(latlng=latlng):
self.assertTrue(collection.contains(latlng))

for latlng in itertools.chain(self._JSON_LATLNGS_B, self._TUPLE_LATLNGS_B):
with self.subTest(latlng=latlng):
self.assertFalse(collection.contains(latlng))

def test_latlng_approximate_match(self):
collection = waypoint_collection.WaypointCollection()
latlng = self._TUPLE_LATLNGS_A[0]
collection.add_latlngs((latlng,))

max_delta = 0.0001
num_tests = 1000
rnd = random.Random(b"123456789")
for _ in range(num_tests):
perturbed_latlng = (
latlng[0] + rnd.uniform(-max_delta, max_delta),
latlng[1] + rnd.uniform(-max_delta, max_delta),
)
self.assertTrue(collection.contains(perturbed_latlng, max_delta))

deltas = (
(max_delta + 1e-6, 0),
(0, max_delta + 1e-6),
(-max_delta - 1e-6, 0),
(0, -max_delta - 1e-6),
)
for delta_lat, delta_lng in deltas:
perturbed_latlng = latlng[0] + delta_lat, latlng[1] + delta_lng
self.assertFalse(collection.contains(perturbed_latlng, max_delta))

def test_exact_latlng_lookup_large_set(self):
# The number of coordinates added to the collection.
num_collection_elements = 1000000
# The number of lookups performed in the lookup tests. The test makes
# `num_test_lookups` of lookups for points that are in the collection and
# the same number of lookups for points that are not in the collection.
num_test_lookups = 1000

collection = waypoint_collection.WaypointCollection()

rnd = random.Random(b"123456789")
member_latlngs = set()
while len(member_latlngs) < num_collection_elements:
latlng = rnd.uniform(-80, 80), rnd.uniform(-180, 180)
member_latlngs.add(latlng)

collection.add_latlngs(member_latlngs)

non_member_latlngs = set()
while len(non_member_latlngs) < num_test_lookups:
latlng = rnd.uniform(-80, 80), rnd.uniform(-180, 180)
if latlng not in member_latlngs:
non_member_latlngs.add(latlng)

member_start = datetime.datetime.now()
for latlng in itertools.islice(member_latlngs, num_test_lookups):
self.assertTrue(collection.contains(latlng))
non_member_start = datetime.datetime.now()
for latlng in non_member_latlngs:
self.assertFalse(collection.contains(latlng))
non_member_end = datetime.datetime.now()

logging.info(
"member lookup = %fs",
(non_member_start - member_start).total_seconds(),
)
logging.info(
"non-member lookup = %fs",
(non_member_end - non_member_start).total_seconds(),
)

def test_approximate_lookup_large_set(self):
# The number of coordinates added to the collection.
num_collection_elements = 1000000
# The number of lookups performed in the lookup tests. The test makes
# `num_test_lookups` of lookups for points that are in the collection and
# the same number of lookups for points that are not in the collection.
num_test_lookups = 1000

max_delta = 0.0001

collection = waypoint_collection.WaypointCollection()

# Member and non-member coordinates are generated in different octants, so
# that a query for a non-member never returns True.
rnd = random.Random(b"123456789")
member_latlngs = set()
while len(member_latlngs) < num_collection_elements:
latlng = rnd.uniform(0, 80), rnd.uniform(1, 90)
member_latlngs.add(latlng)

collection.add_latlngs(member_latlngs)

non_member_latlngs = set()
while len(non_member_latlngs) < num_test_lookups:
latlng = rnd.uniform(0, 80), rnd.uniform(-90, -1)
if latlng not in member_latlngs:
non_member_latlngs.add(latlng)

member_start = datetime.datetime.now()
for latlng in itertools.islice(member_latlngs, num_test_lookups):
# Create a version of the element that is randomly perturbed but that is
# still within the tolerance specified by max_delta.
perturbed_latlng = (
latlng[0] + rnd.uniform(-max_delta, max_delta),
latlng[1] + rnd.uniform(-max_delta, max_delta),
)
self.assertTrue(collection.contains(perturbed_latlng, max_delta))
non_member_start = datetime.datetime.now()
for latlng in non_member_latlngs:
self.assertFalse(collection.contains(latlng), max_delta)
non_member_end = datetime.datetime.now()

logging.info(
"member lookup: %fs",
(non_member_start - member_start).total_seconds(),
)
logging.info(
"non-member lookup: %fs",
(non_member_end - non_member_start).total_seconds(),
)


if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(filename)s:%(lineno)d %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
unittest.main()

0 comments on commit b5779fd

Please sign in to comment.