Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory footprint #503

Merged
merged 53 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
e488ad8
trying to fix bug in plausibility metrix
pat-alt Nov 13, 2024
9baf890
added additional distance measures
pat-alt Nov 15, 2024
4739326
updated deps and added more measures
pat-alt Nov 15, 2024
b9f1f86
damn
pat-alt Nov 15, 2024
d322ef4
small change to speed up benchmark function
pat-alt Nov 15, 2024
7b15cab
adding option to avoid concatenating benchmark results from multiple …
pat-alt Nov 18, 2024
f772958
also making sure interim results for different runs are stored
pat-alt Nov 18, 2024
111bd80
typo
pat-alt Nov 18, 2024
95a33bf
buff
pat-alt Nov 18, 2024
1d5dac3
adds a function to concatenate benchmarks
pat-alt Nov 18, 2024
de1225d
updated changelog
pat-alt Nov 18, 2024
eb4d2d8
fixed small bug
pat-alt Nov 18, 2024
f8c2462
uh
pat-alt Nov 18, 2024
a02fc28
goddamn
pat-alt Nov 18, 2024
08e50cc
bleh
pat-alt Nov 18, 2024
4a74e94
added functionality to set serialization state
pat-alt Nov 19, 2024
031bcb9
ufff
pat-alt Nov 19, 2024
30d244c
uh
pat-alt Nov 19, 2024
331c345
Added functionality to explicitly specify what transformation of the …
pat-alt Nov 19, 2024
0f203f9
changelog
pat-alt Nov 19, 2024
9ba160d
this should be helpful
pat-alt Nov 19, 2024
83f3d6f
hm
pat-alt Dec 11, 2024
2ea1c78
pls
pat-alt Dec 11, 2024
866c631
pls
pat-alt Dec 11, 2024
497ef67
this is so bloody frustrating
pat-alt Dec 11, 2024
68d3072
so close now
pat-alt Dec 11, 2024
fcd9721
still very puzzled
pat-alt Dec 12, 2024
3d93bcc
huh
pat-alt Dec 12, 2024
6a30fac
kill me
pat-alt Dec 13, 2024
c75e42f
trying to free up memory a bit
pat-alt Dec 14, 2024
88289d0
oh yehhh
pat-alt Dec 14, 2024
5a0256e
hmm
pat-alt Dec 16, 2024
fb8df41
hmm
pat-alt Dec 16, 2024
c3f24fb
hmm
pat-alt Dec 16, 2024
8f6b585
respecifying the vertical_split argument to improve intuition
pat-alt Dec 16, 2024
99907c4
log thing
pat-alt Dec 16, 2024
08eda17
storing input data as subarray
pat-alt Dec 19, 2024
260919a
FlattenedCE added
pat-alt Dec 19, 2024
84a5020
o
pat-alt Dec 19, 2024
f65d21d
okkk
pat-alt Dec 19, 2024
c62de6c
come on
pat-alt Dec 19, 2024
5711690
bloody namespacing
pat-alt Dec 19, 2024
91b813e
why is this fialign locally
pat-alt Dec 19, 2024
19cb447
puzzling
pat-alt Dec 19, 2024
e60f448
come on
pat-alt Dec 19, 2024
c0bb4cf
that should do it
pat-alt Dec 19, 2024
562e6da
test ratio
pat-alt Dec 19, 2024
7ab7f18
gotmat
pat-alt Dec 19, 2024
9c36013
more on test ratio
pat-alt Dec 19, 2024
a9423ea
formatter
pat-alt Dec 19, 2024
bf22e1a
fixing errors
pat-alt Dec 19, 2024
a198a48
fixing errors
pat-alt Dec 19, 2024
4ee73a4
lets go
pat-alt Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,23 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v1.1.1].

## Version [1.3.6]
## Version [1.4.0]

### Added

