diff --git a/src/skan/csr.py b/src/skan/csr.py index fac6b883..295a5d5a 100644 --- a/src/skan/csr.py +++ b/src/skan/csr.py @@ -6,6 +6,7 @@ from skimage import morphology from skimage.graph import central_pixel from skimage.util._map_array import map_array, ArrayMap +import numpy.typing as npt import numba import warnings @@ -13,7 +14,9 @@ from .summary_utils import find_main_branches -def _weighted_abs_diff(values0, values1, distances): +def _weighted_abs_diff( + values0: np.ndarray, values1: np.ndarray, distances: np.ndarray + ) -> np.ndarray: """A default edge function for complete image graphs. A pixel graph on an image with no edge values and no mask is a very @@ -521,6 +524,7 @@ def __init__( np.full(skeleton_image.ndim, spacing) ) if keep_images: + self.keep_images = keep_images self.skeleton_image = skeleton_image self.source_image = source_image @@ -550,7 +554,7 @@ def path(self, index): start, stop = self.paths.indptr[index:index + 2] return self.paths.indices[start:stop] - def path_coordinates(self, index): + def path_coordinates(self, index: int): """Return the image coordinates of the pixels in the path. Parameters @@ -566,7 +570,7 @@ def path_coordinates(self, index): path_indices = self.path(index) return self.coordinates[path_indices] - def path_with_data(self, index): + def path_with_data(self, index: int): """Return pixel indices and corresponding pixel values on a path. Parameters @@ -652,9 +656,28 @@ def path_stdev(self): means = self.path_means() return np.sqrt(np.clip(sumsq/lengths - means*means, 0, None)) - def prune_paths(self, indices) -> 'Skeleton': + def prune_paths(self, indices: npt.ArrayLike) -> 'Skeleton': + """Prune nodes from the skeleton. + + Parameters + ---------- + indices: List[int] + List of indices to be removed. + + Retruns + ------- + Skeleton + A new Skeleton object pruned. + """ # warning: slow image_cp = np.copy(self.skeleton_image) + if not np.all(np.array(indices) < self.n_paths): + raise ValueError( + f'The path index {np.max(indices)} does not exist in this ' + f'skeleton. (The highest path index is {self.n_paths}.)\n' + 'If you obtained the index from a summary table, you ' + 'probably need to resummarize the skeleton.' + ) for i in indices: pixel_ids_to_wipe = self.path(i) junctions = self.degrees[pixel_ids_to_wipe] > 2 @@ -668,6 +691,7 @@ def prune_paths(self, indices) -> 'Skeleton': new_skeleton, spacing=self.spacing, source_image=self.source_image, + keep_images=self.keep_images ) def __array__(self, dtype=None): @@ -676,8 +700,11 @@ def __array__(self, dtype=None): def summarize( - skel: Skeleton, *, value_is_height=False, find_main_branch=False - ): + skel: Skeleton, + *, + value_is_height: bool = False, + find_main_branch: bool = False + ) -> pd.DataFrame: """Compute statistics for every skeleton and branch in ``skel``. Parameters @@ -1037,10 +1064,8 @@ def _simplify_graph(skel): src_relab, dst_relab = fw_map[src], fw_map[dst] - edges = sparse.coo_matrix( - (distance, (src_relab, dst_relab)), - shape=(n_nodes, n_nodes) - ) + edges = sparse.coo_matrix((distance, (src_relab, dst_relab)), + shape=(n_nodes, n_nodes)) dir_csgraph = edges.tocsr() simp_csgraph = dir_csgraph + dir_csgraph.T # make undirected diff --git a/src/skan/test/test_csr.py b/src/skan/test/test_csr.py index b7370125..80123052 100644 --- a/src/skan/test/test_csr.py +++ b/src/skan/test/test_csr.py @@ -5,7 +5,7 @@ from numpy.testing import assert_equal, assert_almost_equal from skimage.draw import line -from skan import csr +from skan import csr, summarize from skan._testdata import ( tinycycle, tinyline, skeleton0, skeleton1, skeleton2, skeleton3d, topograph1d, skeleton4 @@ -179,6 +179,62 @@ def test_transpose_image(): ) +@pytest.mark.parametrize( + "skeleton,prune_branch,target", + [ + ( + skeleton1, 1, + np.array([[0, 1, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 1, 0], [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]]) + ), + ( + skeleton1, 2, + np.array([[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 2, 0, 0, 0], + [1, 0, 0, 0, 2, 2, 2]]) + ), + # There are no isolated cycles to be pruned + ( + skeleton1, 3, + np.array([[0, 1, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 0, 1], + [0, 3, 2, 0, 1, 1, 0], [3, 0, 0, 4, 0, 0, 0], + [3, 0, 0, 0, 4, 4, 4]]) + ), + ] + ) +def test_prune_paths( + skeleton: np.ndarray, prune_branch: int, target: np.ndarray + ) -> None: + """Test pruning of paths.""" + s = csr.Skeleton(skeleton, keep_images=True) + summary = summarize(s) + indices_to_remove = summary.loc[summary['branch-type'] == prune_branch + ].index + pruned = s.prune_paths(indices_to_remove) + np.testing.assert_array_equal(pruned, target) + + +def test_prune_paths_exception_single_point() -> None: + """Test exceptions raised when pruning leaves a single point and Skeleton object + can not be created and returned.""" + s = csr.Skeleton(skeleton0) + summary = summarize(s) + indices_to_remove = summary.loc[summary['branch-type'] == 1].index + with pytest.raises(ValueError): + s.prune_paths(indices_to_remove) + + +def test_prune_paths_exception_invalid_path_index() -> None: + """Test exceptions raised when trying to prune paths that do not exist in the summary. This can arise if skeletons + are not updated correctly during iterative pruning.""" + s = csr.Skeleton(skeleton0) + summary = summarize(s) + indices_to_remove = [6] + with pytest.raises(ValueError): + s.prune_paths(indices_to_remove) + + def test_fast_graph_center_idx(): s = csr.Skeleton(skeleton0) i = csr._fast_graph_center_idx(s)