Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling a scikit-learn model #4

Open
pat-alt opened this issue Nov 10, 2023 · 2 comments
Open

Calling a scikit-learn model #4

pat-alt opened this issue Nov 10, 2023 · 2 comments

Comments

@pat-alt
Copy link
Member

pat-alt commented Nov 10, 2023

I did implement the things mentioned above but when applied to a toy workflow inspired from the doc, it is currently not working. I am unclear on what the error message means. The cause could be what you mentioned in your previous reply, i.e. that gradient-based generators won't work with a random forest model.

What I did can be found here: https://github.com/fabiensatalia/SklearnBackend.jl

Here is the error message I get when running test.jl:

ERROR: Compiling Tuple{typeof(convert), Type{PyCall.PyAny}, PyCall.PyObject}: ArgumentError: array must be non-empty
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101 [inlined]
  [2] _pullback(::Zygote.Context{true}, ::typeof(convert), ::Type{PyCall.PyAny}, ::PyCall.PyObject)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:101
  [3] _pullback
    @ ~/.julia/packages/PyCall/ilqDX/src/PyCall.jl:318 [inlined]
  [4] _pullback(::Zygote.Context{true}, ::typeof(getproperty), ::PyCall.PyObject, ::Symbol)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [5] macro expansion
    @ ~/.julia/packages/Zygote/YYT6v/src/lib/literal_getproperty.jl:93 [inlined]
  [6] _pullback
    @ ~/.julia/packages/Zygote/YYT6v/src/lib/literal_getproperty.jl:117 [inlined]
  [7] _pullback
    @ ~/projects/CounterFactuals_contribu/julia/SklearnBackend/src/SklearnBackend.jl:36 [inlined]
  [8] _pullback(::Zygote.Context{true}, ::typeof(probs), ::SklearnModel, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [9] _pullback
    @ ~/projects/CounterFactuals_contribu/julia/SklearnBackend/src/SklearnBackend.jl:40 [inlined]
 [10] _pullback(::Zygote.Context{true}, ::typeof(logits), ::SklearnModel, ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/packages/CounterfactualExplanations/ATUIK/src/objectives/loss_functions.jl:23 [inlined]
 [12] _pullback(::Zygote.Context{true}, ::CounterfactualExplanations.Objectives.var"##logitcrossentropy#12", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(Flux.Losses.logitcrossentropy), ::CounterfactualExplanation)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/CounterfactualExplanations/ATUIK/src/objectives/loss_functions.jl:22 [inlined]
 [14] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.Losses.logitcrossentropy), args::CounterfactualExplanation)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [15] _pullback
    @ ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/loss.jl:18 [inlined]
 [16] _pullback(::Zygote.Context{true}, ::typeof(CounterfactualExplanations.Generators.ℓ), ::CounterfactualExplanations.Generators.GradientBasedGenerator, ::Nothing, ::CounterfactualExplanation)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [17] _pullback
    @ ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/loss.jl:7 [inlined]
 [18] _pullback(::Zygote.Context{true}, ::typeof(CounterfactualExplanations.Generators.ℓ), ::CounterfactualExplanations.Generators.GradientBasedGenerator, ::CounterfactualExplanation)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [19] _pullback
    @ ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/gradient_based/loss.jl:13 [inlined]
 [20] _pullback(::Zygote.Context{true}, ::CounterfactualExplanations.Generators.var"#15#16"{CounterfactualExplanations.Generators.GradientBasedGenerator, CounterfactualExplanation})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [21] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:414
 [22] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:96
 [23] ∂ℓ(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, M::SklearnModel, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/gradient_based/loss.jl:13
 [24] ∇(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, M::SklearnModel, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/gradient_based/loss.jl:46
 [25] propose_state(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/gradient_based/utils.jl:9
 [26] generate_perturbations(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generators/gradient_based/utils.jl:24
 [27] update!(ce::CounterfactualExplanation)
    @ CounterfactualExplanations ~/.julia/packages/CounterfactualExplanations/ATUIK/src/counterfactuals/search.jl:9
 [28] generate_counterfactual(x::Matrix{Float64}, target::Int64, data::CounterfactualData, M::SklearnModel, generator::CounterfactualExplanations.Generators.GradientBasedGenerator; num_counterfactuals::Int64, initialization::Symbol, generative_model_params::NamedTuple{(), Tuple{}}, max_iter::Int64, decision_threshold::Float64, gradient_tol::Float64, min_success_rate::Float64, converge_when::Symbol, timeout::Nothing, invalidation_rate::Float64, learning_rate::Float64, variance::Float64)
    @ CounterfactualExplanations ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generate_counterfactual/core.jl:86
 [29] generate_counterfactual(x::Matrix{Float64}, target::Int64, data::CounterfactualData, M::SklearnModel, generator::CounterfactualExplanations.Generators.GradientBasedGenerator)
    @ CounterfactualExplanations ~/.julia/packages/CounterfactualExplanations/ATUIK/src/generate_counterfactual/core.jl:43
 [30] top-level scope
    @ REPL[49]:1

Originally posted by @fabiensatalia in JuliaTrustworthyAI/CounterfactualExplanations.jl#338 (reply in thread)

@pat-alt pat-alt transferred this issue from JuliaTrustworthyAI/CounterfactualExplanations.jl Nov 10, 2023
@pat-alt
Copy link
Member Author

pat-alt commented Nov 10, 2023

Thanks @fabiensatalia for this.

The reason you get this error (also for differentiable sklearn models) is that Zygote cannot access gradient of Python models. If you want to enable support for gradient-based generators for Python models, you need to overload the Generators.∂ℓ method, as we have done here for PyTorch models.

In the meantime, the only generator that should work for sklearn models without gradient-access is GrowingSpheres. I've tried this here using your code, but am getting the following error:

│     (3)    , variance::Float64)                                                                                                          │
│            /Users/paltmeyer/.julia/packages/CounterfactualExplanations/q96hi/src/generate_counterfactual/core.jl:121                     │
│              │ ╭─────────────────────────────────────────────────╮                                                                       │
│              ╰─│❯ 121 Generators.growing_spheres_generation!(ce) │                                                                       │
│                ╰─────────────────────────────────────────────────╯                                                                       │
│                                                                                                                                          │
│            growing_spheres_generation!(ce::CounterfactualExplanation)                                                                    │
│            /Users/paltmeyer/.julia/packages/CounterfactualExplanations/q96hi/src/generators/non_gradient_based/growing_                  │
│     (4)    spheres/growing_spheres.jl:111                                                                                                │
│              │ ╭───────────────────────────────────────────────────────────╮                                                             │
│              ╰─│❯ 111 ce.s′ = counterfactual_candidates[:, counterfactual] │                                                             │
│                ╰───────────────────────────────────────────────────────────╯                                                             │
│                                                                                                                                          │
│  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────    │
│                                                       Skipped 6 frames in Base                                                           │
│  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────    │
│                                                                                                                                          │
│  ────────────────────────────────────────────────────────── In module Base ──────────────────────────────────────────────────────────    │
│                                                                                                                                          │
│          ╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮       │
│     (11) │  to_index(i::Nothing)                                                                                                 │       │
│          │  ./indices.jl:300                                                                                                     │       │
│          ╰──────────────────────────────────────────────────────────────────────────────────────────────────────── ERROR LINE ───╯       │
│                                                                                                                                          │
╰──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭────────────────────────────────────────────────────────────── ArgumentError ─────────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│  ArgumentError: invalid index: nothing of type Nothing                                                                                   │
│                                                                                                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Tagging @kmariuszk here who added support for GrowingSpheres. Perhaps you have an idea what might be going on?

@pat-alt
Copy link
Member Author

pat-alt commented Nov 10, 2023

Maybe I should add that I think the idea of adding general support for all sklearn models using your approach in SklearnBackend.jl is very ambitious (and maybe we should make that clearer in the docs), because you need to figure out how to access gradients for every one of them.

If we add that sort of support anywhere, it will be right here in TaijaInteroperability, but it would be quite difficult. In terms of priorities, I think we'd first try to extend support for the many models support by MLJ.jl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant