diff --git a/src/distance_explainer.py b/src/distance_explainer.py index ff0b43f..bd8d918 100644 --- a/src/distance_explainer.py +++ b/src/distance_explainer.py @@ -127,6 +127,14 @@ def describe(x, name): input_distance = DistanceExplainer.calculate_distances(input_prediction, embedded_reference) neutral_value = np.exp(-input_distance) + # for one-sided experiments, use "meaningful" neutral value (the unperturbed distance), otherwise center on 0 + if len(lowest_mask_weights) > 0 and len(highest_mask_weights) == 0: + neutral_value = neutral_value + if len(highest_mask_weights) > 0 and len(lowest_mask_weights) == 0: + neutral_value = -neutral_value + if len(highest_mask_weights) > 0 and len(lowest_mask_weights) > 0: + neutral_value = 0 + return saliency, neutral_value @staticmethod