Skip to content

Commit

Permalink
Merge pull request #7 from TimMonko/#4-fix
Browse files Browse the repository at this point in the history
Batch training and prediction widgets
  • Loading branch information
TimMonko authored Feb 11, 2023
2 parents f89c174 + 01bd63c commit de91d41
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 3 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ install_requires =
magicgui
qtpy
aicsimageio
napari
apoc
pyclesperanto_prototype
dask

python_requires = >=3.8
include_package_data = True
Expand Down
4 changes: 3 additions & 1 deletion src/napari_ndev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
except ImportError:
__version__ = "unknown"

from ._widget import batch_annotator
from ._widget import batch_annotator, batch_predict, batch_training

__all__ = [
"batch_annotator",
"batch_training",
"batch_predict",
]
40 changes: 39 additions & 1 deletion src/napari_ndev/_tests/test_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from napari_ndev import batch_annotator
from napari_ndev import batch_annotator # , batch_predict, batch_training


# make_napari_viewer is a pytest fixture that returns a napari viewer object
Expand All @@ -22,3 +22,41 @@ def test_batch_annotator(make_napari_viewer, capsys):
# read captured output and check that it's as we expected
# captured = capsys.readouterr()
# assert captured.out == "napari has 1 layers\n"


# def test_batch_training(make_napari_viewer, capsys):
# # make viewer and add an image layer using our fixture
# viewer = make_napari_viewer()
# test_image = np.random.random((100, 100))
# viewer.add_image(test_image)
# test_thresh = test_image > 1
# viewer.add_labels(test_thresh)

# # create our widget, passing in the viewer
# my_widget = batch_training()
# my_widget()
# # call our widget method
# # my_widget._on_click()

# # read captured output and check that it's as we expected
# # captured = capsys.readouterr()
# # assert captured.out == "napari has 1 layers\n"


# def test_batch_predict(make_napari_viewer, capsys):
# # make viewer and add an image layer using our fixture
# viewer = make_napari_viewer()
# test_image = np.random.random((100, 100))
# viewer.add_image(test_image)
# test_thresh = test_image > 1
# viewer.add_labels(test_thresh)

# # create our widget, passing in the viewer
# my_widget = batch_predict()
# my_widget()
# # call our widget method
# # my_widget._on_click()

# # read captured output and check that it's as we expected
# # captured = capsys.readouterr()
# # assert captured.out == "napari has 1 layers\n"
125 changes: 125 additions & 0 deletions src/napari_ndev/_widget.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""
neural development (nDev) widget collection
"""
import os
import pathlib
import string
from enum import Enum
from typing import TYPE_CHECKING

import apoc
import dask.array as da
import napari
import numpy as np
import pyclesperanto_prototype as cle
from aicsimageio import AICSImage
from magicgui import magic_factory
from napari import layers
Expand Down Expand Up @@ -83,3 +88,123 @@ def saver(image_type, folder_suffix, save_suffix_str):
saver(image, "_images", save_suffix)
saver(labels, "_labels", save_suffix)
return "Saved Successfully"


channel_nums = [0, 1, 2, 3, 4]
PDFS = Enum("PDFS", apoc.PredefinedFeatureSet._member_names_)


@magic_factory(
auto_call=False,
call_button="Batch Train",
result_widget=True,
image_directory=dict(widget_type="FileEdit", mode="d"),
label_directory=dict(widget_type="FileEdit", mode="d"),
predefined_features=dict(widget_type="ComboBox", choices=PDFS),
channel_list=dict(widget_type="Select", choices=channel_nums),
)
def batch_training(
image_directory=pathlib.Path(),
label_directory=pathlib.Path(),
cl_filename: str = "classifier.cl",
predefined_features=PDFS(1),
custom_features: str = None,
channel_list: int = 0,
img_dims: str = "TYX",
label_dims: str = "ZYX",
):
image_list = os.listdir(image_directory)

apoc.erase_classifier(cl_filename)
custom_classifier = apoc.PixelClassifier(opencl_filename=cl_filename)

for file in image_list:

image_stack = []
img = AICSImage(image_directory / file)

def channel_image(img, dims: str, channel: str or int):
if isinstance(channel, str):
channel_index = img.channel_names.index(channel)
elif isinstance(channel, int):
channel_index = channel
channel_img = img.get_image_data(dims, C=channel_index)
return channel_img

for channels in channel_list:
ch_img = channel_image(img=img, dims=img_dims, channel=channels)
image_stack.append(ch_img)

dask_stack = da.stack(image_stack, axis=0)

lbl = AICSImage(label_directory / file)
labels = channel_image(img=lbl, dims=label_dims, channel=0)

if predefined_features.value == 1:
print("custom")
feature_set = custom_features

else:
print("predefined")
feature_set = apoc.PredefinedFeatureSet[
predefined_features.name
].value

custom_classifier.train(
features=feature_set,
image=dask_stack,
ground_truth=labels,
continue_training=True,
)

feature_importances = custom_classifier.feature_importances()
print("success")
# return pd.Series(feature_importances).plot.bar()

return feature_importances


@magic_factory(
auto_call=False,
call_button="Batch Predict",
image_directory=dict(widget_type="FileEdit", mode="d"),
result_directory=dict(widget_type="FileEdit", mode="d"),
classifier_path=dict(widget_type="FileEdit", mode="r"),
channel_list=dict(widget_type="Select", choices=channel_nums),
)
def batch_predict(
image_directory=pathlib.Path(),
result_directory=pathlib.Path(),
classifier_path=pathlib.Path(),
channel_list: int = 0,
img_dims: str = "TYX",
):
image_list = os.listdir(image_directory)
custom_classifier = apoc.PixelClassifier(opencl_filename=classifier_path)

for file in image_list:
# print('started predicting: ', file)
image_stack = []
img = AICSImage(image_directory / file)

def channel_image(img, dims: str, channel: str or int):
if isinstance(channel, str):
channel_index = img.channel_names.index(channel)
elif isinstance(channel, int):
channel_index = channel
channel_img = img.get_image_data(dims, C=channel_index)
return channel_img

for channels in channel_list:
ch_img = channel_image(img=img, dims=img_dims, channel=channels)
image_stack.append(ch_img)

dask_stack = da.stack(image_stack, axis=0)

result = custom_classifier.predict(
image=dask_stack,
)

AICSImage(cle.pull(result)).save(uri=result_directory / file)

return result
12 changes: 11 additions & 1 deletion src/napari_ndev/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@ contributions:
commands:
- id: napari-ndev.make_batch_annotator
python_name: napari_ndev._widget:batch_annotator
title: Make Batch Annotator
title: Make Batch Annotator Widget
- id: napari-ndev.make_batch_training
python_name: napari_ndev._widget:batch_training
title: Make Batch Training Widget
- id: napari-ndev.make_batch_predict
python_name: napari_ndev._widget:batch_predict
title: Make Batch Predict Widget

widgets:
- command: napari-ndev.make_batch_annotator
display_name: Batch Annotator
- command: napari-ndev.make_batch_training
display_name: Batch APOC Training
- command: napari-ndev.make_batch_predict
display_name: Batch APOC Predict

0 comments on commit de91d41

Please sign in to comment.