Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faithfulness metric #455

Merged
merged 51 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8383b74
code copied over
pat-alt May 23, 2024
640bd52
formatter
pat-alt May 23, 2024
cd01b30
added plausibility metric
pat-alt May 24, 2024
edd2907
removed a bug in the find_potential_neighbours method
pat-alt May 24, 2024
1236b2d
formatting
pat-alt May 24, 2024
f444798
done in principle, but still lacking documentation etc.
pat-alt May 24, 2024
153d7f0
goddamn
pat-alt May 24, 2024
1af160d
uff
pat-alt Jun 4, 2024
885e8cc
not working
pat-alt Jun 4, 2024
65b485e
ok ok
pat-alt Jun 4, 2024
4b070c5
getting there
pat-alt Jun 5, 2024
b02d769
niiice
pat-alt Jun 5, 2024
bdc006d
now we're cooking
pat-alt Jun 5, 2024
e0e5dc2
now working on docs
pat-alt Jun 5, 2024
1ca9b38
still easy to find examples where the posterior is clearly degenerate
pat-alt Jun 5, 2024
4f5275e
same overshooting related things still observed for some models
pat-alt Jun 5, 2024
cffa9f0
could consider running proper SGLD only for final posterior construction
pat-alt Jun 5, 2024
1e4c065
formatter
pat-alt Jun 5, 2024
cb3fb20
removed explicit link of energy sampler to dataset, now just model
pat-alt Jun 6, 2024
2af8cea
adding burnin-removal
pat-alt Jun 6, 2024
2c359c6
work on making energy sampler part of model
pat-alt Jun 7, 2024
bb78959
fixed issue with artifacts
pat-alt Jun 7, 2024
1953936
small fix
pat-alt Jun 7, 2024
ebcf646
part of model now
pat-alt Jun 10, 2024
1cc1921
work on tutorial
pat-alt Jun 10, 2024
ec3e739
formatter
pat-alt Jun 10, 2024
f47910f
omg
pat-alt Jun 10, 2024
1835e23
ok think we're nearly there now
pat-alt Jun 10, 2024
ca67bde
finally finally
pat-alt Jun 11, 2024
46d786c
omnibus PR
pat-alt Jun 11, 2024
cc100a2
uff
pat-alt Jun 12, 2024
509fc05
soon
pat-alt Jun 12, 2024
6ff0f48
soon
pat-alt Jun 12, 2024
4ad69e1
regularization currently seems to have opposite of intended effect lo…
pat-alt Jun 13, 2024
138807c
decay schedule sorted
pat-alt Jun 13, 2024
89c44dd
more work on tutorial
pat-alt Jun 13, 2024
e984afd
uf
pat-alt Jun 14, 2024
a20ce64
uff
pat-alt Jun 14, 2024
7d50a87
buf
pat-alt Jun 14, 2024
74ec2da
merged main
pat-alt Sep 9, 2024
161eb14
now depending on EnergyBasedSamplers
pat-alt Sep 9, 2024
984ab0e
fixed manifest
pat-alt Sep 9, 2024
b8c9e25
getting rid of manifests
pat-alt Sep 9, 2024
2f3c768
gitignoring manifests
pat-alt Sep 9, 2024
4167b1f
trying to remove bug in LA extension
pat-alt Sep 9, 2024
c43f4c0
manifest issue
pat-alt Sep 9, 2024
7eab3ba
trying to fix error
pat-alt Sep 9, 2024
c03a041
fixing error with LA
pat-alt Sep 10, 2024
067cac6
tutorial almost done just need to render
pat-alt Sep 10, 2024
f42853d
not sure why rendering doens't work
pat-alt Sep 10, 2024
6be80a7
done
pat-alt Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

### Added

