From df0d5775d43b19cae2afdb5d6f733dce054d6485 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 6 May 2024 20:37:55 -0700 Subject: [PATCH] use helper function to get intersection points for different geometries --- sleap_roots/convhull.py | 19 +++++++++++-------- sleap_roots/points.py | 2 +- tests/test_convhull.py | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py index 88ad97c..2c02bfa 100644 --- a/sleap_roots/convhull.py +++ b/sleap_roots/convhull.py @@ -548,18 +548,21 @@ def get_chull_intersection_vectors( # Get the intersection points if not intersection.is_empty: - intersect_points = ( - np.array([[point.x, point.y] for point in intersection.geoms]) - if intersection.geom_type == "MultiPoint" - else np.array([[intersection.x, intersection.y]]) - ) + intersect_points = extract_points_from_geometry(intersection) else: # Return two vectors of NaNs if there is no intersection return leftmost_vector, rightmost_vector - # Get the leftmost and rightmost intersection points - leftmost_intersect = intersect_points[np.argmin(intersect_points[:, 0])] - rightmost_intersect = intersect_points[np.argmax(intersect_points[:, 0])] + # Convert the list of NumPy arrays to a 2D NumPy array + intersection_points_array = np.vstack(intersect_points) + + # Find the leftmost and rightmost intersection points + leftmost_intersect = intersection_points_array[ + np.argmin(intersection_points_array[:, 0]) + ] + rightmost_intersect = intersection_points_array[ + np.argmax(intersection_points_array[:, 0]) + ] # Make a vector from the leftmost r0 point to the leftmost intersection point leftmost_vector = (leftmost_intersect - leftmost_r0).reshape(1, -1) diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 6d5c5c1..4ca3d05 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple -def extract_points_from_geometry(geometry): +def extract_points_from_geometry(geometry) -> List[np.ndarray]: """Extracts coordinates as a list of numpy arrays from any given Shapely geometry object. This function supports Point, MultiPoint, LineString, and GeometryCollection types. diff --git a/tests/test_convhull.py b/tests/test_convhull.py index f506312..e7eba99 100644 --- a/tests/test_convhull.py +++ b/tests/test_convhull.py @@ -305,7 +305,7 @@ def test_basic_functionality(pts_shape_3_6_2): r0_pts, r1_pts, pts, hull ) - # Assertions depend on the expected outcome, which you'll need to calculate based on your function's logic + # TODO: Add more specific tests as needed assert not np.isnan(left_vector).any(), "Left vector should not contain NaNs" assert not np.isnan(right_vector).any(), "Right vector should not contain NaNs"