Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add slice interpolation #32

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4ccb5af
add initial interpolation widget
jamesyan-git May 30, 2022
b3a04af
incorporated changes to avoid interp_dim error
jamesyan-git Jun 2, 2022
46dceb1
basic container widget added
jamesyan-git Jun 16, 2022
d11ad4b
rough proof of concept for not asking user for input
jamesyan-git Jun 27, 2022
8f084f2
yapf
jamesyan-git Jun 27, 2022
3fcc3eb
refactored to track paint event for history
jamesyan-git Aug 2, 2022
15cb88d
added support for multiple lable IDs and multiple slices
jamesyan-git Aug 8, 2022
9cf0896
clean up yaml
jamesyan-git Aug 8, 2022
fc48307
Add docstrings
jamesyan-git Sep 26, 2022
e66aff1
add Genevieve tests, removed np.s_ and viewer
jamesyan-git Dec 2, 2022
2ecf462
remove line that reads data into array
jamesyan-git Jan 11, 2023
dd23e51
initial test for widget coverage
jamesyan-git Jan 11, 2023
7665367
add test for store_painted_sices
jamesyan-git Jan 13, 2023
c009b55
tidy up test_store_painted_slice
jamesyan-git Jan 13, 2023
a6dc939
removed unnecessary labels layer
jamesyan-git Jan 13, 2023
4442879
add test for distance_transfrom
jamesyan-git Jan 17, 2023
a7c50ff
added test for interpolate
jamesyan-git Feb 2, 2023
0d66c4e
added dimention combo box
jamesyan-git Feb 13, 2023
d5414cf
refactor interp_dim
jamesyan-git Feb 13, 2023
1dcfc45
test enter_interpolation_mode
jamesyan-git Feb 17, 2023
342b9a9
disable comboboxed when interpolating
jamesyan-git Feb 17, 2023
9f65ce4
check for layers before interpolation
jamesyan-git Feb 17, 2023
7fd2c21
Update gabrielBB action to aganders headless gui
jamesyan-git Apr 18, 2023
f9fc956
use setup-qt-libs
jamesyan-git Apr 18, 2023
5dcd6dd
update test_and_deploy to use latest versions of tests
jamesyan-git Apr 18, 2023
9ece5b0
Merge branch 'main' into slice-interpolation
jamesyan-git Apr 18, 2023
0b03144
fix tox.ini having two passenv on one line
jamesyan-git Apr 18, 2023
345dddc
Merge branch 'slice-interpolation' of github.com:jamesyan-git/zarpain…
jamesyan-git Apr 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.8, 3.9, '3.10']
python-version: ['3.8', '3.9', '3.10']

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -51,14 +51,15 @@ jobs:

# this runs the platform-specific tests declared in tox.ini
- name: Test with tox
uses: GabrielBB/xvfb-action@v1
uses: aganders3/headless-gui@v1
with:
run: python -m tox
env:
PLATFORM: ${{ matrix.platform }}

- name: Coverage
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v3


deploy:
# this will run when you have tagged a commit, starting with "v*"
Expand All @@ -68,9 +69,9 @@ jobs:
runs-on: ubuntu-latest
if: contains(github.ref, 'tags')
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: "3.x"
- name: Install dependencies
Expand Down
2 changes: 2 additions & 0 deletions src/zarpaint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._add_3d_points import add_points_3d_with_alt_click
from .reader import zarr_tensorstore
from ._copy_data import copy_data
from ._interpolate_labels import interpolate_between_slices

__all__ = [
'create_labels',
Expand All @@ -19,4 +20,5 @@
'add_points_3d_with_alt_click',
'zarr_tensorstore'
'copy_data',
'interpolate_between_slices',
]
329 changes: 329 additions & 0 deletions src/zarpaint/_interpolate_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
from collections import defaultdict
from magicgui.widgets import Container, ComboBox, PushButton
import napari
import numpy as np
from scipy.interpolate import interpn
from scipy.ndimage import distance_transform_edt


def signed_distance_transform(image):
"""apply distance transform and return new image

Parameters
----------
image : np.array
the label layer data

Returns
-------
np.array
label layer data with distance transform applied
"""
image = image.astype(bool)
edt = distance_transform_edt(image) - distance_transform_edt(~image)
return edt


