Skip to content

Commit

Permalink
Improved error reporting and tests for prune_paths() methods (#212)
Browse files Browse the repository at this point in the history
Closes #206

The error arises during usage of `csr.Skeleton.prune_paths()` which
takes a list of indices, typical from `csr.summarize()`, which are to be
pruned. If the index is outside of the range of rows in the data frame a
`ValueError` is thrown.

It was first highlighted during early development of pruning iteratively
but can arise any time in inappropriate value that is greater than the
number of paths in a skeleton is passed to `prune_path()` method.

---------

Co-authored-by: slackline <[email protected]>
Co-authored-by: Juan Nunez-Iglesias <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2023
1 parent b61b24a commit 91d4e7c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
45 changes: 35 additions & 10 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from skimage import morphology
from skimage.graph import central_pixel
from skimage.util._map_array import map_array, ArrayMap
import numpy.typing as npt
import numba
import warnings

from .nputil import _raveled_offsets_and_distances
from .summary_utils import find_main_branches


def _weighted_abs_diff(values0, values1, distances):
def _weighted_abs_diff(
values0: np.ndarray, values1: np.ndarray, distances: np.ndarray
) -> np.ndarray:
"""A default edge function for complete image graphs.
A pixel graph on an image with no edge values and no mask is a very
Expand Down Expand Up @@ -521,6 +524,7 @@ def __init__(
np.full(skeleton_image.ndim, spacing)
)
if keep_images:
self.keep_images = keep_images
self.skeleton_image = skeleton_image
self.source_image = source_image

Expand Down Expand Up @@ -550,7 +554,7 @@ def path(self, index):
start, stop = self.paths.indptr[index:index + 2]
return self.paths.indices[start:stop]

def path_coordinates(self, index):
def path_coordinates(self, index: int):
"""Return the image coordinates of the pixels in the path.
Parameters
Expand All @@ -566,7 +570,7 @@ def path_coordinates(self, index):
path_indices = self.path(index)
return self.coordinates[path_indices]

def path_with_data(self, index):
def path_with_data(self, index: int):
"""Return pixel indices and corresponding pixel values on a path.
Parameters
Expand Down Expand Up @@ -652,9 +656,28 @@ def path_stdev(self):
means = self.path_means()
return np.sqrt(np.clip(sumsq/lengths - means*means, 0, None))

def prune_paths(self, indices) -> 'Skeleton':
def prune_paths(self, indices: npt.ArrayLike) -> 'Skeleton':
"""Prune nodes from the skeleton.
Parameters
----------
indices: List[int]
List of indices to be removed.
Retruns
-------
Skeleton
A new Skeleton object pruned.
"""
# warning: slow
image_cp = np.copy(self.skeleton_image)
if not np.all(np.array(indices) < self.n_paths):
raise ValueError(
f'The path index {np.max(indices)} does not exist in this '
f'skeleton. (The highest path index is {self.n_paths}.)\n'
'If you obtained the index from a summary table, you '
'probably need to resummarize the skeleton.'
)
for i in indices:
pixel_ids_to_wipe = self.path(i)
junctions = self.degrees[pixel_ids_to_wipe] > 2
Expand All @@ -668,6 +691,7 @@ def prune_paths(self, indices) -> 'Skeleton':
new_skeleton,
spacing=self.spacing,
source_image=self.source_image,
keep_images=self.keep_images
)

def __array__(self, dtype=None):
Expand All @@ -676,8 +700,11 @@ def __array__(self, dtype=None):


def summarize(
skel: Skeleton, *, value_is_height=False, find_main_branch=False
):
skel: Skeleton,
*,
value_is_height: bool = False,
find_main_branch: bool = False
) -> pd.DataFrame:
"""Compute statistics for every skeleton and branch in ``skel``.
Parameters
Expand Down Expand Up @@ -1037,10 +1064,8 @@ def _simplify_graph(skel):

src_relab, dst_relab = fw_map[src], fw_map[dst]

edges = sparse.coo_matrix(
(distance, (src_relab, dst_relab)),
shape=(n_nodes, n_nodes)
)
edges = sparse.coo_matrix((distance, (src_relab, dst_relab)),
shape=(n_nodes, n_nodes))
dir_csgraph = edges.tocsr()
simp_csgraph = dir_csgraph + dir_csgraph.T # make undirected

Expand Down
58 changes: 57 additions & 1 deletion src/skan/test/test_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.testing import assert_equal, assert_almost_equal
from skimage.draw import line

from skan import csr
from skan import csr, summarize
from skan._testdata import (
tinycycle, tinyline, skeleton0, skeleton1, skeleton2, skeleton3d,
topograph1d, skeleton4
Expand Down Expand Up @@ -179,6 +179,62 @@ def test_transpose_image():
)


@pytest.mark.parametrize(
"skeleton,prune_branch,target",
[
(
skeleton1, 1,
np.array([[0, 1, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 0, 1],
[0, 1, 1, 0, 1, 1, 0], [0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]])
),
(
skeleton1, 2,
np.array([[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 2, 0, 0, 0],
[1, 0, 0, 0, 2, 2, 2]])
),
# There are no isolated cycles to be pruned
(
skeleton1, 3,
np.array([[0, 1, 1, 1, 1, 1, 0], [1, 0, 0, 0, 0, 0, 1],
[0, 3, 2, 0, 1, 1, 0], [3, 0, 0, 4, 0, 0, 0],
[3, 0, 0, 0, 4, 4, 4]])
),
]
)
def test_prune_paths(
skeleton: np.ndarray, prune_branch: int, target: np.ndarray
) -> None:
"""Test pruning of paths."""
s = csr.Skeleton(skeleton, keep_images=True)
summary = summarize(s)
indices_to_remove = summary.loc[summary['branch-type'] == prune_branch
].index
pruned = s.prune_paths(indices_to_remove)
np.testing.assert_array_equal(pruned, target)


def test_prune_paths_exception_single_point() -> None:
"""Test exceptions raised when pruning leaves a single point and Skeleton object
can not be created and returned."""
s = csr.Skeleton(skeleton0)
summary = summarize(s)
indices_to_remove = summary.loc[summary['branch-type'] == 1].index
with pytest.raises(ValueError):
s.prune_paths(indices_to_remove)


def test_prune_paths_exception_invalid_path_index() -> None:
"""Test exceptions raised when trying to prune paths that do not exist in the summary. This can arise if skeletons
are not updated correctly during iterative pruning."""
s = csr.Skeleton(skeleton0)
summary = summarize(s)
indices_to_remove = [6]
with pytest.raises(ValueError):
s.prune_paths(indices_to_remove)


def test_fast_graph_center_idx():
s = csr.Skeleton(skeleton0)
i = csr._fast_graph_center_idx(s)
Expand Down

0 comments on commit 91d4e7c

Please sign in to comment.