Skip to content

Commit

Permalink
Merge pull request #305 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…fix-channels-first-issue

Fix channels first issue
  • Loading branch information
annahedstroem authored Oct 23, 2023
2 parents 72a519b + 30805d8 commit 1f48195
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 30 deletions.
29 changes: 24 additions & 5 deletions quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ def generate_tf_explanation(
if not isinstance(targets, np.ndarray):
targets = np.array([targets])

channel_first = kwargs.get("channel_first", infer_channel_first(inputs))
channel_first = (
kwargs["channel_first"]
if "channel_first" in kwargs
else infer_channel_first(inputs)
)
inputs = make_channel_last(inputs, channel_first)

explanation: np.ndarray = np.zeros_like(inputs)
Expand Down Expand Up @@ -477,7 +481,11 @@ def generate_captum_explanation(
Returns np.ndarray of same shape as inputs.
"""

channel_first = kwargs.get("channel_first", infer_channel_first(inputs))
channel_first = (
kwargs["channel_first"]
if "channel_first" in kwargs
else infer_channel_first(inputs)
)

softmax = kwargs.get("softmax", None)
if softmax is not None:
Expand Down Expand Up @@ -540,22 +548,28 @@ def f_reduce_axes(a):
method = constants.DEPRECATED_XAI_METHODS_CAPTUM[method]

if method in ["GradientShap", "DeepLift", "DeepLiftShap"]:
baselines = (
kwargs["baseline"] if "baseline" in kwargs else torch.zeros_like(inputs)
)
attr_func = eval(method)
explanation = f_reduce_axes(
attr_func(model, **xai_lib_kwargs).attribute(
inputs=inputs,
target=targets,
baselines=kwargs.get("baseline", torch.zeros_like(inputs)),
baselines=baselines,
)
)

elif method == "IntegratedGradients":
baselines = (
kwargs["baseline"] if "baseline" in kwargs else torch.zeros_like(inputs)
)
attr_func = eval(method)
explanation = f_reduce_axes(
attr_func(model, **xai_lib_kwargs).attribute(
inputs=inputs,
target=targets,
baselines=kwargs.get("baseline", torch.zeros_like(inputs)),
baselines=baselines,
n_steps=10,
method="riemann_trapezoid",
)
Expand Down Expand Up @@ -736,8 +750,13 @@ def generate_zennit_explanation(
"""

channel_first = kwargs.get("channel_first", infer_channel_first(inputs))
channel_first = (
kwargs["channel_first"]
if "channel_first" in kwargs
else infer_channel_first(inputs)
)
softmax = kwargs.get("softmax", None)

if softmax is not None:
warnings.warn(
f"Softmax argument has been passed to the explanation function. Different XAI "
Expand Down
1 change: 1 addition & 0 deletions quantus/helpers/model/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def shape_input(
"""
if channel_first is None:
channel_first = utils.infer_channel_first(x)

# Expand first dimension if this is just a single instance.
if not batched:
x = x.reshape(1, *shape)
Expand Down
26 changes: 1 addition & 25 deletions tests/metrics/test_localisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,18 +625,6 @@ def test_pointing_game(
},
0.0,
),
(
lazy_fixture("load_1d_1ch_conv_model"),
lazy_fixture("none_in_gt_zeros_1d_3ch"),
{
"init": {
"k": 200,
"disable_warnings": True,
"display_progressbar": False,
},
},
0.38, # TODO: verify correctness
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("none_in_gt_zeros_2d_3ch"),
Expand All @@ -649,18 +637,6 @@ def test_pointing_game(
},
{"min": 0.1, "max": 0.25},
),
(
lazy_fixture("load_1d_1ch_conv_model"),
lazy_fixture("half_in_gt_zeros_1d_3ch"),
{
"init": {
"k": 50,
"disable_warnings": True,
"display_progressbar": False,
},
},
0.9800000000000001, # TODO: verify correctness
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("half_in_gt_zeros_2d_3ch"),
Expand All @@ -683,7 +659,7 @@ def test_pointing_game(
"display_progressbar": False,
},
},
0.4, # TODO: verify correctness
0.4,
),
(
lazy_fixture("load_mnist_model"),
Expand Down

0 comments on commit 1f48195

Please sign in to comment.