Skip to content

Commit

Permalink
come on
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Dec 19, 2024
1 parent 19cb447 commit e60f448
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/counterfactuals/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ Returns the encoded representation of `flat_ce.target`.
"""
function target_encoded(flat_ce::FlattenedCE, data::CounterfactualData)
return data.output_encoder(flat_ce.target; y_levels=data.y_levels)
end
end
3 changes: 2 additions & 1 deletion src/counterfactuals/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,5 @@ end
Overloads the function for [`CounterfactualExplanation`](@ref) to use the counterfactual data's labels if no data is provided.
"""
find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000) = find_potential_neighbours(ce, ce.data, n)
find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000) =
find_potential_neighbours(ce, ce.data, n)
12 changes: 7 additions & 5 deletions src/objectives/penalties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@ function distance_from_target(
)

# Get potential neighbours:
get!(
ce.search,
:potential_neighbours,
CounterfactualExplanations.find_potential_neighbours(ce, K),
)
ChainRulesCore.ignore_derivatives() do
get!(
ce.search,
:potential_neighbours,
CounterfactualExplanations.find_potential_neighbours(ce, K),
)
end

ys = ce.search[:potential_neighbours]
if K > size(ys, 2)
Expand Down

0 comments on commit e60f448

Please sign in to comment.