def point_and_values(image_1, image_2, interp_axis=0):
"""apply distance transforms to the 2 images to interpolate.

Apply distance transforms to the 2 images that are going to be
interpolated between. Return a tuple of points and values where
points represent

Parameters
----------
image_1 : np.array
The first slice to interpolate from
image_2 : _type_
The second slice to interpolate to
interp_axis : int, optional
The dimension along which to interpolate, by default 0

Returns
-------
tuple of numpy array
tuple og distance transform data and corresponding location
"""
edt_1 = signed_distance_transform(image_1)
edt_2 = signed_distance_transform(image_2)
values = np.stack([edt_1, edt_2], axis=interp_axis)
points = tuple([np.arange(i) for i in values.shape])
return points, values


def xi_coords(shape, percent=0.5, interp_axis=0):
"""
create array of coordinates to interpolate between
Parameters
----------
shape : tuple
Shape of the slice
percent : float, optional
Value to populate the xi array, by default 0.5
interp_axis : int, optional
The axis to interpolate along , by default 0

Returns
-------
numpy array
Coordinate denoting the area of the slice to interpolate along
"""
slices = [slice(0, i) for i in shape]
xi = np.moveaxis(np.mgrid[slices], 0,
-1).reshape(np.prod(shape), len(shape)).astype('float')
xi = np.insert(xi, interp_axis, percent, axis=1)
return xi


def slice_iterator(slice_index_1, slice_index_2):
"""create an iterable across the range of slices to be interpolated

Parameters
----------
slice_index_1 : int
number of one bound of the slice range
slice_index_2 : _type_
the opposite bound of the slice range

Returns
-------
zip of 2 numpy arrays
tuple of slice indicies, tuple of percentages to give xi coords
"""
intermediate_slices = np.arange(slice_index_1 + 1, slice_index_2)
n_slices = slice_index_2 - slice_index_1 + 1
stepsize = 1 / n_slices
intermediate_percentages = np.arange(0 + stepsize, 1, stepsize)
return zip(intermediate_slices, intermediate_percentages)


def interpolated_slice(
percent, points, values, interp_axis=0, method='linear'
):
"""Create the dtata for one of the interpolated slices

Parameters
----------
percent : array_like
A value to populate the xi array
points : tuple of ndarray of float
The points of the slice on which to paint
values : array_like
Data to draw on the slice
interp_axis : int, optional
The axis along which to interpolate, by default 0
method : str, optional
Interpolation method, by default 'linear'

Returns
-------
np array
A slice with interpolated data drawn on
"""
# TODO: check return type
img_shape = list(values.shape)
del img_shape[interp_axis]

xi = xi_coords(img_shape, percent=percent, interp_axis=interp_axis)
interpolated_img = interpn(points, values, xi, method=method)
interpolated_img = np.reshape(interpolated_img, img_shape) > 0
return interpolated_img


class InterpolateSliceWidget(Container):
def __init__(self, viewer: "napari.viewer.Viewer"):
"""Widget for handling the interpolate slice gui and event callbacks

Parameters
----------
viewer : napari.viewer.Viewer
napari viewer to add the widget to
"""
super().__init__()
self.viewer = viewer
self.painted_slice_history = defaultdict(set)

self.labels_combo = ComboBox(
name='Labels Layer', choices=self.get_labels_layers
)
self.interp_dim_combo = ComboBox(
name="Interpolation Dimension",
choices=self.update_dim_choices
)

self.start_interpolation_btn = PushButton(name='Start Interpolation')
self.interpolate_btn = PushButton(name='Interpolate')
self.start_interpolation_btn.clicked.connect(
self.enter_interpolation_mode
)
self.extend([
self.labels_combo, self.interp_dim_combo,
self.start_interpolation_btn, self.interpolate_btn
])
self.interpolate_btn.hide()

self.interpolate_axis = 0
self.selected_layer = None

def update_dim_choices(self, interp_dim_combo):
layer_name = self.labels_combo.current_choice
if not layer_name:
return []
return list(range(self.viewer.layers[layer_name].data.ndim))

