Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add slice interpolation #32

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4ccb5af
add initial interpolation widget
jamesyan-git May 30, 2022
b3a04af
incorporated changes to avoid interp_dim error
jamesyan-git Jun 2, 2022
46dceb1
basic container widget added
jamesyan-git Jun 16, 2022
d11ad4b
rough proof of concept for not asking user for input
jamesyan-git Jun 27, 2022
8f084f2
yapf
jamesyan-git Jun 27, 2022
3fcc3eb
refactored to track paint event for history
jamesyan-git Aug 2, 2022
15cb88d
added support for multiple lable IDs and multiple slices
jamesyan-git Aug 8, 2022
9cf0896
clean up yaml
jamesyan-git Aug 8, 2022
fc48307
Add docstrings
jamesyan-git Sep 26, 2022
e66aff1
add Genevieve tests, removed np.s_ and viewer
jamesyan-git Dec 2, 2022
2ecf462
remove line that reads data into array
jamesyan-git Jan 11, 2023
dd23e51
initial test for widget coverage
jamesyan-git Jan 11, 2023
7665367
add test for store_painted_sices
jamesyan-git Jan 13, 2023
c009b55
tidy up test_store_painted_slice
jamesyan-git Jan 13, 2023
a6dc939
removed unnecessary labels layer
jamesyan-git Jan 13, 2023
4442879
add test for distance_transfrom
jamesyan-git Jan 17, 2023
a7c50ff
added test for interpolate
jamesyan-git Feb 2, 2023
0d66c4e
added dimention combo box
jamesyan-git Feb 13, 2023
d5414cf
refactor interp_dim
jamesyan-git Feb 13, 2023
1dcfc45
test enter_interpolation_mode
jamesyan-git Feb 17, 2023
342b9a9
disable comboboxed when interpolating
jamesyan-git Feb 17, 2023
9f65ce4
check for layers before interpolation
jamesyan-git Feb 17, 2023
7fd2c21
Update gabrielBB action to aganders headless gui
jamesyan-git Apr 18, 2023
f9fc956
use setup-qt-libs
jamesyan-git Apr 18, 2023
5dcd6dd
update test_and_deploy to use latest versions of tests
jamesyan-git Apr 18, 2023
9ece5b0
Merge branch 'main' into slice-interpolation
jamesyan-git Apr 18, 2023
0b03144
fix tox.ini having two passenv on one line
jamesyan-git Apr 18, 2023
345dddc
Merge branch 'slice-interpolation' of github.com:jamesyan-git/zarpain…
jamesyan-git Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/zarpaint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._add_3d_points import add_points_3d_with_alt_click
from .reader import zarr_tensorstore
from ._copy_data import copy_data
from ._interpolate_labels import interpolate_between_slices

__all__ = [
'create_labels',
Expand All @@ -19,4 +20,5 @@
'add_points_3d_with_alt_click',
'zarr_tensorstore'
'copy_data',
'interpolate_between_slices',
]
341 changes: 341 additions & 0 deletions src/zarpaint/_interpolate_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
from collections import defaultdict
import warnings
from magicgui.widgets import Container, ComboBox, PushButton
import napari
import numpy as np
from scipy.interpolate import interpn
from scipy.ndimage import distance_transform_edt


def distance_transform(image):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this should be renamed signed_distance_transform

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesyan-git can you rename this function? 😉

"""apply distance transform and return new image

Parameters
----------
image : np.array
the label layer data

Returns
-------
np.array
label layer data with distance transform applied
"""
image = image.astype(bool)
edt = distance_transform_edt(image) - distance_transform_edt(~image)
return edt


def point_and_values(image_1, image_2, interp_dim=0):
"""apply distance transforms to the 2 images to interpolate.

Apply distance transforms to the 2 images that are going to be
interpolated between. Return a tuple of points and values where
points represent

Parameters
----------
image_1 : np.array
The first slice to interpolate from
image_2 : _type_
The second slice to interpolate to
interp_dim : int, optional
The dimention along which to interpolate, by default 0

Returns
-------
tuple of numpy array
tuple og distance transform data and corresponding location
"""
edt_1 = distance_transform(image_1)
edt_2 = distance_transform(image_2)
values = np.stack([edt_1, edt_2], axis=interp_dim)
points = tuple([np.arange(i) for i in values.shape])
return points, values


def xi_coords(shape, percent=0.5, interp_dim=0):
"""
create array of coordinates to interpolate between
Parameters
----------
shape : tuple
Shape of the slice
percent : float, optional
Value to populate the xi array, by default 0.5
interp_dim : int, optional
The axis to interpolate along , by default 0

Returns
-------
numpy array
Coordinate denoting the area of the slice to interpolate along
"""
slices = [slice(0, i) for i in shape]
xi = np.moveaxis(np.mgrid[slices], 0,
-1).reshape(np.prod(shape), len(shape)).astype('float')
xi = np.insert(xi, interp_dim, percent, axis=1)
return xi


