Skip to content

Commit

Permalink
Tidies up Images.plot_and_save()
Browse files Browse the repository at this point in the history
Closes #806

Whislt #806 was meant to address why no Matplotlib image was generated and tested/compared to a target image when
plotting without axes or colorbar I discovered this was because the image is saved with `Images.save_array_figure()`.

This in turn uses
[`matplotlib.pyplot.imsave()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imsave.html) to save the
image and the test was using
[`skimage.io.imread()`](https://scikit-image.org/docs/stable/api/skimage.io.html#skimage.io.imread) to read the image
back as an array and test the `np.sum()` and `img.shape`.

We no longer need to save arrays nor images as readable arrays (see #804 / #802) and so the need to save the image in
this manner seemed redudant. The `Images.save_figure()` already has logic which excludes the axes and scale bars (see
lines 321 of current changeset).

Further the logic for deciding what to save within `Image.plot_and_save()` seemed overly complicated and if
`Images.save_array_figure()` were being called `Images.plot_and_save()` did not return a `fig` (Matplotlib Image) that
could be saved and tested using `pytest-mpl` extension.

To this end I have...

1. Removed `Images.save_array_figure()`.
2. Tweaked the plotting options under `Images.savefig()` starting at line 321 to use a [tight
   layout](https://matplotlib.org/2.0.2/users/tight_layout_guide.html) which ensures there is no border (see note
   below).
3. Simplified the logic in `Images.plot_and_save()` controlling whether images are saved so that `ValueError` are raised
   if `Images.savefig == False` or `if Images.image_set in ["all", "core"] or self.core_set:` evaluates to
   `False`. Ensures images are explicitly closed to reduce memory usage.
4. Added tests for the raising of `ValueError` exceptions in `Images.plot_and_save()`
5. Tidied up some test names to be consistent (resulted in image name change).
6. Tidied up passing of `plotting_config` into these tests.
7. Added a `pytest-mpl` test for `test_plot_and_save_no_axes_no_colorbar()`, adding the target image as required by #806

Note on Borders

The image generated by `pytest-mpl` against which comparisons are made _does_ have a border because it is saving the
returned `fig` object itself. The actual generated image doesn't have a border (manual checks have been made).
  • Loading branch information
ns-rse committed Mar 4, 2024
1 parent 8798977 commit 07db718
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 51 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 34 additions & 25 deletions tests/test_plottingfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from skimage import io

from topostats.grains import Grains
from topostats.io import LoadScans
Expand Down Expand Up @@ -108,30 +107,19 @@ def test_save_figure(
assert isinstance(ax, Axes)


def test_save_array_figure(tmp_path: Path):
"""Tests that the image array is saved."""
rng2 = np.random.default_rng()
Images(
data=rng2.random((10, 10)),
output_dir=tmp_path,
filename="result",
).save_array_figure()
assert Path(tmp_path / "result.png").exists()


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_no_colorbar(load_scan_data: LoadScans, tmp_path: Path) -> None:
def test_plot_and_save_no_colorbar(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting without colorbar."""
plotting_config["colorbar"] = False
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
pixel_to_nm_scaling=load_scan_data.pixel_to_nm_scaling,
title="Raw Height",
colorbar=False,
axes=True,
image_set="all",
**plotting_config,
).plot_and_save()
# pytest.fail()
return fig


Expand All @@ -145,17 +133,15 @@ def test_plot_histogram_and_save(load_scan_data: LoadScans, tmp_path: Path) -> N


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_colorbar(load_scan_data: LoadScans, tmp_path: Path) -> None:
"""Test plotting with colorbar."""
def test_plot_and_save_colorbar_and_axes(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting with colorbar and axes (True in default_config.yaml)."""
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
pixel_to_nm_scaling=load_scan_data.pixel_to_nm_scaling,
title="Raw Height",
colorbar=True,
axes=True,
image_set="all",
**plotting_config,
).plot_and_save()
return fig

Expand All @@ -174,20 +160,43 @@ def test_plot_and_save_no_axes(load_scan_data: LoadScans, plotting_config: dict,
return fig


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_no_axes_no_colorbar(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting without axes and without the colourbar."""
plotting_config["axes"] = False
plotting_config["colorbar"] = False
Images(
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
title="Raw Height",
**plotting_config,
).plot_and_save()
img = io.imread(tmp_path / "01-raw_heightmap.png")
assert np.sum(img) == 1535334
assert img.shape == (64, 64, 4)
return fig


@pytest.mark.parametrize(
("save", "image_set", "core_set"),
[
pytest.param(False, "all", False, id="save option is false, no images can be plotted"),
pytest.param(True, "nothing", False, id="image_set is invalid, core_set false, no images can be plotted"),
],
)
def test_plot_and_save_false_raises_exception(
load_scan_data: LoadScans, plotting_config: dict, save: bool, image_set: str, core_set: bool, tmp_path: Path
) -> None:
"""Test ValueError is raised if invalid save/image_set/core_set values are passed."""
plotting_config["save"] = save
plotting_config["image_set"] = image_set
plotting_config["core_set"] = core_set
with pytest.raises(ValueError): # noqa: PT011
_, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
title="Raw Height",
**plotting_config,
).plot_and_save()


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
Expand Down
39 changes: 13 additions & 26 deletions topostats/plottingfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def plot_histogram_and_save(self):

def plot_and_save(self):
"""
Plot and save the images with savefig or imsave depending on config file parameters.
Plot and save the image.
Returns
-------
Expand All @@ -251,19 +251,16 @@ def plot_and_save(self):
"""
fig, ax = None, None
if self.save:
if self.image_set == "all" or self.core_set:
if self.axes or self.colorbar:
fig, ax = self.save_figure()
else:
if isinstance(self.masked_array, np.ndarray) or self.region_properties:
fig, ax = self.save_figure()
else:
self.save_array_figure()
LOGGER.info(
f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}\
| DPI: {self.savefig_dpi}"
)
return fig, ax
if self.image_set in ["all", "core"] or self.core_set:
fig, ax = self.save_figure()
LOGGER.info(
f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}"
" | DPI: {self.savefig_dpi}"
)
plt.close()
return fig, ax
raise ValueError(f"Invalid image_set ({self.image_set=}) or core_set ({self.core_set=})")
raise ValueError(f"The 'save' option is False ({self.save=}), set to 'True' to save figures.")

def save_figure(self):
"""Save figures as plt.savefig objects.
Expand Down Expand Up @@ -325,6 +322,8 @@ def save_figure(self):
if not self.axes and not self.colorbar:
plt.title("")
fig.frameon = False
plt.box(False)
plt.tight_layout()
plt.savefig(
(self.output_dir / f"{self.filename}.{self.savefig_format}"),
bbox_inches="tight",
Expand All @@ -345,18 +344,6 @@ def save_figure(self):
plt.close()
return fig, ax

def save_array_figure(self) -> None:
"""Save the image array as an image using plt.imsave()."""
plt.imsave(
(self.output_dir / f"{self.filename}.{self.savefig_format}"),
self.data,
cmap=self.cmap,
vmin=self.zrange[0],
vmax=self.zrange[1],
format=self.savefig_format,
)
plt.close()


def add_bounding_boxes_to_plot(fig, ax, shape, region_properties: list, pixel_to_nm_scaling: float) -> None:
"""Add the bounding boxes to a plot.
Expand Down

0 comments on commit 07db718

Please sign in to comment.