def get_labels_layers(self, combo):
"""Returns a list of existing labels to display

Parameters
----------
combo : magicgui ComboBox
A dropdown to dispaly the layers

Returns
-------
list[napari.layer.label]
A list of curently existing layers
"""
return [
layer for layer in self.viewer.layers
if isinstance(layer, napari.layers.Labels)
]

def store_painted_slices(self, event):
"""Identify slices that have been painetd on

Parameters
----------
event : Event
napari paint event
"""
last_label_history_item = event.value
real_item = []
# filter empty history atoms from item
for atom in last_label_history_item:
all_coords = list(atom[0])
if any([len(arr) for arr in all_coords]):
real_item.append(atom)
if not real_item:
return

# item is list of atoms. atom is (tuple of e.g. (y, x) painted coords, array of original label, new label)
last_label_history_item = real_item
last_label_coords = last_label_history_item[0][0]

unique_coords = list(map(np.unique, last_label_coords))

last_slice_painted = unique_coords[self.interpolate_axis][0]

label_id = last_label_history_item[-1][-1]

self.painted_slice_history[label_id].add(last_slice_painted)

def enter_interpolation_mode(self, event):
"""Connect the paint callback and change button text

Parameters
----------
event : Event
Event spawned by button click
"""
if not self.labels_combo.current_choice:
raise RuntimeError("No labels layer selected.")

self.selected_layer = self.viewer.layers[
self.labels_combo.current_choice]

self.selected_layer.events.paint.connect(self.store_painted_slices)

self.start_interpolation_btn.hide()
self.interpolate_btn.show()

self.interpolate_btn.clicked.connect(self.interpolate)
self.interpolate_axis = self.interp_dim_combo.get_value()
self.labels_combo.enabled = False
self.interp_dim_combo.enabled = False

def interpolate(self, event):
"""For each label_id, iterate over each slice that has been painted on
and perform pairwise (i, i+1) interpolation on each pair.

Parameters
----------
event : Event
Object created upon clicking "interpolate" in the widget
"""

for label_id, slices_painted in self.painted_slice_history.items():
slices_painted = list(sorted(slices_painted))
if len(slices_painted) > 1:
for i in range(1, len(slices_painted)):
interpolate_between_slices(
self.selected_layer, slices_painted[i - 1],
slices_painted[i], label_id, self.interpolate_axis
)

self.reset()

def reset(self):
"""Reset button text and clear paint event history
"""
self.selected_layer.events.paint.disconnect(self.store_painted_slices)
self.painted_slice_history.clear()
self.interpolate_axis = None
self.interpolate_btn.clicked.disconnect(self.interpolate)

self.interpolate_btn.hide()
self.start_interpolation_btn.show()
self.labels_combo.enabled = True
self.interp_dim_combo.enabled = True


def interpolate_between_slices(
label_layer: "napari.layers.Labels",
slice_index_1: int,
slice_index_2: int,
label_id: int = 1,
interpolate_axis: int = 0
):
"""Compute and draw interpolation between 2 label annotations.

Parameters
----------
label_layer : napari.layers.Labels
The label layer to draw on
slice_index_1 : int
slice containing the first label annotation
slice_index_2 : int
slice containing the second label anotation
interpolation occurs between slice_index_1 slice_index_2
label_id : int, optional
the id of the annotation that is to be painted, by default 1
interpolate_axis : int, optional
the dimension/axis to interpolate across, by default 0
"""

if slice_index_1 > slice_index_2:
slice_index_1, slice_index_2 = slice_index_2, slice_index_1
slice_1 = np.take(label_layer.data, slice_index_1, axis=interpolate_axis)
slice_2 = np.take(label_layer.data, slice_index_2, axis=interpolate_axis)

slice_1 = np.where(slice_1 == label_id, 1, 0)
slice_2 = np.where(slice_2 == label_id, 1, 0)

points, values = point_and_values(slice_1, slice_2, interpolate_axis)

for slice_number, percentage in slice_iterator(slice_index_1,
slice_index_2):
interpolated_img = interpolated_slice(
percentage,
points,
values,
interp_axis=interpolate_axis,
method='linear'
)
indices = [slice(None) for _ in range(label_layer.data.ndim)]
indices[interpolate_axis] = slice_number
indices = tuple(indices)
label_layer.data[indices][interpolated_img] = label_id
label_layer.refresh()
Loading