Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for distributed skeleton analysis #124

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/bench_skan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def bench_suite():
times = OrderedDict()
skeleton = np.load(os.path.join(rundir, 'infected3.npz'))['skeleton']
with timer() as t_build_graph:
g, indices, degrees = csr.skeleton_to_csgraph(skeleton,
g, indices = csr.skeleton_to_csgraph(skeleton,
spacing=2.24826)
times['build graph'] = t_build_graph[0]
with timer() as t_build_graph2:
g, indices, degrees = csr.skeleton_to_csgraph(skeleton,
g, indices = csr.skeleton_to_csgraph(skeleton,
spacing=2.24826)
times['build graph again'] = t_build_graph2[0]
with timer() as t_stats:
Expand Down
11 changes: 5 additions & 6 deletions doc/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
"source": [
"from skan import skeleton_to_csgraph\n",
"\n",
"pixel_graph, coordinates, degrees = skeleton_to_csgraph(skeleton0)"
"pixel_graph, coordinates = skeleton_to_csgraph(skeleton0)"
]
},
{
Expand All @@ -242,8 +242,7 @@
"metadata": {},
"outputs": [],
"source": [
"pixel_graph0, coordinates0, degrees0 = skeleton_to_csgraph(skeleton0,\n",
" spacing=spacing_nm)"
"pixel_graph0, coordinates0 = skeleton_to_csgraph(skeleton0, spacing=spacing_nm)"
]
},
{
Expand Down Expand Up @@ -288,8 +287,8 @@
],
"source": [
"from skan import _testdata\n",
"g0, c0, _ = skeleton_to_csgraph(_testdata.skeleton0)\n",
"g1, c1, _ = skeleton_to_csgraph(_testdata.skeleton1)\n",
"g0, c0 = skeleton_to_csgraph(_testdata.skeleton0)\n",
"g1, c1 = skeleton_to_csgraph(_testdata.skeleton1)\n",
"fig, axes = plt.subplots(1, 2)\n",
"\n",
"draw.overlay_skeleton_networkx(g0, c0, image=_testdata.skeleton0,\n",
Expand Down Expand Up @@ -710,4 +709,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
16 changes: 4 additions & 12 deletions doc/getting_started.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

Getting started: Skeleton analysis with Skan
============================================

Expand Down Expand Up @@ -131,7 +130,7 @@ branches along that network.

>>> from skan import skeleton_to_csgraph
>>>
>>> pixel_graph, coordinates, degrees = skeleton_to_csgraph(skeleton0)
>>> pixel_graph, coordinates = skeleton_to_csgraph(skeleton0)

The pixel graph is a SciPy `CSR
matrix <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html>`__
Expand All @@ -146,8 +145,7 @@ in physical units instead of pixels:

.. nbplot::

>>> pixel_graph0, coordinates0, degrees0 = skeleton_to_csgraph(skeleton0,
... spacing=spacing_nm)
>>> pixel_graph0, coordinates0 = skeleton_to_csgraph(skeleton0, spacing=spacing_nm)

The second variable contains the coordinates (in pixel units) of the
points in the pixel graph. Finally, ``degrees`` is an image of the
Expand All @@ -166,8 +164,8 @@ recommended for very small networks.)
.. nbplot::

>>> from skan import _testdata
>>> g0, c0, _ = skeleton_to_csgraph(_testdata.skeleton0)
>>> g1, c1, _ = skeleton_to_csgraph(_testdata.skeleton1)
>>> g0, c0 = skeleton_to_csgraph(_testdata.skeleton0)
>>> g1, c1 = skeleton_to_csgraph(_testdata.skeleton1)
>>> fig, axes = plt.subplots(1, 2)
>>>
>>> draw.overlay_skeleton_networkx(g0, c0, image=_testdata.skeleton0,
Expand Down Expand Up @@ -506,9 +504,3 @@ This is of course a toy example. For the full dataset and analysis, see:
But we hope this minimal example will serve for inspiration for your
future analysis of skeleton images.

If you are interested in how we used `numba <https://numba.pydata.org/>`_
to accelerate some parts of Skan, check out `jni's talk <https://www.youtube.com/watch?v=0pUPNMglnaE>`_
at the the SciPy 2019 conference.


.. code-links:: python clear
66 changes: 53 additions & 13 deletions skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy import spatial
import numba

from .nputil import pad, raveled_steps_to_neighbors
from .nputil import raveled_steps_to_neighbors


## NBGraph and Numba-based implementation
Expand Down Expand Up @@ -318,9 +318,11 @@ class Skeleton:
def __init__(self, skeleton_image, *, spacing=1, source_image=None,
_buffer_size_offset=None, keep_images=True,
unique_junctions=True):
graph, coords, degrees = skeleton_to_csgraph(skeleton_image,
spacing=spacing,
unique_junctions=unique_junctions)
graph, coords = skeleton_to_csgraph(
skeleton_image,
spacing=spacing,
unique_junctions=unique_junctions,
)
if np.issubdtype(skeleton_image.dtype, np.float_):
pixel_values = ndi.map_coordinates(skeleton_image, coords.T,
order=3)
Expand All @@ -336,7 +338,6 @@ def __init__(self, skeleton_image, *, spacing=1, source_image=None,
self._distances_initialized = False
self.skeleton_image = None
self.source_image = None
self.degrees_image = degrees
self.degrees = np.diff(self.graph.indptr)
self.spacing = (np.asarray(spacing) if not np.isscalar(spacing)
else np.full(skeleton_image.ndim, spacing))
Expand Down Expand Up @@ -595,11 +596,12 @@ def skeleton_to_csgraph(skel, *, spacing=1, value_is_height=False,
to pixel coordinates in `degree_image` or `skel`. Array entry
(0,:) contains currently always zeros to index the pixels, which
start at 1, directly to the coordinates.
degree_image : array of int, same shape as skel
An image where each pixel value contains the degree of its
corresponding node in `graph`. This is useful to classify nodes.
"""
height = pad(skel, 0.) if value_is_height else None
height = (
np.pad(skel, 1, mode='constant', constant_values=0.)
if value_is_height
else None
)
# ensure we have a bool image, since we later use it for bool indexing
skel = skel.astype(bool)
ndim = skel.ndim
Expand Down Expand Up @@ -627,15 +629,16 @@ def skeleton_to_csgraph(skel, *, spacing=1, value_is_height=False,
pixel_indices[np.unique(labeled_junctions)[1:]] = centroids

num_edges = np.sum(degree_image) # *2, which is how many we need to store
skelint = pad(skelint, 0) # pad image to prevent looparound errors
# pad image to prevent looparound errors
skelint = np.pad(skelint, 1, mode='constant', constant_values=0)
steps, distances = raveled_steps_to_neighbors(skelint.shape, ndim,
spacing=spacing)
graph = _pixel_graph(skelint, steps, distances, num_edges, height)

if unique_junctions:
_uniquify_junctions(graph, pixel_indices,
labeled_junctions, centroids, spacing=spacing)
return graph, pixel_indices, degree_image
return graph, pixel_indices


@numba.jit(nopython=True, cache=True)
Expand Down Expand Up @@ -857,8 +860,8 @@ def summarise(image, *, spacing=1, using_height=False):
"""
ndim = image.ndim
spacing = np.ones(ndim, dtype=float) * spacing
g, coords_img, degrees = skeleton_to_csgraph(image, spacing=spacing,
value_is_height=using_height)
g, coords_img = skeleton_to_csgraph(image, spacing=spacing,
value_is_height=using_height)
num_skeletons, skeleton_ids = csgraph.connected_components(g,
directed=False)
if np.issubdtype(image.dtype, np.float_) and not using_height:
Expand Down Expand Up @@ -941,3 +944,40 @@ def compute_centroids(image):
sums = np.add.reduceat(coords[grouping], np.cumsum(sizes)[:-1])
means = sums / sizes[1:, np.newaxis]
return labeled_image, means


def make_degree_image(skeleton_image):
"""Create a array showing the degree of connectivity of each pixel.

Parameters
----------
skeleton_image : array
An input image in which every nonzero pixel is considered part of
the skeleton, and links between pixels are determined by a full
n-dimensional neighborhood.

Returns
-------
degree_image : array of int, same shape as skeleton_image
An image containing the degree of connectivity of each pixel in the
skeleton to neighboring pixels.
"""
bool_skeleton = skeleton_image.astype(bool)
degree_kernel = np.ones((3,) * bool_skeleton.ndim)
degree_kernel[(1,) * bool_skeleton.ndim] = 0 # remove centre pixel
if isinstance(bool_skeleton, np.ndarray):
degree_image = ndi.convolve(
bool_skeleton.astype(int),
degree_kernel,
mode='constant',
) * bool_skeleton
# use dask image for any array other than a numpy array (which isn't
# supported yet anyway)
else:
import dask.array as da
from dask_image.ndfilters import convolve as dask_convolve
if isinstance(bool_skeleton, da.Array):
degree_image = dask_convolve(bool_skeleton.astype(int),
degree_kernel,
mode='constant') * bool_skeleton
return degree_image
6 changes: 3 additions & 3 deletions skan/image_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def mesh_sizes(skeleton):
... [0, 1, 0, 1, 0]])
>>> print(mesh_sizes(image))
[]
>>> from skan.nputil import pad
>>> image2 = pad(image, 1) # make sure mesh not touching border
>>> # make sure mesh not touching border
>>> image2 = np.pad(image, 1, mode='constant', constant_values=1)
>>> print(mesh_sizes(image2)) # sizes in row order of first pixel in space
[7 2 3 1]
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def image_summary(skeleton, *, spacing=1):
"""
stats = pd.DataFrame()
stats['scale'] = [spacing]
g, coords, degimg = csr.skeleton_to_csgraph(skeleton, spacing=spacing)
g, coords = csr.skeleton_to_csgraph(skeleton, spacing=spacing)
degrees = np.diff(g.indptr)
num_junctions = np.sum(degrees > 2)
stats['number of junctions'] = num_junctions
Expand Down
63 changes: 0 additions & 63 deletions skan/nputil.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,69 +68,6 @@ def smallest_int_dtype(number, *, signed=False, min_dtype=np.int8):
return dtype


# adapted from github.com/janelia-flyem/gala
def pad(ar, vals, *, axes=None):
"""Pad an array with values in `vals` along `axes`.

Parameters
----------
ar : array, shape (M, N, ...)
The input array.
vals : int or iterable of int, shape (K,)
The values to pad with.
axes : int in {0, ..., `ar.ndim`}, or iterable thereof, optional
The axes of `ar` to pad. If None, pad along all axes.

Returns
-------
ar2 : array, shape (M+2K, N+2K, ...)
The padded array.

Examples
--------
>>> ar = np.array([4, 5, 6])
>>> pad(ar, 0)
array([0, 4, 5, 6, 0])
>>> pad(ar, [0, 1])
array([1, 0, 4, 5, 6, 0, 1])
>>> ar = np.array([[4, 5, 6]])
>>> pad(ar, 0)
array([[0, 0, 0, 0, 0],
[0, 4, 5, 6, 0],
[0, 0, 0, 0, 0]])
>>> pad(ar, 0, axes=1)
array([[0, 4, 5, 6, 0]])
"""
if axes is None:
axes = list(range(ar.ndim))
if not isinstance(vals, collections.Iterable):
vals = [vals]
if not isinstance(axes, collections.Iterable):
axes = [axes]
p = len(vals)
newshape = np.array(ar.shape)
for ax in axes:
newshape[ax] += 2*p
vals = np.reshape(vals, (p,) + (1,) * (ar.ndim-1))
new_dtype = ar.dtype
if np.issubdtype(new_dtype, np.integer):
maxval = max([np.max(vals), np.max(ar)])
minval = min([np.min(vals), np.min(ar)])
signed = (minval < 0)
maxval = max(abs(minval), maxval)
new_dtype = smallest_int_dtype(maxval, signed=signed,
min_dtype=new_dtype)
ar2 = np.empty(newshape, dtype=new_dtype)
center = np.ones(newshape, dtype=bool)
for ax in axes:
ar2.swapaxes(0, ax)[p-1::-1,...] = vals
ar2.swapaxes(0, ax)[-p:,...] = vals
center.swapaxes(0, ax)[p-1::-1,...] = False
center.swapaxes(0, ax)[-p:,...] = False
ar2[center] = ar.ravel()
return ar2


def raveled_steps_to_neighbors(shape, connectivity=1, *, order='C', spacing=1,
return_distances=True):
"""Return raveled coordinate steps for given array shape and neighborhood.
Expand Down
18 changes: 7 additions & 11 deletions skan/test/test_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_tiny_cycle():
g, idxs, degimg = csr.skeleton_to_csgraph(tinycycle)
g, idxs = csr.skeleton_to_csgraph(tinycycle)
expected_indptr = [0, 0, 2, 4, 6, 8]
expected_indices = [2, 3, 1, 4, 1, 4, 2, 3]
expected_data = np.sqrt(2)
Expand All @@ -20,14 +20,12 @@ def test_tiny_cycle():
assert_equal(g.indices, expected_indices)
assert_almost_equal(g.data, expected_data)

expected_degrees = np.array([[0, 2, 0], [2, 0, 2], [0, 2, 0]])
assert_equal(degimg, expected_degrees)
assert_equal(np.ravel_multi_index(idxs.astype(int).T, tinycycle.shape),
[0, 1, 3, 5, 7])


def test_skeleton1_stats():
g, idxs, degimg = csr.skeleton_to_csgraph(skeleton1)
g, idxs = csr.skeleton_to_csgraph(skeleton1)
stats = csr.branch_statistics(g)
assert_equal(stats.shape, (4, 4))
keys = map(tuple, stats[:, :2].astype(int))
Expand Down Expand Up @@ -62,9 +60,8 @@ def test_summarise_spacing():


def test_line():
g, idxs, degimg = csr.skeleton_to_csgraph(tinyline)
g, idxs = csr.skeleton_to_csgraph(tinyline)
assert_equal(np.ravel(idxs), [0, 1, 2, 3])
assert_equal(degimg, [0, 1, 2, 1, 0])
assert_equal(g.shape, (4, 4))
assert_equal(csr.branch_statistics(g), [[1, 3, 2, 0]])

Expand All @@ -76,15 +73,15 @@ def test_cycle_stats():


def test_3d_spacing():
g, idxs, degimg = csr.skeleton_to_csgraph(skeleton3d, spacing=[5, 1, 1])
g, idxs = csr.skeleton_to_csgraph(skeleton3d, spacing=[5, 1, 1])
stats = csr.branch_statistics(g)
assert_equal(stats.shape, (5, 4))
assert_almost_equal(stats[0], [1, 5, 10.467, 1], decimal=3)
assert_equal(np.unique(stats[:, 3].astype(int)), [1, 2, 3])


def test_topograph():
g, idxs, degimg = csr.skeleton_to_csgraph(topograph1d,
g, idxs = csr.skeleton_to_csgraph(topograph1d,
value_is_height=True)
stats = csr.branch_statistics(g)
assert stats.shape == (1, 4)
Expand All @@ -101,10 +98,9 @@ def test_topograph_summary():

def test_junction_multiplicity():
"""Test correct distances when a junction has more than one pixel."""
g, idxs, degimg = csr.skeleton_to_csgraph(skeleton0)
g, idxs = csr.skeleton_to_csgraph(skeleton0)
assert_almost_equal(g[3, 5], 2.0155644)
g, idxs, degimg = csr.skeleton_to_csgraph(skeleton0,
unique_junctions=False)
g, idxs = csr.skeleton_to_csgraph(skeleton0, unique_junctions=False)
assert_almost_equal(g[2, 3], 1.0)
assert_almost_equal(g[3, 6], np.sqrt(2))

Expand Down
4 changes: 2 additions & 2 deletions skan/test/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def filtered(skeleton):


def test_networkx_plot():
g0, c0, _ = csr.skeleton_to_csgraph(_testdata.skeleton0)
g1, c1, _ = csr.skeleton_to_csgraph(_testdata.skeleton1)
g0, c0 = csr.skeleton_to_csgraph(_testdata.skeleton0)
g1, c1 = csr.skeleton_to_csgraph(_testdata.skeleton1)
fig, axes = plt.subplots(1, 2)
draw.overlay_skeleton_networkx(g0, c0, image=_testdata.skeleton0,
axis=axes[0])
Expand Down