Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

473 generator add t crex #476

Merged
merged 28 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v1.1.1].

## Version [1.3.0] - 2024-09-16

### Added

- Added basic support for the T-CREx counterfactual generator. [#473]

## Version [1.2.0] - 2024-09-10

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CounterfactualExplanations"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
authors = ["Patrick Altmeyer <[email protected]>"]
version = "1.2.0"
version = "1.3.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"hash": "f038aceeeef09b691dfd63f09787272c",
"result": {
"engine": "jupyter",
"markdown": "```@meta\nCurrentModule = CounterfactualExplanations \n```\n\n\n\n# T-CREx Generator\n\nThe T-CREx is a novel model-agnostic counterfactual generator that can be used to generate local and global Counterfactual Rule Explanations (CREx) [@bewley2024counterfactual].\n\n\n```{=commonmark}\n!!! warning \"Breaking Changes Expected\"\n Work on this feature is still in its very early stages and breaking changes should be expected. The introduction of this new generator introduces new concepts such as global counterfactual explanations that are not explained anywhere else in this documentation. If you want to use this generator, please make sure you are familiar with the related literature. \n```\n\n\n## Usage\n\nThe implementation of the `TCRExGenerator` depends on [DecisionTree.jl](https://github.com/JuliaAI/DecisionTree.jl). For the time being, we have decided to not add a strong dependency on DecisionTree.jl to the package. Instead, the functionality of the `TCRExGenerator` is made available through the `DecisionTreeExt` extension, which will be loaded conditionally on loading the [DecisionTree.jl](https://github.com/JuliaAI/DecisionTree.jl) (see [Julia docs](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) for more details extensions):\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing DecisionTree\n```\n:::\n\n\nLet us first load set up the problem by loading some data, fitting a simple model and determining a target and factual class:\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Counteractual data and model:\nn = 3000\ndata = CounterfactualData(load_moons(n; noise=0.25)...)\nX = data.X\nM = fit_model(data, :MLP)\nfx = predict_label(M, data)\ntarget = 1\nfactual = 0\nchosen = rand(findall(predict_label(M, data) .== factual))\nx = select_factual(data, chosen)\n```\n:::\n\n\nNext, we instantiate the generator much like any other counterfactual generator in our package:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nρ = 0.02 # feasibility threshold (see Bewley et al. (2024))\nτ = 0.9 # accuracy threshold (see Bewley et al. (2024))\ngenerator = Generators.TCRExGenerator(ρ=ρ, τ=τ)\n```\n:::\n\n\nFinally, we can use the `TCRExGenerator` instance to generate a (global) counterfactual rule epxlanation (CRE) for the given target, data and model as follows:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\ncre = generator(target, data, M) # counterfactual rule explanation (global)\n```\n:::\n\n\nThe CRE can be applied to our factual `x` to derive a (local) counterfactual point explanation (CPE):\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nidx, optimal_rule = cre(x) # counterfactual point explanation (local)\n```\n:::\n\n\n## Worked Example from @bewley2024counterfactual\n\nTo make better sense of this, we will now go through the worked example presented in @bewley2024counterfactual. For this purpose, we need to make the functions of the `DecisionTreeExt` extension available.\n\n\n```{=commonmark}\n!!! warning \"Private API\"\n Please note that of the `DecisionTreeExt` extension is loaded here purely for demonstrative purposes. You should not load the extension like this in your own work.\n```\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nDTExt = Base.get_extension(CounterfactualExplanations, :DecisionTreeExt)\n```\n:::\n\n\n### (a) Tree-based surrogate model \n\nIn the first step, we train a tree-based surrogate model based on the data and the black-box model `M`. Specifically, the surrogate model is trained on pairs of observed input data and the labels predicted by the black-box model: $\\{(x, M(x))\\}_{1\\leq i \\leq n}$. Following @bewley2024counterfactual, we impose a minimum number of samples per leaf to ensure counterfactual *feasibility* (also often referred to as *plausibility*). This number is computed under the hood and based on the `generator.ρ` field of the `TCRExGenerator`, which can be used to specify the minimum fraction of all samples that is contained by any given rule. \n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n# Surrogate:\nmodel, fitresult = DTExt.grow_surrogate(generator, data, M)\nM_sur = CounterfactualExplanations.DecisionTreeModel(model; fitresult=fitresult)\nplot(M_sur, data; ms=3, markerstrokewidth=0, size=(500, 500), colorbar=false)\n```\n\n::: {.cell-output .cell-output-display execution_count=41}\n![Tree-based surrogate model](tcrex_files/figure-commonmark/fig-surr-output-1.svg){#fig-surr}\n:::\n:::\n\n\nWe can reassure ourselves that the feasibility constraint is indeed respected:\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\n# Extract rules:\nR = DTExt.extract_rules(fitresult[1])\n\n# Compute feasibility and accuracy:\nfeas = DTExt.rule_feasibility.(R, (X,))\n@assert minimum(feas) >= ρ\n@info \"Minimum fraction of samples across all rules is $(round(minimum(feas), digits=3))\"\nacc_factual = DTExt.rule_accuracy.(R, (X,), (fx,), (factual,))\nacc_target = DTExt.rule_accuracy.(R, (X,), (fx,), (target,))\n@assert all(acc_target .+ acc_factual .== 1.0)\n```\n:::\n\n\n### (b) Maximal-valid rules \n\nFrom the complete set of rules derived from the surrogate tree, we can derive the maximal-valid rules next. Intuitively, \"a maximal-valid rule is one that cannot be made any larger without violating the validity conditions\", where validity is defined in terms of both *feasibility* (`generator.ρ`) and accuracy (`generator.τ`). \n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nR_max = DTExt.max_valid(R, X, fx, target, τ)\nfeas_max = DTExt.rule_feasibility.(R_max, (X,))\nacc_max = DTExt.rule_accuracy.(R_max, (X,), (fx,), (target,))\nplt = plot(data; ms=3, markerstrokewidth=0, size=(500, 500))\np1 = deepcopy(plt)\nrectangle(w, h, x, y) = Shape(x .+ [0,w,w,0], y .+ [0,0,h,h])\nfor (i, rule) in enumerate(R_max)\n ubx, uby = minimum([rule[1][2], maximum(X[1, :])]),\n minimum([rule[2][2], maximum(X[2, :])])\n lbx, lby = maximum([rule[1][1], minimum(X[1, :])]),\n maximum([rule[2][1], minimum(X[2, :])])\n _feas = round(feas_max[i]; digits=2)\n _n = Int(round(feas_max[i] * n; digits=2))\n _acc = round(acc_max[i]; digits=2)\n @info \"Rectangle R$i with feasibility $(_feas) (n≈$(_n)) and accuracy $(_acc)\"\n lab = \"R$i (ρ̂=$(_feas), τ̂=$(_acc))\"\n plot!(p1, rectangle(ubx-lbx,uby-lby,lbx,lby), opacity=.5, color=i+2, label=lab)\nend\np1\n```\n\n::: {.cell-output .cell-output-display execution_count=43}\n![Maximal-valid rules.](tcrex_files/figure-commonmark/fig-max-output-1.svg){#fig-max}\n:::\n:::\n\n\n### (c) Induced grid partition\n\nBased on the set of maximal-valid rules, we compute and plot the induced grid partition below.\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\n_grid = DTExt.induced_grid(R_max)\np2 = deepcopy(p1)\nfunction plot_grid!(p)\n for (i, (bounds_x, bounds_y)) in enumerate(_grid)\n lbx, ubx = bounds_x\n lby, uby = bounds_y\n lbx = maximum([lbx, minimum(X[1, :])])\n lby = maximum([lby, minimum(X[2, :])])\n ubx = minimum([ubx, maximum(X[1, :])])\n uby = minimum([uby, maximum(X[2, :])])\n plot!(\n p,\n rectangle(ubx - lbx, uby - lby, lbx, lby);\n fillcolor=\"black\",\n fillalpha=0.1,\n label=nothing,\n lw=2,\n )\n end\nend\nplot_grid!(p2)\np2\n```\n\n::: {.cell-output .cell-output-display execution_count=44}\n![Induced grid partition.](tcrex_files/figure-commonmark/fig-grid-output-1.svg){#fig-grid}\n:::\n:::\n\n\n### (d) Grid cell prototypes \n\nNext, we pick prototypes from each cell in the induced grid. By setting `pick_arbitrary=false` here we enfore that prototypes correspond to cell centroids, which is not necessary. For each prototype, we compute the corresponding CRE, which is indicated by the color of the large markers in the figure below:\n\n::: {.cell execution_count=12}\n``` {.julia .cell-code}\nxs = DTExt.prototype.(_grid, (X,); pick_arbitrary=false)\nRᶜ = DTExt.cre.((R_max,), xs, (X,); return_index=true) \np3 = deepcopy(p2)\nscatter!(p3, eachrow(hcat(xs...))..., ms=10, label=nothing, color=Rᶜ.+2)\np3\n```\n\n::: {.cell-output .cell-output-display execution_count=45}\n![Grid cell prototypes.](tcrex_files/figure-commonmark/fig-proto-output-1.svg){#fig-proto}\n:::\n:::\n\n\n### (e) - (f) Global CE representation\n\nBased on the prototypes and their corresponding rule assignments, we fit a CART classification tree with restricted feature thresholds. Specificically, features thresholds are restricted to the partition bounds induced by the set of maximal-valid rules as in @bewley2024counterfactual. The figure below shows the resulting global CE representation (i.e. the metarules).\n\n::: {.cell execution_count=13}\n``` {.julia .cell-code}\nbounds = DTExt.partition_bounds(R_max)\ntree = DTExt.classify_prototypes(hcat(xs...)', Rᶜ, bounds)\nR_final, labels = DTExt.extract_leaf_rules(tree) \np4 = deepcopy(plt)\nfor (i, rule) in enumerate(R_final)\n ubx, uby = minimum([rule[1][2], maximum(X[1, :])]),\n minimum([rule[2][2], maximum(X[2, :])])\n lbx, lby = maximum([rule[1][1], minimum(X[1, :])]),\n maximum([rule[2][1], minimum(X[2, :])])\n plot!(\n p4,\n rectangle(ubx - lbx, uby - lby, lbx, lby);\n fillalpha=0.5,\n label=nothing,\n color=labels[i] + 2\n )\nend\np4\n```\n\n::: {.cell-output .cell-output-display execution_count=46}\n![Global CE representation.](tcrex_files/figure-commonmark/fig-global-output-1.svg){#fig-global}\n:::\n:::\n\n\n### (g) Local CE example\n\nTo generate a local explanation based on the global CE representation, we simply apply the CART decision tree classifier from the previous step to our factual:\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\noptimal_rule = apply_tree(tree, vec(x))\np5 = deepcopy(p2)\nscatter!(p5, [x[1]], [x[2]], ms=10, color=2+optimal_rule, label=\"Local CE (move to R$optimal_rule)\")\np5\n```\n\n::: {.cell-output .cell-output-display execution_count=47}\n![Local CE example.](tcrex_files/figure-commonmark/fig-local-output-1.svg){#fig-local}\n:::\n:::\n\n\n## References\n\n",
"supporting": [
"tcrex_files"
],
"filters": []
}
}
Loading
Loading