Skip to content

Commit

Permalink
float dice score metric + also plot lots of stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Lane committed Nov 16, 2024
1 parent 8be7928 commit c1e5e70
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
30 changes: 29 additions & 1 deletion fishjaw/images/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,35 @@ def dice_score(truth: np.ndarray, pred: np.ndarray) -> float:
volume1 = np.sum(truth)
volume2 = np.sum(pred)

# Both arrays are empty, consider Dice score as 1
# Both arrays are empty, consider Dice score to be 1
if volume1 + volume2 == 0:
warnings.warn("Both arrays are empty, returning Dice score of 1")
return 1.0

return 2.0 * intersection / (volume1 + volume2)


def float_dice(arr1: np.ndarray, arr2: np.ndarray) -> float:
"""
Calculate the Dice score between two float arrays.
:param arr1: float array.
:param arr2: float array.
:returns: "Dice" score - actually just the product divided by the sum
:raises: ValueError if the shapes of the arrays do not match.
"""
if arr1.shape != arr2.shape:
raise ValueError(
f"Shape mismatch: {arr1.shape=} and {arr2.shape=}"
)

intersection = np.sum(arr1 * arr2)
volume1 = np.sum(arr1)
volume2 = np.sum(arr2)

# Both arrays are empty, consider Dice score to be 1
if volume1 + volume2 == 0:
warnings.warn("Both arrays are empty, returning Dice score of 1")
return 1.0
Expand Down
53 changes: 41 additions & 12 deletions scripts/ablate_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Perform an ablation study on the requested data.
Perform an ablation study on the requested data, assuming the model is a monai
AttentionUnet (which it probably is?).
This will run the model on some data with and without the attention mechanism enabled-
we replace the attention block with the identity function, effectively disabling it.
Expand All @@ -11,14 +12,17 @@
import argparse

import torch
import numpy as np
import torchio as tio
from tqdm import tqdm
import matplotlib.pyplot as plt
from monai.networks.nets.attentionunet import AttentionBlock, AttentionUnet

from fishjaw.model import model, data
from fishjaw.util import files
from fishjaw.inference import read, mesh
from fishjaw.visualisation import plot_meshes
from fishjaw.images import metrics


def ablated_psi(module, input_, output):
Expand Down Expand Up @@ -56,7 +60,7 @@ def _plot(
inference_subject: tio.Subject,
ax: tuple[plt.Axes, plt.Axes, plt.Axes],
indices: tuple[int] | None = None,
) -> None:
) -> np.ndarray:
"""
Possibly disable some attention mechanism(s), Run inference
and plot the results on the provided axes
Expand All @@ -83,6 +87,8 @@ def _plot(
for hook in hooks:
hook.remove()

return prediction


def main(args: argparse.Namespace):
"""
Expand Down Expand Up @@ -116,28 +122,51 @@ def main(args: argparse.Namespace):
else read.test_subject(config["model_path"])
)

# the number of attention layers
for idx in range(5):
fig, axes = plt.subplots(
2, 3, figsize=(15, 10), subplot_kw={"projection": "3d"}
)

# Choose which layers(and combinations of indices) to ablate
n_attention_blocks = sum(
1 for module in net.modules() if isinstance(module, AttentionBlock)
)
to_ablate = [(i,) for i in range(n_attention_blocks)] + [ # Single layers
# Pairs
(i, j)
for i in range(n_attention_blocks)
for j in range(n_attention_blocks)
if i != j
]

# Create figures for showing the projections with and without attention
projection_fig_ax = [
plt.subplots(2, 3, figsize=(15, 10), subplot_kw={"projection": "3d"})
for _ in to_ablate
]

for indices, (fig, axes) in tqdm(
zip(to_ablate, projection_fig_ax), total=len(to_ablate)
):
# Plot with attention
_plot(net, config, inference_subject, axes[0])
with_attention = _plot(net, config, inference_subject, axes[0])
axes[0, 0].set_zlabel("With attention")

# Plot without attention
_plot(net, config, inference_subject, axes[0], indices=(idx,))
without_attention = _plot(
net, config, inference_subject, axes[1], indices=indices
)
axes[1, 0].set_zlabel("Without attention")

# Find the Dice similarity between them
dice = metrics.float_dice(with_attention, without_attention)

fig.suptitle(
f"Ablation study - {'test fish' if args.subject is None else f'subject {args.subject}'}"
f"\nDice similarity: {dice:.4f}"
f"\nRemoved attention blocks: {indices}"
"\nThresholded at 0.5"
)

fig.tight_layout()
fig.savefig(
out_dir
/ f"ablation_{'test' if args.subject is None else args.subject}_{idx}.png"
out_dir / f"ablation_{'test' if args.subject is None else args.subject}"
f"_{'_'.join(str(index) for index in indices)}.png"
)
plt.close(fig)

Expand Down

0 comments on commit c1e5e70

Please sign in to comment.