diff --git a/scripts/ablate_attention.py b/scripts/ablate_attention.py index 12c0524..3d03ff0 100644 --- a/scripts/ablate_attention.py +++ b/scripts/ablate_attention.py @@ -183,7 +183,7 @@ def main(args: argparse.Namespace): # Heatmap of Dice similarities fig, axis = plt.subplots() - im = axis.imshow(dice_matrix, cmap="plasma_r") + im = axis.imshow(dice_matrix, cmap="plasma_r", vmin=0, vmax=1) axis.set_xlabel("Removed attention block(s)") axis.set_ylabel("Removed attention block(s)")