Skip to content

Commit

Permalink
fix one-sided mode (one of the two ranges empty)
Browse files Browse the repository at this point in the history
Added test and fixed the failures by checking whether the set is empty.
  • Loading branch information
egpbos committed Feb 8, 2024
1 parent e66394d commit a827a84
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 29 deletions.
34 changes: 22 additions & 12 deletions src/distance_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,36 @@ def explain_image_distance(self, model_or_function, input_data, embedded_referen

self.predictions = np.concatenate(batch_predictions)

lowest_distances_masks, lowest_mask_weights = self._get_lowest_distance_masks_and_weights(
embedded_reference,
self.predictions, self.masks,
self.mask_selection_range_min,
self.mask_selection_range_max)
def describe(x, name):
return f'Description of {name}\nmean:{np.mean(x)}\nstd:{np.std(x)}\nmin:{np.min(x)}\nmax:{np.max(x)}'

statistics = []

highest_distances_masks, highest_mask_weights = self._get_lowest_distance_masks_and_weights(
embedded_reference,
self.predictions, self.masks,
self.mask_selection_negative_range_min,
self.mask_selection_negative_range_max)

def describe(x, name):
return f'Description of {name}\nmean:{np.mean(x)}\nstd:{np.std(x)}\nmin:{np.min(x)}\nmax:{np.max(x)}'
if len(highest_mask_weights) > 0:
statistics.append(describe(highest_mask_weights, 'highest_mask_weights'))
unnormalized_sal_highest = np.mean(highest_distances_masks, axis=0)
else:
unnormalized_sal_highest = 0

self.statistics = '\n'.join([
describe(highest_mask_weights, 'highest_mask_weights'),
describe(lowest_mask_weights, 'lowest_mask_weights')])
lowest_distances_masks, lowest_mask_weights = self._get_lowest_distance_masks_and_weights(
embedded_reference,
self.predictions, self.masks,
self.mask_selection_range_min,
self.mask_selection_range_max)

if len(lowest_mask_weights) > 0:
statistics.append(describe(lowest_mask_weights, 'lowest_mask_weights'))
unnormalized_sal_lowest = np.mean(lowest_distances_masks, axis=0)
else:
unnormalized_sal_lowest = 0

unnormalized_sal_lowest = np.mean(lowest_distances_masks, axis=0)
unnormalized_sal_highest = np.mean(highest_distances_masks, axis=0)
self.statistics = '\n'.join(statistics)
unnormalized_sal = unnormalized_sal_lowest - unnormalized_sal_highest

saliency = unnormalized_sal
Expand Down
Binary file not shown.
Binary file not shown.
70 changes: 53 additions & 17 deletions tests/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,81 @@
import os
from typing import Callable
import numpy as np
from numpy.typing import ArrayLike
import pytest
from distance_explainer import DistanceExplainer
from tests.config import get_default_config
from tests.config import get_default_config, Config
import dataclasses

DUMMY_EMBEDDING_DIMENSIONALITY = 10


def test_dummy_data_exact_expected_output(set_all_the_seeds: Callable, dummy_model: Callable):
"""Code output should be identical to recorded output."""
[expected_saliency, expected_value] = np.load(
'./tests/test_data/test_dummy_data_exact_expected_output.npz').values()
@pytest.fixture(autouse=True)
def set_all_the_seeds(seed_value=0):
"""Set all necessary seeds."""
os.environ['PYTHONHASHSEED'] = str(seed_value)
np.random.seed(seed_value)

@pytest.fixture()
def dummy_model() -> Callable:
"""Get a dummy model that returns a random embedding for every input in a batch."""
return lambda x: np.random.randn(x.shape[0], DUMMY_EMBEDDING_DIMENSIONALITY)


@pytest.fixture
def dummy_data() -> tuple[ArrayLike, ArrayLike]:
embedded_reference = np.random.randn(1, DUMMY_EMBEDDING_DIMENSIONALITY)
input_arr = np.random.random((32, 32, 3))
return embedded_reference, input_arr


config = get_default_config()
def get_explainer(config: Config, axis_labels={2: 'channels'}, preprocess_function=None):
explainer = DistanceExplainer(mask_selection_range_max=config.mask_selection_range_max,
mask_selection_range_min=config.mask_selection_range_min,
mask_selection_negative_range_max=config.mask_selection_negative_range_max,
mask_selection_negative_range_min=config.mask_selection_negative_range_min,
n_masks=config.number_of_masks,
axis_labels={2: 'channels'},
preprocess_function=None,
axis_labels=axis_labels,
preprocess_function=preprocess_function,
feature_res=config.feature_res,
p_keep=config.p_keep)
return explainer


def test_distance_explainer(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable):
"""Code output should be identical to recorded output."""
embedded_reference, input_arr = dummy_data
explainer = get_explainer(get_default_config())

saliency, value = explainer.explain_image_distance(dummy_model, input_arr, embedded_reference)

assert saliency.shape == (1,) + input_arr.shape[:2] + (1,) # Has correct shape

expected_saliency, expected_value = np.load('./tests/test_data/test_dummy_data_exact_expected_output.npz').values()
assert np.allclose(expected_saliency, saliency) # Has correct value
assert np.allclose(expected_value, value) # Has correct value


@pytest.fixture()
def dummy_model() -> Callable:
"""Get a dummy model that returns a random embedding for every input in a batch."""
return lambda x: np.random.randn(x.shape[0], DUMMY_EMBEDDING_DIMENSIONALITY)
@pytest.mark.parametrize("empty_side,expected_tag",
[({"mask_selection_range_max": 0.}, "pos_empty"),
({"mask_selection_negative_range_min": 1.}, "neg_empty")])
def test_distance_explainer_one_sided(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable,
empty_side: dict[str, float],
expected_tag: str):
"""Code output should be identical to recorded output."""
embedded_reference, input_arr = dummy_data

config = dataclasses.replace(get_default_config(), **empty_side)
explainer = get_explainer(config)

@pytest.fixture()
def set_all_the_seeds(seed_value=0):
"""Set all necessary seeds."""
os.environ['PYTHONHASHSEED'] = str(seed_value)
np.random.seed(seed_value)
saliency, value = explainer.explain_image_distance(dummy_model, input_arr, embedded_reference)

assert saliency.shape == (1,) + input_arr.shape[:2] + (1,) # Has correct shape

# np.savez(f'./tests/test_data/test_dummy_data_exact_expected_output_{expected_tag}.npz',
# expected_saliency=saliency, expected_value=value)
expected_saliency, expected_value = np.load(f'./tests/test_data/test_dummy_data_exact_expected_output_{expected_tag}.npz').values()
assert np.allclose(expected_saliency, saliency) # Has correct value
assert np.allclose(expected_value, value) # Has correct value

0 comments on commit a827a84

Please sign in to comment.