def slice_iterator(slice_index_1, slice_index_2):
"""create an iterable across the range of slices to be interpolated

Parameters
----------
slice_index_1 : int
number of one bound of the slice range
slice_index_2 : _type_
the opposite bound of the slice range

Returns
-------
zip of 2 numpy arrays
tuple of slice indicies, tuple of percentages to give xi coords
"""
intermediate_slices = np.arange(slice_index_1 + 1, slice_index_2)
n_slices = slice_index_2 - slice_index_1 + 1
stepsize = 1 / n_slices
intermediate_percentages = np.arange(0 + stepsize, 1, stepsize)
return zip(intermediate_slices, intermediate_percentages)


def interpolated_slice(percent, points, values, interp_dim=0, method='linear'):
"""Create the dtata for one of the interpolated slices

Parameters
----------
percent : array_like
A value to populate the xi array
points : tuple of ndarray of float
The points of the slice on which to paint
values : array_like
Data to draw on the slice
interp_dim : int, optional
The axis along which to interpolate, by default 0
method : str, optional
Interpolation method, by default 'linear'

Returns
-------
np array
A slice with interpolated data drawn on
"""
# TODO: check return type
img_shape = list(values.shape)
del img_shape[interp_dim]

xi = xi_coords(img_shape, percent=percent, interp_dim=interp_dim)
interpolated_img = interpn(points, values, xi, method=method)
interpolated_img = np.reshape(interpolated_img, img_shape) > 0
return interpolated_img


class InterpolateSliceWidget(Container):
def __init__(self, viewer: "napari.viewer.Viewer"):
"""Widget for handling the interpolate slice gui and event callbacks

Parameters
----------
viewer : napari.viewer.Viewer
napari viewer to add the widget to
"""
super().__init__()
self.viewer = viewer
self.labels_combo = ComboBox(
name='Labels Layer', choices=self.get_labels_layers
)
self.start_interpolation_btn = PushButton(name='Start Interpolation')
self.interpolate_btn = PushButton(name='Interpolate')
self.start_interpolation_btn.clicked.connect(self.enter_interpolation)
self.extend([
self.labels_combo, self.start_interpolation_btn,
self.interpolate_btn
])
self.interpolate_btn.hide()
self.painted_slice_history = defaultdict(set)
self.interp_dim = None

def get_labels_layers(self, combo):
"""Returns a list of existing labels to display

Parameters
----------
combo : magicgui ComboBox
A dropdown to dispaly the layers

Returns
-------
list[napari.layer.label]
A list of curently existing layers
"""
return [
layer for layer in self.viewer.layers
if isinstance(layer, napari.layers.Labels)
]

def paint_callback(self, event):
"""Identify slices that have been painetd on

Parameters
----------
event : Event
napari paint event
"""
last_label_history_item = event.value
real_item = []
# filter empty history atoms from item
for atom in last_label_history_item:
all_coords = list(atom[0])
if any([len(arr) for arr in all_coords]):
real_item.append(atom)
if not real_item:
return

last_label_history_item = real_item
last_label_coords = last_label_history_item[0][
0
] # first history atom is a tuple, first element of atom is coords
# here we can determine both the slice index that was painted *and* the interp dim
# it wil be the array in last_label_coords that has only one unique element in it
# the interp dim will be the index of that array in the tuple
if not last_label_coords:
return

unique_coords = list(map(np.unique, last_label_coords))
if self.interp_dim is None:
self._infer_interp_dim(unique_coords)

last_slice_painted = unique_coords[self.interp_dim][0]

label_id = last_label_history_item[-1][-1]

self.painted_slice_history[label_id].add(last_slice_painted)

def _infer_interp_dim(self, unique_coords):
"""Infer the dimension/axis on which to interpolate.

unique_coords contains a list for each dimension.
One of the lists will contain a single element referencing the slice
being painted on. This mean that that lists is the dimension being
painted across

Parameters
----------
unique_coords : List
A list of lists, containing coordinates which have been painted
and the label which is being painted
"""
interp_dim = None
for i in range(len(unique_coords)):
if len(unique_coords[i]) == 1:
interp_dim = i
break
if interp_dim == None:
warnings.warn(
"Couldn't determine axis for interpolation. Using axis 0 by default."
)
self.interp_dim = 0
else:
self.interp_dim = interp_dim

