Skip to content

Commit

Permalink
Merge pull request #192 from kushalbakshi/dev_widget
Browse files Browse the repository at this point in the history
Add ROI drawing widget
  • Loading branch information
ttngu207 authored Apr 25, 2024
2 parents 3b420a7 + df2c834 commit c8805a5
Show file tree
Hide file tree
Showing 8 changed files with 468 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ example_data

# vscode
*.code-workspace

# dash widget
file_system_backend
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.10.0] - 2024-04-09

+ Add - ROI mask creation widget
+ Update documentation for using the included widgets in the package

## [0.9.5] - 2024-03-22

+ Add - pytest
Expand Down Expand Up @@ -209,6 +214,9 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
+ Add - `scan` and `imaging` modules
+ Add - Readers for `ScanImage`, `ScanBox`, `Suite2p`, `CaImAn`

[0.10.0]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.10.0
[0.9.5]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.5
[0.9.4]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.4
[0.9.3]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.3
[0.9.2]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.2
[0.9.1]: https://github.com/datajoint/element-calcium-imaging/releases/tag/0.9.1
Expand Down
2 changes: 2 additions & 0 deletions docs/src/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ the common motifs to create Element Calcium Imaging. Major features include:
- [ ] Deepinterpolation
- [x] Data export to NWB
- [x] Data publishing to DANDI
- [x] Widgets for manual ROI mask creation and curation for cell segmentation of Fluorescent voltage sensitive indicators, neurotransmitter imaging, and neuromodulator imaging
- [ ] Expand creation widget to provide pixel weights for each mask based on Fluorescence intensity traces at each pixel

Further development of this Element is community driven. Upon user requests and based on
guidance from the Scientific Steering Group we will continue adding features to this
Expand Down
5 changes: 5 additions & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ please set `processing_method="extract"` in the
ProcessingParamSet table, and provide the `params` attribute of the ProcessingParamSet
table in the `{'suite2p': {...}, 'extract': {...}}` dictionary format. Please also
install the [MATLAB engine](https://pypi.org/project/matlabengine/) API for Python.

## Manual ROI Mask Creation and Curation

+ Manual creation of ROI masks for fluorescence activity extraction is supported by the `draw_rois.py` plotly/dash widget. This widget allows the user to draw new ROI masks and submit them to the database. The widget can be launched in a Jupyter notebook after following the [installation instructions](#installation-instructions-for-active-projects) and importing `draw_rois` from the module.
+ ROI masks can be curated using the `widget.py` jupyter widget that allows the user to mark each regions as either a `cell` or `non-cell`. This widget can be launched in a Jupyter notebook after following the [installation instructions](#installation-instructions-for-active-projects) and importing `main` from the module.
228 changes: 228 additions & 0 deletions element_calcium_imaging/plotting/draw_rois.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import yaml
import datajoint as dj
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from dash import no_update
from dash_extensions.enrich import (
DashProxy,
Input,
Output,
State,
html,
dcc,
Serverside,
ServersideOutputTransform,
)

from .utilities import *


logger = dj.logger


def draw_rois(db_prefix: str):
scan = dj.create_virtual_module("scan", f"{db_prefix}scan")
imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging")
all_keys = (imaging.MotionCorrection).fetch("KEY")

colors = {"background": "#111111", "text": "#00a0df"}

app = DashProxy(transforms=[ServersideOutputTransform()])
app.layout = html.Div(
[
html.H2("Draw ROIs", style={"color": colors["text"]}),
html.Label(
"Select data key from dropdown", style={"color": colors["text"]}
),
dcc.Dropdown(
id="toplevel-dropdown", options=[str(key) for key in all_keys]
),
html.Br(),
html.Div(
[
html.Button(
"Load Image",
id="load-image-button",
style={"margin-right": "20px"},
),
dcc.RadioItems(
id="image-type-radio",
options=[
{"label": "Average Image", "value": "average_image"},
{
"label": "Max Projection Image",
"value": "max_projection_image",
},
],
value="average_image",
labelStyle={"display": "inline-block", "margin-right": "10px"},
style={"display": "inline-block", "color": colors["text"]},
),
html.Div(
[
html.Button("Submit Curated Masks", id="submit-button"),
],
style={
"textAlign": "right",
"flex": "1",
"display": "inline-block",
},
),
],
style={
"display": "flex",
"justify-content": "flex-start",
"align-items": "center",
},
),
html.Br(),
html.Br(),
html.Div(
[
dcc.Graph(
id="avg-image",
config={
"modeBarButtonsToAdd": [
"drawclosedpath",
"drawrect",
"drawcircle",
"drawline",
"eraseshape",
],
},
style={"width": "100%", "height": "100%"},
)
],
style={
"display": "flex",
"justify-content": "center",
"align-items": "center",
"padding": "0.0",
"margin": "auto",
},
),
html.Pre(id="annotations"),
html.Div(id="button-output"),
dcc.Store(id="store-key"),
dcc.Store(id="store-mask"),
dcc.Store(id="store-movie"),
html.Div(id="submit-output"),
]
)

@app.callback(
Output("store-key", "value"),
Input("toplevel-dropdown", "value"),
)
def store_key(value):
if value is not None:
return Serverside(value)
else:
return no_update

@app.callback(
Output("avg-image", "figure"),
Output("store-movie", "average_images"),
State("store-key", "value"),
Input("load-image-button", "n_clicks"),
Input("image-type-radio", "value"),
prevent_initial_call=True,
)
def create_figure(value, render_n_clicks, image_type):
if render_n_clicks is not None:
if image_type == "average_image":
summary_images = (
imaging.MotionCorrection.Summary & yaml.safe_load(value)
).fetch("average_image")
else:
summary_images = (
imaging.MotionCorrection.Summary & yaml.safe_load(value)
).fetch("max_proj_image")
average_images = [image.astype("float") for image in summary_images]
roi_contours = get_contours(yaml.safe_load(value), db_prefix)
logger.info("Generating figure.")
fig = px.imshow(
np.asarray(average_images),
animation_frame=0,
binary_string=True,
labels=dict(animation_frame="plane"),
)
for contour in roi_contours:
# Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates
fig.add_trace(
go.Scatter(
x=contour[:, 1], # Plotly uses x, y order for coordinates
y=contour[:, 0],
mode="lines", # Display as lines (not markers)
line=dict(color="white", width=0.5), # Set line color and width
showlegend=False, # Do not show legend for each contour
)
)
fig.update_layout(
dragmode="drawrect",
autosize=True,
height=550,
newshape=dict(opacity=0.6, fillcolor="#00a0df"),
plot_bgcolor=colors["background"],
paper_bgcolor=colors["background"],
font_color=colors["text"],
)
fig.update_annotations(bgcolor="#00a0df")
else:
return no_update
return fig, Serverside(average_images)

@app.callback(
Output("store-mask", "annotation_list"),
Input("avg-image", "relayoutData"),
prevent_initial_call=True,
)
def on_relayout(relayout_data):
if not relayout_data:
return no_update
else:
if "shapes" in relayout_data:
global shape_type
try:
shape_type = relayout_data["shapes"][-1]["type"]
return Serverside(relayout_data)
except IndexError:
return no_update
elif any(["shapes" in key for key in relayout_data]):
return Serverside(relayout_data)

@app.callback(
Output("submit-output", "children"),
Input("submit-button", "n_clicks"),
State("store-mask", "annotation_list"),
State("store-key", "value"),
)
def submit_annotations(n_clicks, annotation_list, value):
x_mask_li = []
y_mask_li = []
if n_clicks is not None:
if annotation_list:
if "shapes" in annotation_list:
logger.info("Creating Masks.")
shapes = [d["type"] for d in annotation_list["shapes"]]
for shape, annotation in zip(shapes, annotation_list["shapes"]):
mask = create_mask(annotation, shape)
y_mask_li.append(mask[0])
x_mask_li.append(mask[1])
print("Masks created")
insert_into_database(
scan, imaging, yaml.safe_load(value), x_mask_li, y_mask_li
)
else:
logger.warn(
"Incorrect annotation list format. This is a known bug. Please draw a line anywhere on the image and click `Submit Curated Masks`. It will be ignored in the final submission but will format the list correctly."
)
return no_update
else:
logger.warn("No annotations to submit.")
return no_update
else:
return no_update

return app
Loading

0 comments on commit c8805a5

Please sign in to comment.