Skip to content

Commit

Permalink
fix(csr): Set type of node_props to float64 (#235)
Browse files Browse the repository at this point in the history
`dtype` of `node_props` passed from `csr_to_nbgraph()` to `NBGraph` was
`float32`. This sets the type to explicitly be `np.float64` which
matches the expectation of `csr_spec_float` passed to
`numba.experimental.jitclass()`.

The tests that were previously failing in
[TopoStats](https://github.com/AFM-SPM/TopoStats/) (see #234 for
details) now pass locally. 🙂

Closes #234

---------

Co-authored-by: Juan Nunez-Iglesias <[email protected]>
  • Loading branch information
ns-rse and jni authored Dec 13, 2024
1 parent 3d5da88 commit 207383e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def csr_to_nbgraph(csr, node_props=None):
csr.indices,
csr.data,
np.array(csr.shape, dtype=np.int32),
node_props,
node_props.astype(np.float64),
)


Expand Down Expand Up @@ -525,7 +525,7 @@ def __init__(
if np.issubdtype(skeleton_image.dtype, np.floating):
self.pixel_values = skeleton_image[coords]
elif np.issubdtype(skeleton_image.dtype, np.integer):
self.pixel_values = skeleton_image.astype(float)[coords]
self.pixel_values = skeleton_image.astype(np.float64)[coords]
else:
self.pixel_values = None
self.graph = graph
Expand Down
49 changes: 48 additions & 1 deletion src/skan/test/test_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from numpy.testing import assert_equal, assert_almost_equal
import pandas as pd
import pytest
import scipy
from scipy import ndimage as ndi
from skimage import data
from skimage.draw import line
from skimage.morphology import skeletonize

from skan import csr
from skan._testdata import (
Expand Down Expand Up @@ -357,6 +361,16 @@ def test_skeleton_integer_dtype(dtype):
assert stats['mean_pixel_value'].max() > 1


@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_skeleton_all_float_dtypes(dtype):
"""Test that skeleton data types can be both float32 and float64."""
horse = ~data.horse()
skeleton_image = skeletonize(horse)
dt = ndi.distance_transform_edt(horse)
float_skel = (dt * skeleton_image).astype(dtype)
_ = csr.Skeleton(float_skel)


def test_default_summarize_separator():
with pytest.warns(np.exceptions.VisibleDeprecationWarning,
match='separator in column name'):
Expand Down Expand Up @@ -523,7 +537,7 @@ def test_nx_to_skeleton(


@pytest.mark.parametrize(
'wrong_skeleton',
('wrong_skeleton'),
[
pytest.param(skeleton0, id='Numpy Array.'),
pytest.param(csr.Skeleton(skeleton0), id='Skeleton.'),
Expand All @@ -538,3 +552,36 @@ def test_nx_to_skeleton_attribute_error(wrong_skeleton: Any) -> None:
"""Test various errors are raised by nx_to_skeleton()."""
with pytest.raises(Exception):
csr.nx_to_skeleton(wrong_skeleton)


@pytest.mark.parametrize(
('skeleton'),
[
pytest.param(skeleton0, id='Numpy Array'),
pytest.param(csr.Skeleton(skeleton0), id='Skeleton'),
pytest.param(nx_graph, id='NetworkX Graph without edges.'),
],
)
def test_csr_to_nbgraph_attribute_error(skeleton: Any) -> None:
"""Raise AttributeError if csr_to_nbgraph() passed incomplete objects."""
with pytest.raises(AttributeError):
csr.csr_to_nbgraph(skeleton)


@pytest.mark.parametrize(
('graph'),
[
pytest.param(
scipy.sparse.csr_matrix(skeleton0),
id='Sparse matrix directly from Numpy Array',
),
pytest.param(
scipy.sparse.csr_matrix(csr.Skeleton(skeleton0)),
id='Sparse matrix from csr.Skeleton',
),
],
)
def test_csr_to_nbgraph_type_error(graph: scipy.sparse.csr_matrix) -> None:
"""Test TypeError is raised by csr_to_nbgraph() if wrong type is passed."""
with pytest.raises(TypeError):
csr.csr_to_nbgraph(graph)

0 comments on commit 207383e

Please sign in to comment.