- Adds new `FlattenedCE` struct and conversion function `flatten(ce::CounterfactualExplanation)::FlattenedCE` for flattening a CounterfactualExplanation object. In the short term, this can be useful for compact storage or transmission of explanations. In the long term, we may consider using the flattened representation as much as possible to optimize performance. [#502]
- Also added `unflatten` function to convert a `FlattenedCE` object back to its original `CounterfactualExplanation` form. This is used in benchmarking, where flattened objects are used in the first parallelization (generating counterfactuals) and full objects are used for evaluation. This is a temporary solution until we address the fact that downstream `Evaluation` functions currently expect the full `CounterfactualExplanation` form. [#502]
- Added additional aliases for penalties including `distance_cosine`.
- Added `concatenate_output::Bool=true` keyword argument to `benchmark` function. This allows users to suppress concatenation of output in benchmarking (`concatenate_output=false`), which can be useful when memory usage is critical.
- Added a `concatenate_benchmarks(storage_path::String)` function that can be used to concatenate multiple benchmark results into a single file.
- Added functionality to set global serialization state. This is useful for suppressing serialization on non-root ranks in parallel computations.
- Added functionality to explicitly specify what transformation of the `CounterfactualExplanation` object should be stored in evaluation data frames.

### Changed

- `Benchmark` objects now have an additional field `counterfactuals` to store a `DataFrame` containing the sample ID column `:sample` and then counterfactuals `:ce`.

## Version [1.3.6] - 2024-11-08

### Changed

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CounterfactualExplanations"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
authors = ["Patrick Altmeyer <[email protected]> and contributors"]
version = "1.3.6"
version = "1.4.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ makedocs(;
"FeatureTweak" => "explanation/generators/feature_tweak.md",
"Gravitational" => "explanation/generators/gravitational.md",
"Greedy" => "explanation/generators/greedy.md",
"GrowingSpheres" => "explanation/generators/growing_spheres.md",
"PROBE" => "explanation/generators/probe.md",
"REVISE" => "explanation/generators/revise.md",
"MINT" => "explanation/generators/mint.md",
Expand Down
8 changes: 5 additions & 3 deletions src/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function __init__()
end

# Dependencies:
using Flux
using Flux: Flux
using TaijaBase: TaijaBase

# Setup:
Expand Down Expand Up @@ -54,6 +54,7 @@ export Linear, MLP, DeepEnsemble
export flux_training_params
export probs, logits
export standard_models_catalogue, all_models_catalogue, model_evaluation, predict_label
export fit_model

# Convergence
include("convergence/Convergence.jl")
Expand All @@ -79,7 +80,6 @@ export FeatureTweakGenerator
export GenericGenerator
export GravitationalGenerator
export GreedyGenerator
export GrowingSpheresGenerator
export REVISEGenerator
export DiCEGenerator
export WachterGenerator
Expand All @@ -91,10 +91,12 @@ export @objective
# argmin
###
include("counterfactuals/Counterfactuals.jl")
export CounterfactualExplanation
export CounterfactualExplanation, FlattenedCE
export generate_counterfactual
export total_steps, converged, terminated, path, target_probs
export animate_path
export flatten, unflatten, FlattenedCE
export target_encoded

include("evaluation/Evaluation.jl")
using .Evaluation
Expand Down
2 changes: 1 addition & 1 deletion src/convergence/Convergence.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Convergence

using Distributions
using Flux
using Flux: Flux
using LinearAlgebra
using ..CounterfactualExplanations
using ..Models
Expand Down
4 changes: 2 additions & 2 deletions src/counterfactuals/Counterfactuals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using .GenerativeModels
using .Generators
using .Models
using ChainRulesCore: ChainRulesCore
using Flux
using Flux: Flux
using MLUtils: MLUtils
using MultivariateStats
using Statistics: Statistics
Expand All @@ -14,7 +14,6 @@ using StatsBase
include("core_struct.jl")
include("encodings.jl")
include("generate_counterfactual.jl")
include("growing_spheres.jl")
include("info_extraction.jl")
include("initialisation.jl")
include("path_tracking.jl")
Expand All @@ -23,6 +22,7 @@ include("search.jl")
include("termination.jl")
include("utils.jl")
include("vectorised.jl")
include("flatten.jl")

# Counterfactual Rule Explanations:
include("CRE.jl")
54 changes: 54 additions & 0 deletions src/counterfactuals/flatten.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
FlattenedCE <: AbstractCounterfactualExplanation

A flattened representation of a `CounterfactualExplanation`, containing only the `factual`, `target`, and `counterfactual` attributes. This can be useful for compact storage or transmission of explanations.
"""
struct FlattenedCE <: AbstractCounterfactualExplanation
factual::AbstractArray
target::RawTargetType
counterfactual::AbstractArray
end

"""
(ce::CounterfactualExplanation)()::FlattenedCE

Calling the `ce::CounterfactualExplanation` object results in a [`FlattenedCE`](@ref) instance, which is the flattened version of the original.
"""
(ce::CounterfactualExplanation)()::FlattenedCE =
FlattenedCE(ce.factual, ce.target, ce.counterfactual)

"""
flatten(ce::CounterfactualExplanation)

Alias for `(ce::CounterfactualExplanation)()`. Converts a `CounterfactualExplanation` to its flattened form.
"""
flatten(ce::CounterfactualExplanation) = ce()

function unflatten(
flat_ce::FlattenedCE,
data::CounterfactualData,
M::Models.AbstractModel,
generator::Generators.AbstractGenerator;
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)::CounterfactualExplanation
return CounterfactualExplanation(
flat_ce.factual,
flat_ce.target,
data,
M,
generator;
initialization=initialization,
convergence=convergence,
num_counterfactuals=size(flat_ce.counterfactual, 2),
)
end

"""
target_encoded(flat_ce::FlattenedCE, data::CounterfactualData)

Returns the encoded representation of `flat_ce.target`.
"""
function target_encoded(flat_ce::FlattenedCE, data::CounterfactualData)
return data.output_encoder(flat_ce.target; y_levels=data.y_levels)
end
15 changes: 11 additions & 4 deletions src/counterfactuals/generate_counterfactual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
timeout::Union{Nothing,Real}=nothing,
return_flattened::Bool=false,
)

The core function that is used to run counterfactual search for a given factual `x`, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.
Expand All @@ -24,6 +25,7 @@ The core function that is used to run counterfactual search for a given factual
- `initialization::Symbol=:add_perturbation`: Initialization method. By default, the initialization is done by adding a small random perturbation to the factual to achieve more robustness.
- `convergence::Union{AbstractConvergence,Symbol}=:decision_threshold`: Convergence criterion. By default, the convergence is based on the decision threshold. Possible values are `:decision_threshold`, `:max_iter`, `:generator_conditions` or a conrete convergence object (e.g. [`DecisionThresholdConvergence`](@ref)).
- `timeout::Union{Nothing,Int}=nothing`: Timeout in seconds.
- `return_flattened::Bool`: If true, the flattened CE is returned instead of a CE object.

# Examples

Expand Down Expand Up @@ -83,7 +85,10 @@ function generate_counterfactual(
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
timeout::Union{Nothing,Real}=nothing,
return_flattened::Bool=false,
)
output(ce::CounterfactualExplanation) = return_flattened ? flatten(ce) : ce

# Initialize:
ce = CounterfactualExplanation(
x,
Expand All @@ -98,14 +103,14 @@ function generate_counterfactual(

# Check for redundancy (assess if already converged with respect to factual):
if Convergence.converged(ce.convergence, ce, ce.factual)
@info "Factual already in target class and probability exceeds threshold γ=$(ce.convergence.decision_threshold)."
return ce
@info "Factual already in target class and probability exceeds threshold."
return output(ce)
end

# Check for incompatibility:
if Generators.incompatible(ce.generator, ce)
@info "Generator is incompatible with other specifications for the counterfactual explanation (e.g. the model). See warnings for details. No search completed."
return ce
return output(ce)
end

# Search:
Expand All @@ -120,7 +125,9 @@ function generate_counterfactual(
end
end
end
return ce

# Return full or flattened explanation:
return output(ce)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/counterfactuals/initialisation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Flux
using Flux: Flux

"""
initialize_state(ce::CounterfactualExplanation)
Expand Down
22 changes: 17 additions & 5 deletions src/counterfactuals/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,28 @@ function adjust_shape!(ce::CounterfactualExplanation)
end

"""
find_potential_neighbors(ce::AbstractCounterfactualExplanation)
find_potential_neighbours(
ce::AbstractCounterfactualExplanation, data::CounterfactualData, n::Int=1000
)

Finds potential neighbors for the selected factual data point.
"""
function find_potential_neighbours(ce::AbstractCounterfactualExplanation, n::Int=1000)
nobs = size(ce.data.X, 2)
data = DataPreprocessing.subsample(ce.data, minimum([nobs, n]))
function find_potential_neighbours(
ce::AbstractCounterfactualExplanation, data::CounterfactualData, n::Int=1000
)
nobs = size(data.X, 2)
data = DataPreprocessing.subsample(data, minimum([nobs, n]))
ids = findall(data.output_encoder.labels .== ce.target)
n_candidates = minimum([size(ce.data.y, 2), n])
n_candidates = minimum([size(data.y, 2), n])
candidates = DataPreprocessing.select_factual(data, rand(ids, n_candidates))
potential_neighbours = reduce(hcat, map(x -> x[1], collect(candidates)))
return potential_neighbours
end

"""
find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000)

Overloads the function for [`CounterfactualExplanation`](@ref) to use the counterfactual data's labels if no data is provided.
"""
find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000) =
find_potential_neighbours(ce, ce.data, n)
4 changes: 2 additions & 2 deletions src/data_preprocessing/counterfactual_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Stores data and metadata for counterfactual explanations.
"""

mutable struct CounterfactualData
X::AbstractMatrix
X::AbstractArray
y::EncodedOutputArrayType
likelihood::Symbol
mutability::Union{Vector{Symbol},Nothing}
Expand Down Expand Up @@ -103,7 +103,7 @@ mutable struct CounterfactualData

if all(conditions)
new(
X,
view(X, :, :),
y,
likelihood,
mutability,
Expand Down
27 changes: 26 additions & 1 deletion src/evaluation/Evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,29 @@ using ..Models
using LinearAlgebra: LinearAlgebra
using Statistics

include("serialization.jl")
include("benchmark.jl")
include("evaluate.jl")
include("measures.jl")

export global_serializer, Serializer, NullSerializer, _serialization_state
export global_output_identifier, DefaultOutputIdentifier, _output_id, get_global_output_id
export ExplicitOutputIdentifier
export get_global_ce_transform,
global_ce_transform, IdentityTransformer, ExplicitCETransformer
export Benchmark, benchmark, evaluate, default_measures
export validity, redundancy
export plausibility, faithfulness
export plausibility
export plausibility_energy_differential,
plausibility_cosine, plausibility_distance_from_target
export faithfulness
export plausibility_measures, default_measures, distance_measures, all_measures
export concatenate_benchmarks

"Available plausibility measures."
const plausibility_measures = [
plausibility_energy_differential, plausibility_cosine, plausibility_distance_from_target
]

"The default evaluation measures."
const default_measures = [
Expand All @@ -28,4 +44,13 @@ const distance_measures = [
CounterfactualExplanations.Objectives.distance_linf,
]

"All measures."
const all_measures = [
validity,
redundancy,
collect(values(CounterfactualExplanations.Objectives.penalties_catalogue))...,
plausibility_measures...,
faithfulness,
]

end
Loading
Loading