def enter_interpolation(self, event):
"""Connect the paint callback and change button text

Parameters
----------
event : Event
Event spawned by button click
"""
self.selected_layer = self.viewer.layers[
self.labels_combo.current_choice]

self.selected_layer.events.paint.connect(self.paint_callback)

self.start_interpolation_btn.hide()
self.interpolate_btn.show()

self.interpolate_btn.clicked.connect(self.interpolate)

def interpolate(self, event):
"""For each label_id, iterate over each slice that has been painted on
and perform pairwise (i, i+1) interpolation on each pair.

Parameters
----------
event : Event
Object created upon clicking "interpolate" in the widget
"""
assert self.interp_dim is not None, 'Cannot interpolate without knowing dimension'

for label_id, slices_painted in self.painted_slice_history.items():
slices_painted = list(sorted(slices_painted))
if len(slices_painted) > 1:
for i in range(1, len(slices_painted)):
interpolate_between_slices(
self.selected_layer, slices_painted[i - 1],
slices_painted[i], label_id, self.interp_dim
)

self.reset()

def reset(self):
"""Reset button text and clear paint event history
"""
self.selected_layer.events.paint.disconnect(self.paint_callback)
self.painted_slice_history.clear()
self.interp_dim = None
self.interpolate_btn.clicked.disconnect(self.interpolate)

self.interpolate_btn.hide()
self.start_interpolation_btn.show()


def interpolate_between_slices(
label_layer: "napari.layers.Labels",
slice_index_1: int,
slice_index_2: int,
label_id: int = 1,
interp_dim: int = 0
):
"""Compute and draw interpolation between 2 label annotations.

Parameters
----------
label_layer : napari.layers.Labels
The label layer to draw on
slice_index_1 : int
slice containing the first label annotation
slice_index_2 : int
slice containing the second label anotation
interpolation occurs between slice_index_1 slice_index_2
label_id : int, optional
the id of the annotation that is to be painted, by default 1
interp_dim : int, optional
the dimension/axis to interpolate across, by default 0
"""

if slice_index_1 > slice_index_2:
slice_index_1, slice_index_2 = slice_index_2, slice_index_1
layer_data = np.asarray(label_layer.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line needs to be removed. (I assume it's here because you were debugging something, but we can't have an entire full size volume read into memory).

We should replace the following two lines with:

    slice_1 = np.take(label_layer.data, slice_index_1, axis=interp_dim)
    slice_2 = np.take(label_layer.data, slice_index_2, axis=interp_dim)

slice_1 = np.take(layer_data, slice_index_1, axis=interp_dim)
slice_2 = np.take(layer_data, slice_index_2, axis=interp_dim)

slice_1 = np.where(slice_1 == label_id, 1, 0)
slice_2 = np.where(slice_2 == label_id, 1, 0)

points, values = point_and_values(slice_1, slice_2, interp_dim)

for slice_number, percentage in slice_iterator(slice_index_1,
slice_index_2):
interpolated_img = interpolated_slice(
percentage,
points,
values,
interp_dim=interp_dim,
method='linear'
)
indices = [slice(None) for _ in range(label_layer.data.ndim)]
indices[interp_dim] = slice_number
indices = tuple(indices)
label_layer.data[indices][interpolated_img] = label_id
label_layer.refresh()
1 change: 0 additions & 1 deletion src/zarpaint/_tests/test_watershed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def test_watershed_split_2d(make_napari_viewer):
# create 2 squares with one corner overlapping
data[1:10, 1:10] = 1
data[8:17, 8:17] = 1
print(data)

# palce points in the centre of the 2 squares
points = np.asarray([[5, 5], [12, 12]])
Expand Down
5 changes: 5 additions & 0 deletions src/zarpaint/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ contributions:
- id: zarpaint.copy_data
title: Copy Data…
python_name: zarpaint:copy_data
- id: zarpaint.interpolate_widg
title: Interpolate
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
title: Interpolate
title: Interpolate Slices with Signed Distance Transform

?

Maybe we can remove "signed" or Interpolate Slices (w/ Distance Transform) if length is an issue

python_name: zarpaint._interpolate_labels:InterpolateSliceWidget

readers:
- command: zarpaint.get_reader
Expand All @@ -42,3 +45,5 @@ contributions:
display_name: Split With Watershed
- command: zarpaint.copy_data
display_name: Copy Data
- command: zarpaint.interpolate_widg
display_name: Interpolate Slices
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
display_name: Interpolate Slices
display_name: Interpolate Slices with Signed Distance Transform

Just preparing for when we add future methods. I also hate that it's so hard for me to find out in other interpolation tools what method they're actually using!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like when we have future methods, there should be a dropdown selection box in the interpolation widget to switch between them (instead of creating additional widgets, which this comment seems to suggest).