diff --git a/src/counterfactuals/flatten.jl b/src/counterfactuals/flatten.jl index 28754ae14..6498f4576 100644 --- a/src/counterfactuals/flatten.jl +++ b/src/counterfactuals/flatten.jl @@ -51,4 +51,4 @@ Returns the encoded representation of `flat_ce.target`. """ function target_encoded(flat_ce::FlattenedCE, data::CounterfactualData) return data.output_encoder(flat_ce.target; y_levels=data.y_levels) -end \ No newline at end of file +end diff --git a/src/counterfactuals/utils.jl b/src/counterfactuals/utils.jl index ca981046d..3e1e2ee59 100644 --- a/src/counterfactuals/utils.jl +++ b/src/counterfactuals/utils.jl @@ -96,4 +96,5 @@ end Overloads the function for [`CounterfactualExplanation`](@ref) to use the counterfactual data's labels if no data is provided. """ -find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000) = find_potential_neighbours(ce, ce.data, n) +find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000) = + find_potential_neighbours(ce, ce.data, n) diff --git a/src/objectives/penalties.jl b/src/objectives/penalties.jl index 392eb7303..ebb2b37a7 100644 --- a/src/objectives/penalties.jl +++ b/src/objectives/penalties.jl @@ -121,11 +121,13 @@ function distance_from_target( ) # Get potential neighbours: - get!( - ce.search, - :potential_neighbours, - CounterfactualExplanations.find_potential_neighbours(ce, K), - ) + ChainRulesCore.ignore_derivatives() do + get!( + ce.search, + :potential_neighbours, + CounterfactualExplanations.find_potential_neighbours(ce, K), + ) + end ys = ce.search[:potential_neighbours] if K > size(ys, 2)