Skip to content

Commit

Permalink
Merge pull request #122 from TimMonko/image_overview_wrap_single
Browse files Browse the repository at this point in the history
Add wrapping to image overview for only one image set
  • Loading branch information
TimMonko authored Dec 10, 2024
2 parents 991e193 + d1d7f4e commit 8cde771
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 27 deletions.
69 changes: 68 additions & 1 deletion src/napari_ndev/_tests/test_image_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,46 @@
from matplotlib_scalebar.scalebar import ScaleBar

from napari_ndev import nImage
from napari_ndev.image_overview import ImageOverview, image_overview
from napari_ndev.image_overview import (
ImageOverview,
_add_scalebar,
image_overview,
)


def test_image_overview_wrap():
# create a random numpy array of size 100 x 100
data = np.random.rand(100, 100)
# create a dictionary with the image data
five_image_set = {
'image': [data, data, data, data, data],
'title': ['Image 1', 'Image 2', 'Image 3', 'Image 4', 'Image 5'],
}
fig = image_overview(five_image_set)
assert isinstance(fig, plt.Figure)
assert np.array_equal(
fig.get_size_inches(), np.array([9, 6]) # 3 columns * 3 width, 2 rows * 3 height
)
assert len(fig.axes) == 6
assert fig.axes[0].get_title() == 'Image 1'
assert not fig.axes[5].get_title()
assert not fig.axes[5].get_images()

def test_image_overview_nowrap():
# create a random numpy array of size 100 x 100
data = np.random.rand(100, 100)
# create a dictionary with the image data
five_image_set = {
'image': [data, data, data],
'title': ['Image 1', 'Image 2', 'Image 3'],
}
fig = image_overview(five_image_set)
assert isinstance(fig, plt.Figure)
assert np.array_equal(
fig.get_size_inches(), np.array([9, 3]) # 3 columns * 3 width, 2 rows * 3 height
)
assert len(fig.axes) == 3
assert fig.axes[0].get_title() == 'Image 1'


@pytest.fixture
Expand Down Expand Up @@ -73,6 +112,34 @@ def test_image_overview_plot_title(image_and_label_sets):
assert isinstance(fig, plt.Figure)
assert fig._suptitle.get_text() == test_title

def test_add_scalebar_float(image_and_label_sets):
fig = image_overview(image_and_label_sets)
_add_scalebar(fig.axes[0], 0.5)

assert isinstance(fig, plt.Figure)
scalebar = [
child for child in fig.axes[0].get_children()
if isinstance(child, ScaleBar)
]
assert len(scalebar) == 1

def test_add_scalebar_dict(image_and_label_sets):
fig = image_overview(image_and_label_sets)
scalebar_dict = {
'dx': 0.25,
'units': 'mm',
'location': 'upper right',
'badkey': 'badvalue',
}
_add_scalebar(fig.axes[0], scalebar_dict)

assert isinstance(fig, plt.Figure)
scalebar = [
child for child in fig.axes[0].get_children()
if isinstance(child, ScaleBar)
]
assert len(scalebar) == 1

def test_image_overview_scalebar_float(image_and_label_sets):
fig = image_overview(image_and_label_sets, scalebar=0.5)

Expand Down
83 changes: 57 additions & 26 deletions src/napari_ndev/image_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import inspect

import matplotlib.pyplot as plt
import numpy as np
import stackview


Expand Down Expand Up @@ -126,8 +127,23 @@ def image_overview(
# convert input to list if needed
image_sets = [image_sets] if isinstance(image_sets, dict) else image_sets
# create the subplot grid
num_rows = len(image_sets)
num_columns = max([len(image_set['image']) for image_set in image_sets])

# if only one image set, wrap rows and columns to get a nice aspect ratio
if len(image_sets) == 1:
num_images = len(image_sets[0]['image'])

if num_images <= 3:
num_columns = num_images
num_rows = 1
# wrap so it is roughly a square aspect ratio
else:
num_columns = int(np.ceil(np.sqrt(num_images)))
num_rows = int(np.ceil(num_images / num_columns))

if len(image_sets) > 1:
num_rows = len(image_sets)
num_columns = max([len(image_set['image']) for image_set in image_sets])

# multiply scale of plot by number of columns and rows
fig, axs = plt.subplots(
num_rows,
Expand All @@ -141,10 +157,19 @@ def image_overview(
axs = [[ax] for ax in axs]

# iterate through the image sets
for row, image_set in enumerate(image_sets):
for col, _image in enumerate(image_set['image']):
for image_set_idx, image_set in enumerate(image_sets):
for image_idx, _image in enumerate(image_set['image']):

# calculate the correct row and column for the subplot
if len(image_sets) == 1:
row = image_idx // num_columns
col = image_idx % num_columns
if len(image_sets) > 1:
row = image_set_idx
col = image_idx

# create a dictionary from the col-th values of each key
image_dict = {key: value[col] for key, value in image_set.items()}
image_dict = {key: value[image_idx] for key, value in image_set.items()}

# turn off the subplot and continue if there is no image
if image_dict.get('image') is None:
Expand All @@ -161,30 +186,36 @@ def image_overview(

# add scalebar, if dict is present
if scalebar is not None:
from matplotlib_scalebar.scalebar import ScaleBar
_add_scalebar(axs[row][col], scalebar)

# get a default dictionary to pass to sb_dict, and only overwrite the keys that are present in scalebar
sb_dict = {
'dx': 1,
'units': 'um',
'frameon': True,
'location': 'lower right',
}

# if scalebar is just float, convert to dict
if isinstance(scalebar, float):
sb_valid_dict = {'dx': scalebar}

# if scalebar is dict, only keep the keys that are valid for ScaleBar
if isinstance(scalebar, dict):
sb_valid_dict = {k: v for k, v in scalebar.items() if k in inspect.signature(ScaleBar).parameters}

# update key: values in sb_dict with values from scalebar if key is present
sb_dict.update(sb_valid_dict)

axs[row][col].add_artist(ScaleBar(**sb_dict))
# remove empty subplots
for ax in fig.get_axes():
ax.axis('off') if not ax.get_images() else None

plt.suptitle(fig_title, fontsize=16)
plt.tight_layout(pad=0.3)

return fig

def _add_scalebar(ax, scalebar):
from matplotlib_scalebar.scalebar import ScaleBar

# get a default dictionary to pass to sb_dict,
# and only overwrite the keys that are present in scalebar
sb_dict = {
'dx': 1,
'units': 'um',
'frameon': True,
'location': 'lower right',
}

# if scalebar is just float, convert to dict
if isinstance(scalebar, float):
sb_dict = {'dx': scalebar}
# if scalebar is dict, only keep the keys that are valid for ScaleBar
elif isinstance(scalebar, dict):
sb_valid_dict = {k: v for k, v in scalebar.items() if k in inspect.signature(ScaleBar).parameters}
# update key: values in sb_dict with values from scalebar if key is present
sb_dict.update(sb_valid_dict)

ax.add_artist(ScaleBar(**sb_dict))

0 comments on commit 8cde771

Please sign in to comment.