- Added support for an energy constraint as in Altmeyer et al. ([2024](https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5)). This is the first step towards adding functionality for ECCCo. [387]
- Added new evaluation metric to measure unfaithfulness as in Altmeyer et al. ([2024](https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5)). [#454]
- Added support for an energy constraint as in Altmeyer et al. ([2024](https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5)). This is the first step towards adding functionality for ECCCo. [#387]

### Removed

- Removed bug in `find_potential_neighbours` method. [#454]

## Version [1.1.6] - 2024-05-19

Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[weakdeps]
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
NeuroTreeModels = "1db4e0a5-a364-4b0c-897c-2bd5a4a3a1f2"

[extensions]
DecisionTreeExt = "DecisionTree"
JEMExt = "JointEnergyModels"
LaplaceReduxExt = "LaplaceRedux"
NeuroTreeExt = "NeuroTreeModels"

Expand All @@ -44,6 +46,7 @@ DataFrames = "1"
DecisionTree = "0.12.3, 0.12.4"
Distributions = "0.25.97"
Flux = "0.12, 0.13, 0.14"
JointEnergyModels = "0.1.3"
LaplaceRedux = "0.1.4, 0.2, 1.0"
LazyArtifacts = "1"
LinearAlgebra = "1.6, 1.7, 1.8, 1.9, 1.10"
Expand All @@ -68,9 +71,10 @@ julia = "1.6, 1.7, 1.8, 1.9, 1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
NeuroTreeModels = "1db4e0a5-a364-4b0c-897c-2bd5a4a3a1f2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "DecisionTree", "LaplaceRedux", "NeuroTreeModels", "Test"]
test = ["Aqua", "DecisionTree", "JointEnergyModels", "LaplaceRedux", "NeuroTreeModels", "Test"]
9 changes: 9 additions & 0 deletions ext/JEMExt/JEMExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module JEMExt

using CounterfactualExplanations
using CounterfactualExplanations.Models: Models
using JointEnergyModels: JointEnergyModels

include("jem.jl")

end
108 changes: 108 additions & 0 deletions ext/JEMExt/jem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using CounterfactualExplanations.Models
using Distributions
using Flux
using MLJBase
using Tables: columntable
using TaijaBase: TaijaBase

"""
CounterfactualExplanations.JEM(
model::JointEnergyModels.JointEnergyClassifier; likelihood::Symbol=:classification_binary
)

Outer constructor for a neural network with Laplace Approximation from `LaplaceRedux.jl`.
"""
function CounterfactualExplanations.JEM(
model::JointEnergyModels.JointEnergyClassifier;
likelihood::Symbol=:classification_binary,
)
return Models.Model(model, CounterfactualExplanations.JEM(); likelihood=likelihood)
end

"""
(M::Model)(data::CounterfactualData, type::JEM; kwargs...)

Constructs a differentiable tree-based model for the given data.
"""
function (M::Models.Model)(
data::CounterfactualData, type::CounterfactualExplanations.JEM; kwargs...
)
n = CounterfactualExplanations.DataPreprocessing.outdim(data)
𝒟y = Categorical(ones(n) ./ n)
𝒟x = Normal()
input_dim = size(data.X,1)
sampler = JointEnergyModels.ConditionalSampler(𝒟x, 𝒟y; input_size=(input_dim,), batch_size=50)
pat-alt marked this conversation as resolved.
Show resolved Hide resolved
model = JointEnergyModels.JointEnergyClassifier(sampler; kwargs...)
M = CounterfactualExplanations.JEM(model; likelihood=data.likelihood)
return M
end

"""
train(M::JEM, data::CounterfactualData; kwargs...)

Fits the model `M` to the data in the `CounterfactualData` object.
This method is not called by the user directly.

# Arguments
- `M::JEM`: The wrapper for an JEM model.
- `data::CounterfactualData`: The `CounterfactualData` object containing the data to be used for training the model.

# Returns
- `M::JEM`: The fitted JEM model.
"""
function Models.train(
M::Models.Model,
type::CounterfactualExplanations.JEM,
data::CounterfactualData;
kwargs...,
)
X, y = CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mlj(data)
if M.likelihood ∉ [:classification_multi, :classification_binary]
y = float.(y.refs)
end
X = columntable(X)
mach = MLJBase.machine(M.model, X, y)
MLJBase.fit!(mach)
Flux.testmode!(mach.fitresult[1])
M.fitresult = mach.fitresult
return M
end

"""
Models.logits(M::JEM, X::AbstractArray)

Calculates the logit scores output by the model M for the input data X.

# Arguments
- `M::JEM`: The model selected by the user. Must be a model from the MLJ library.
- `X::AbstractArray`: The feature vector for which the logit scores are calculated.

# Returns
- `logits::Matrix`: A matrix of logits for each output class for each data point in X.

# Example
logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x
"""
function Models.logits(
M::Models.Model, type::CounterfactualExplanations.JEM, X::AbstractArray
)
return M.fitresult[1](X)
end

"""
Models.probs(
M::Models.Model,
type::CounterfactualExplanations.JEM,
X::AbstractArray,
)

Overloads the [probs](@ref) method for NeuroTree models.
"""
function Models.probs(
M::Models.Model, type::CounterfactualExplanations.JEM, X::AbstractArray
)
if ndims(X) == 1
X = X[:, :] # account for 1-dimensional inputs
end
return softmax(logits(M, X))
end
3 changes: 0 additions & 3 deletions src/base_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ abstract type AbstractCounterfactualExplanation end
"Base type for models."
abstract type AbstractModel end

"Alias for `AbstractModel` (deprecated)."
const AbstractFittedModel = AbstractModel

"Treat `AbstractModel` as scalar when broadcasting."
Base.broadcastable(model::AbstractModel) = Ref(model)

Expand Down
14 changes: 7 additions & 7 deletions src/counterfactuals/utils.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
output_dim(ce::CounterfactualExplanation)
outdim(ce::CounterfactualExplanation)

A convenience method that returns the output dimension of the predictive model.
"""
function output_dim(ce::CounterfactualExplanation)
return size(Models.probs(ce.M, ce.x))[1]
function outdim(ce::CounterfactualExplanation)
return CounterfactualExplanations.DataPreprocessing.outdim(ce.data)
end

"""
Expand Down Expand Up @@ -77,12 +77,12 @@ end

Finds potential neighbors for the selected factual data point.
"""
function find_potential_neighbours(ce::AbstractCounterfactualExplanation)
function find_potential_neighbours(ce::AbstractCounterfactualExplanation, n::Int=1000)
nobs = size(ce.data.X, 2)
data = DataPreprocessing.subsample(ce.data, minimum([nobs, 1000]))
data = DataPreprocessing.subsample(ce.data, minimum([nobs, n]))
ids = findall(Models.predict_label(ce.M, data) .== ce.target)
n_candidates = minimum([size(ce.data.y, 2), 1000])
candidates = DataPreprocessing.select_factual(ce.data, rand(ids, n_candidates))
n_candidates = minimum([size(ce.data.y, 2), n])
candidates = DataPreprocessing.select_factual(data, rand(ids, n_candidates))
potential_neighbours = reduce(hcat, map(x -> x[1], collect(candidates)))
return potential_neighbours
end
9 changes: 9 additions & 0 deletions src/data_preprocessing/counterfactual_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,12 @@ function transformable_features(
# Returns indices of columns that have varying values:
return counterfactual_data.features_continuous[idx_not_all_equal]
end

"""
outdim(data::CounterfactualData)

Returns the number of output classes.
"""
function outdim(data::CounterfactualData)
return length(data.y_levels)
end
1 change: 1 addition & 0 deletions src/evaluation/Evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Statistics

include("benchmark.jl")
include("evaluate.jl")
include("utils.jl")
include("measures.jl")

export Benchmark, benchmark, evaluate, default_measures
Expand Down
27 changes: 27 additions & 0 deletions src/evaluation/faithfulness/faithfulness.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
include("utils.jl")

using CounterfactualExplanations.Objectives

function faithfulness(ce::AbstractCounterfactualExplanation; kwrgs...)
return faithfulness(ce, distance_from_posterior; kwrgs...)
end

"""
faithfulness(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
λ::AbstractFloat=1.0,
kwrgs...,
)

Computes the faithfulness of a counterfactual explanation based on the distance from the target. Specifically, the function computes the faithfulness as the exponential decay of the distance from the samples drawn from the learned posterior of the model with rate parameter `λ`. Larger values of `λ` result in a faster decay of the faithfulness. If you input data is not normalized, you may want to adjust the rate parameter `λ` accordingly, e.g. higher values for larger feature scales.
"""
function faithfulness(
ce::CounterfactualExplanation,
fun::typeof(distance_from_posterior);
λ::AbstractFloat=0.5,
kwrgs...,
)
Δ = fun(ce; kwrgs...)
return exp_decay(Δ, λ)
end
Loading
Loading