-
Notifications
You must be signed in to change notification settings - Fork 2
/
graph_util_get_neighbors.py
136 lines (121 loc) · 3.8 KB
/
graph_util_get_neighbors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Contains the get_neighbors function and associated helper functions.
"""
from enum import Enum
from typing import Union, List, Tuple, Dict, Any
import numpy as np
import shapely.geometry
from g2o import SE3Quat
from shapely.geometry import LineString
import scipy.spatial
from map_processing.transform_utils import se3_quat_average
class _NeighborType(Enum):
INTERSECTION = (0,)
CLOSE_DISTANCE = 1
def get_neighbors(
vertices: np.ndarray,
vertex_ids: Union[List[int], None] = None,
neighbor_type: _NeighborType = _NeighborType.INTERSECTION,
) -> Tuple[List[List[int]], List[Dict[str, Any]]]:
"""TODO: documentation
Args:
vertices:
vertex_ids:
neighbor_type:
Returns:
"""
nvertices = vertices.shape[0]
if vertex_ids is None:
vertex_ids = list(range(nvertices))
neighbors = (
[[vertex_ids[1]]]
+ [[vertex_ids[i - 1], vertex_ids[i + 1]] for i in range(1, nvertices - 1)]
+ [[vertex_ids[-2]]]
)
curr_id = max(vertex_ids) + 1
intersections = []
intersection_detector = scipy.spatial.KDTree(vertices[:, 0:3])
intersections_detected = intersection_detector.query_ball_point(
vertices[:, 0:3], 1, workers=-1, return_sorted=True
)
for id1, close_detections_list in enumerate(intersections_detected):
for id2 in close_detections_list:
if id1 == 0 or id1 == id2 or id2 < id1:
continue
if neighbor_type == _NeighborType.INTERSECTION:
intersection = _get_intersection(vertices, id1, id2, curr_id)
if intersection is None:
continue
intersections.append(intersection)
neighbors[id1 - 1][-1] = curr_id
neighbors[id1][0] = curr_id
neighbors[id2 - 1][-1] = curr_id
neighbors[id2][0] = curr_id
curr_id += 1
elif neighbor_type == _NeighborType.CLOSE_DISTANCE and _is_close_enough(
vertices, id1, id2
):
neighbors[id1].append(id2)
neighbors[id2].append(id1)
print(f"Point {id1} and {id2} are close enough, adding neighbors")
return neighbors, intersections
def _get_intersection(vertices, id1, id2, curr_id):
"""TODO: Documentation
Args:
vertices:
id1:
id2:
curr_id:
Returns:
"""
line1 = LineString(
[
(vertices[id1 - 1][0], vertices[id1 - 1][2]),
(vertices[id1][0], vertices[id1][2]),
]
)
line2 = LineString(
[
(vertices[id2 - 1][0], vertices[id2 - 1][2]),
(vertices[id2][0], vertices[id2][2]),
]
)
intersect_pt = line1.intersection(line2)
average = se3_quat_average(
[
SE3Quat(vertices[id1 - 1]),
SE3Quat(vertices[id1]),
SE3Quat(vertices[id2 - 1]),
SE3Quat(vertices[id2]),
]
).to_vector()
if str(intersect_pt) == "LINESTRING EMPTY" or not isinstance(
intersect_pt, shapely.geometry.point.Point
):
return None
print(f"Intersection at {intersect_pt}, between {id1} and {id2}")
return {
"translation": {"x": intersect_pt.x, "y": average[1], "z": intersect_pt.y},
"rotation": {
"x": average[3],
"y": average[4],
"z": average[5],
"w": average[6],
},
"poseId": curr_id,
"neighbors": [id1 - 1, id1, id2 - 1, id2],
}
def _is_close_enough(vertices, id1, id2):
"""TODO: Documentation
Args:
vertices:
id1:
id2:
Returns:
"""
v1 = vertices[id1]
v2 = vertices[id2]
return (
abs(v1[1] - v2[1]) < 1
and ((v1[0] - v2[0]) ** 2 + (v1[2] - v2[2]) ** 2) ** 0.5 < 1
)