diff --git a/tests/test_points.py b/tests/test_points.py index 117951a..003c882 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -1,6 +1,12 @@ import numpy as np import pytest -from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection, MultiLineString +from shapely.geometry import ( + Point, + MultiPoint, + LineString, + GeometryCollection, + MultiLineString, +) from sleap_roots import Series from sleap_roots.lengths import get_max_length_pts from sleap_roots.points import ( @@ -815,9 +821,17 @@ def test_extract_from_empty_geometrycollection(): @pytest.mark.parametrize( "geometry, expected", [ - (MultiLineString([[(0, 0), (1, 1)], [(2, 2), (3, 3)]]), []), - (GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)]), MultiLineString([[(0, 0), (1, 1)], [(2, 2), (3, 3)]])]), [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])]), # GeometryCollection with MultiLineString + ( + GeometryCollection( + [ + Point(1, 2), + LineString([(0, 0), (1, 1)]), + MultiLineString([[(0, 0), (1, 1)], [(2, 2), (3, 3)]]), + ] + ), + [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])], + ), # GeometryCollection with MultiLineString (MultiLineString(), []), # Empty MultiLineString ], )