Skip to content

Commit

Permalink
come on
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Dec 19, 2024
1 parent f65d21d commit c62de6c
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 30 deletions.
7 changes: 4 additions & 3 deletions src/counterfactuals/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ end
Calling the `ce::CounterfactualExplanation` object results in a [`FlattenedCE`](@ref) instance, which is the flattened version of the original.
"""
(ce::CounterfactualExplanation)()::FlattenedCE = FlattenedCE(ce.factual, ce.target, ce.counterfactual)
(ce::CounterfactualExplanation)()::FlattenedCE =
FlattenedCE(ce.factual, ce.target, ce.counterfactual)

"""
flatten(ce::CounterfactualExplanation)
Expand All @@ -31,7 +32,7 @@ function unflatten(
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)::CounterfactualExplanation
CounterfactualExplanation(
return CounterfactualExplanation(
flat_ce.factual,
flat_ce.target,
data,
Expand All @@ -41,4 +42,4 @@ function unflatten(
convergence=convergence,
num_counterfactuals=size(flat_ce.counterfactual, 2),
)
end
end
3 changes: 1 addition & 2 deletions src/counterfactuals/generate_counterfactual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ function generate_counterfactual(
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
timeout::Union{Nothing,Real}=nothing,
return_flattened::Bool=false
return_flattened::Bool=false,
)

output(ce::CounterfactualExplanation) = return_flattened ? flatten(ce) : ce

# Initialize:
Expand Down
2 changes: 1 addition & 1 deletion src/data_preprocessing/counterfactual_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ mutable struct CounterfactualData

if all(conditions)
new(
view(X,:,:),
view(X, :, :),
y,
likelihood,
mutability,
Expand Down
8 changes: 5 additions & 3 deletions src/evaluation/Evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ include("measures.jl")
export global_serializer, Serializer, NullSerializer, _serialization_state
export global_output_identifier, DefaultOutputIdentifier, _output_id, get_global_output_id
export ExplicitOutputIdentifier
export get_global_ce_transform, global_ce_transform, IdentityTransformer, ExplicitCETransformer
export get_global_ce_transform,
global_ce_transform, IdentityTransformer, ExplicitCETransformer
export Benchmark, benchmark, evaluate, default_measures
export validity, redundancy
export plausibility
export plausibility_energy_differential, plausibility_cosine, plausibility_distance_from_target
export plausibility_energy_differential,
plausibility_cosine, plausibility_distance_from_target
export faithfulness
export plausibility_measures, default_measures, distance_measures, all_measures
export concatenate_benchmarks
Expand All @@ -44,7 +46,7 @@ const distance_measures = [

"All measures."
const all_measures = [
validity,
validity,
redundancy,
collect(values(CounterfactualExplanations.Objectives.penalties_catalogue))...,
plausibility_measures...,
Expand Down
22 changes: 7 additions & 15 deletions src/evaluation/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,7 @@ function benchmark(
end

function unflatten_for_eval(
ces::Vector{FlattenedCE},
data::CounterfactualData,
Ms,
gens,
kwrgs
ces::Vector{FlattenedCE}, data::CounterfactualData, Ms, gens, kwrgs
)
unflatten_kwrgs = (;)
initialization = get(kwrgs, :initialization, nothing)
Expand Down Expand Up @@ -368,11 +364,7 @@ function benchmark(
yhat = CounterfactualExplanations.predict_label(M, test_data)
for i in 1:n_individuals
# For each individual and specified factual label, randomly choose index of a factual observation:
chosen_ind = rand(
findall(
yhat .== factual[i],
),
)[1]
chosen_ind = rand(findall(yhat .== factual[i]))[1]
push!(chosen, chosen_ind)
end
xs = CounterfactualExplanations.select_factual(test_data, chosen)
Expand All @@ -390,9 +382,9 @@ function benchmark(
@debug "Length of grid: $(length(grid))"

# Split grid vertically into groups of `vertical_splits` elements:
if split_vertically
if split_vertically
npart = minimum([length(grid), vertical_splits])
else
else
npart = length(grid)
end
@debug "Number of elements per partition: $npart"
Expand Down Expand Up @@ -442,7 +434,7 @@ function benchmark(
# Unflatten for evaluation:
@assert typeof(ces) == Vector{FlattenedCE} "Expecting a vector of `FlattenedCE`. Did you accidentally set `return_flattened=false`?"
ces = unflatten_for_eval(ces, data, Ms, gens, kwrgs)

# Free up memory:
xs = nothing
targets = nothing
Expand Down Expand Up @@ -510,7 +502,7 @@ function concatenate_benchmarks(storage_path::String)
return nothing
end
bmk_files = get_benchmark_files(storage_path)
bmks = Serialization.deserialize.(bmk_files)
bmks = Serialization.deserialize.(bmk_files)
bmks = reduce(vcat, bmks)
return bmks
end
Expand All @@ -521,7 +513,7 @@ end
Returns a list of all benchmark files stored in `storage_path`.
"""
function get_benchmark_files(storage_path::String)
# No results:
# No results:
if length(storage_path) == 0
@warn "No interim results found"
return nothing
Expand Down
2 changes: 1 addition & 1 deletion src/evaluation/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct ExplicitCETransformer <: AbstractCETransformer
fun::Function
function ExplicitCETransformer(fun::Function)
@assert hasmethod(fun, Tuple{CounterfactualExplanation}) "Measure function must have a method for `CounterfactualExplanation`"
new(fun)
return new(fun)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/evaluation/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ OutputID(identifier::DefaultOutputIdentifier) = ""
global _output_id::String = ""

"And explicit output identifier that takes the string value of `id`. "
struct ExplicitOutputIdentifier <: AbstractOutputIdentifier
struct ExplicitOutputIdentifier <: AbstractOutputIdentifier
id::String
end

Expand All @@ -53,4 +53,4 @@ Set the global output identifier to `identifier` and return its string represent
function global_output_identifier(identifier::AbstractOutputIdentifier)
global _output_id = OutputID(identifier)
return _output_id
end
end
3 changes: 1 addition & 2 deletions src/objectives/penalties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,10 @@ Compute the distance from a counterfactual to the target manifold using cosine s
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation object.
- `kwrgs...`: Additional keyword arguments for the distance function.
"""
function distance_from_target_cosine(ce::AbstractCounterfactualExplanation;kwrgs...)
function distance_from_target_cosine(ce::AbstractCounterfactualExplanation; kwrgs...)
return distance_from_target(ce; cosine=true, kwrgs...)
end


"""
function model_loss_penalty(
ce::AbstractCounterfactualExplanation;
Expand Down
7 changes: 6 additions & 1 deletion test/other/evaluation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using TaijaData: load_moons, load_circles
using CounterfactualExplanations.Evaluation:
Benchmark, evaluate, validity, distance_measures
Benchmark,
evaluate,
validity,
distance_measures,
ExplicitCETransformer,
global_serializer
using CounterfactualExplanations.Objectives: distance

# Dataset
Expand Down

0 comments on commit c62de6c

Please sign in to comment.