Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asymmetric causal Shapley values with adaptive sampling #400

Merged
merged 290 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
290 commits
Select commit Hold shift + click to select a range
151bba7
lintr
martinju Aug 3, 2024
361d8f9
.
martinju Aug 3, 2024
12ca661
apply name changes to test files
martinju Aug 5, 2024
a5c666b
rename regular output name
martinju Aug 5, 2024
86fed31
adding setup adaptive ++
martinju Aug 5, 2024
410c05d
update regular tests
martinju Aug 5, 2024
ae29313
bugfix, improve printing and init adaptive tests
martinju Aug 6, 2024
34c0905
update test files
martinju Aug 6, 2024
970d08d
div
martinju Aug 6, 2024
4e319e7
remove timing arg and add hidden testing arg
martinju Aug 6, 2024
a81a22b
fixing broken testing objects after updates
martinju Aug 6, 2024
8a4f0db
update tests with testing = TRUE, and remove timing = FALSE
martinju Aug 6, 2024
bb4e385
rds files
martinju Aug 7, 2024
643e5f0
styler
martinju Aug 7, 2024
24ae4d4
[skip actions] .
martinju Aug 7, 2024
d0a1ad6
move functions to appropriate files
martinju Aug 7, 2024
854fee7
[skip actions] doc + temporary and hiddenly adding unique_sampling
martinju Aug 7, 2024
17c94ef
add timing + experiment with improved bootstrapping code
martinju Aug 8, 2024
c7f3e2b
[skip actions] fix non-unique sampling
martinju Aug 8, 2024
f95e7ef
init moving to max_n_combinations
martinju Aug 9, 2024
5c5f436
add feature_samples to iter_list in setup for convenience
martinju Aug 9, 2024
21c43dc
simplifying explain view + improve max_n_combinations sets and checks
martinju Aug 9, 2024
2e7e864
man
martinju Aug 9, 2024
699bde0
Merge commit '2e7e86450686f61d6f4c7f63ac87c5857ff0094e' into convergence
martinju Aug 9, 2024
c9b679e
.
martinju Aug 9, 2024
95dda97
[skip actions] remaining stuff of max_n_combinations. Works, i think
martinju Aug 9, 2024
01b017e
[skip actions] remaining stuff of max_n_combinations. Works, i think
martinju Aug 9, 2024
d296e87
new bootstrap introduced with tests
martinju Aug 9, 2024
6dbcaff
making tests work
martinju Aug 9, 2024
303e323
tests OK
martinju Aug 10, 2024
fb6d050
some more ok tests. Forecast dont work as of now
martinju Aug 10, 2024
db6c221
apply the feature_combination stuff also to groups
martinju Aug 10, 2024
63281a6
Not 100% sure this actually works as it should
martinju Aug 10, 2024
1d6fb63
add and fix group tests
martinju Aug 12, 2024
5bc4efe
new
martinju Aug 12, 2024
6086d2a
adaptive OK
martinju Aug 16, 2024
11df7de
all tests pass
martinju Aug 16, 2024
ca58ce4
styler
martinju Aug 16, 2024
6b4931d
man
martinju Aug 16, 2024
803a181
fix checks
martinju Aug 16, 2024
14f360e
Disable rcpp approx solve warnings
martinju Aug 16, 2024
38be193
temporary fix forecast (not adaptive yet)
martinju Aug 16, 2024
45a657e
tests
martinju Aug 16, 2024
28acb00
combinations -> coalitions and merging all features/groups-code
martinju Aug 16, 2024
886f93b
add features to coalition table for both features and groups
martinju Aug 19, 2024
7ba0408
bugfix groups + some plot test updates
martinju Aug 19, 2024
48d6b81
more fixing
martinju Aug 19, 2024
b9bd1fe
adaptive tests
martinju Aug 19, 2024
8b0bd2c
remaining tests
martinju Aug 19, 2024
8bdb6be
forcast tests (something is up with forecast grouping, though)
martinju Aug 20, 2024
3de91ff
[skip actions] style
martinju Aug 20, 2024
7d747b8
adding reweighting strategy on all cond
martinju Aug 21, 2024
8f46e38
[skip actions] add reweighting strategies + non-unique paired sampling
martinju Aug 23, 2024
9daf94e
n_samples -> n_MC_samples
martinju Aug 23, 2024
b4125d2
tests OK
martinju Aug 23, 2024
cd7b0a8
fix iterative with paired sampling
martinju Sep 5, 2024
e3373d3
update tests after bootstrap change + .Rprofile for smoother testing
martinju Sep 5, 2024
fd7734f
add intermediate saving
martinju Sep 6, 2024
9a35c14
working version of continue training
martinju Sep 6, 2024
a064b9c
moves prev_shapr_object handling to setup and add validity test
martinju Sep 6, 2024
e46822a
working
martinju Sep 6, 2024
55434c1
[skip actions] Working
martinju Sep 6, 2024
f4e7931
man + testthat
martinju Sep 9, 2024
7f18f91
new adaptive output testfiles
martinju Sep 9, 2024
681d197
Fix cutting of coalition list per horizon in ```shapley_setup_forecas…
jonlachmann Sep 10, 2024
ccc39af
Merge remote-tracking branch 'jonlachmann/convergence' into convergence
martinju Sep 11, 2024
23016dc
update OK forecast test files
martinju Sep 11, 2024
db7c15d
extra forecast test file update
martinju Sep 11, 2024
a4d02fb
add max_batch_size og min_n_batches
martinju Sep 19, 2024
c04571c
apply the new n_batches settings in practice
martinju Sep 19, 2024
2118826
remove all traces of n_batches in the older code
martinju Sep 26, 2024
5bee4b9
adaptive tests ok
martinju Sep 26, 2024
56fa77b
more tests
martinju Sep 26, 2024
cba572c
adding checks for adaptive argument formats
martinju Sep 27, 2024
1e481e2
update tests
martinju Sep 27, 2024
c780b70
regression
martinju Sep 27, 2024
33301c3
temporary disabling the forecast tests
martinju Sep 27, 2024
2e0e09d
new test files
martinju Sep 27, 2024
cbd01f4
move reweighting and set new defaults
martinju Sep 27, 2024
cc737aa
moving towards new defaults
martinju Sep 27, 2024
e76d068
adpative-output at least OK
martinju Sep 27, 2024
0e4ba15
[skip tests] new test files
martinju Sep 27, 2024
99864b9
[skip actions] other tests ok
martinju Sep 30, 2024
882ed16
Merge remote-tracking branch 'origin/convergence' into convergence
martinju Sep 30, 2024
75ae9a4
[skip actions] .
martinju Sep 30, 2024
92f45c1
[skip actions] documenting explain
martinju Sep 30, 2024
f7f5a49
more documentation
martinju Sep 30, 2024
9095c64
.
martinju Sep 30, 2024
d6bc603
[skip actions] Slight restructure + update of main vignette
martinju Sep 30, 2024
1789783
checks for the adaptive argument
martinju Oct 1, 2024
1a5d034
NSE warnings
martinju Oct 1, 2024
35cd82f
test updates
martinju Oct 1, 2024
59e1c23
man and zzz
martinju Oct 1, 2024
dadbce2
man + tests
martinju Oct 1, 2024
231c862
tmp
martinju Oct 1, 2024
08a9e87
deal with cont estimation for non-adaptive
martinju Oct 1, 2024
3627112
first vignette
martinju Oct 1, 2024
21a2c29
vaeac also need X in setup
martinju Oct 1, 2024
7cb617b
vaeac vignette works
martinju Oct 1, 2024
55f5219
[skip actions] init update of regression vignette
martinju Oct 1, 2024
f56eb49
+ regression
martinju Oct 1, 2024
f093fa0
fix docs
martinju Oct 1, 2024
3f222b3
style
martinju Oct 2, 2024
61336bf
linting
martinju Oct 2, 2024
a02c9c9
fix man
martinju Oct 2, 2024
0d96221
fix vignette
martinju Oct 2, 2024
91de6ec
remove (>= 3.0.0) for testthat for tesitng
martinju Oct 2, 2024
fde5a3e
Merge branch 'verbose' into convergence
martinju Oct 2, 2024
54aa94b
Merge branch 'convergence' into verbose
martinju Oct 2, 2024
92936a4
replacing the old verbose syntax in vaeac and regresion
martinju Oct 2, 2024
971637b
move everything to string verbose
martinju Oct 2, 2024
bb2092f
playing around with cli progress
martinju Oct 2, 2024
26c0cff
more testing
martinju Oct 3, 2024
851a2b6
more work
martinju Oct 3, 2024
06c878a
working OK for now
martinju Oct 3, 2024
d581017
more work
martinju Oct 3, 2024
b8f4331
separate regression done
martinju Oct 3, 2024
4108811
also regression_surrogate done
martinju Oct 3, 2024
8f89af3
fixed no testthat package?
martinju Oct 4, 2024
941b95e
consider myself done for now
martinju Oct 4, 2024
31a9203
vignettes
martinju Oct 4, 2024
996a261
testfile updates
martinju Oct 4, 2024
2676501
styler
martinju Oct 4, 2024
758725f
lint and some checks
martinju Oct 4, 2024
16ff5b8
Merge branch 'verbose' into convergence
martinju Oct 4, 2024
e91963e
hoping to avoid the missing testthat package error on GHA
martinju Oct 4, 2024
2e75796
Added the cpp files
LHBO Oct 4, 2024
332bbb4
Added all files to vignettes. Still need to make the vignette runnabl…
LHBO Oct 4, 2024
10c0b7f
Added all test files.
LHBO Oct 4, 2024
bc3931e
Added references
LHBO Oct 4, 2024
81c02aa
Added documentation to explain.
LHBO Oct 4, 2024
25e98e4
Updated Gaussian
LHBO Oct 5, 2024
ff0c435
Updated setup
LHBO Oct 5, 2024
e55cd94
Updated shapley setup. Need to discuss sampling with Martin.
LHBO Oct 5, 2024
a90a6d7
Added file with all the asymmetric and causal functions
LHBO Oct 5, 2024
bb87601
Updated copula
LHBO Oct 5, 2024
6fb3185
Typos in gaussian
LHBO Oct 5, 2024
e4c93bb
Typo in documentation of categorical
LHBO Oct 5, 2024
f358c75
Added categorical
LHBO Oct 5, 2024
0e0fbf0
Updated vaeac
LHBO Oct 5, 2024
8767548
Updated compute_vS to support stepwise causal sampling
LHBO Oct 5, 2024
4fb2491
Added todo comment
LHBO Oct 5, 2024
47f05b3
Forgot to add `causal_sampling` to compute_vS
LHBO Oct 5, 2024
eb96c45
combination -> coalition typos in gaussian and copula
LHBO Oct 5, 2024
fc72cd2
stylr
LHBO Oct 5, 2024
6a41efb
Updated cli with basic information about asymmetric and causal Shaple…
LHBO Oct 5, 2024
70023a1
combination -> coalition
LHBO Oct 6, 2024
d345d30
combination -> coalition, and n_samples -> n_MC_samples
LHBO Oct 6, 2024
ad9e589
n_samples -> n_MC_samples
LHBO Oct 6, 2024
4e7c2a6
Removed `vS_details` verbose from asymmetric/causal Shapley values af…
LHBO Oct 6, 2024
2975b05
Added variable so that error messages distinguish between feature-wis…
LHBO Oct 6, 2024
68adb8d
Updated the setup testfiles and ran them
LHBO Oct 6, 2024
ac75a85
n_samples -> n_MC_samples in C++ for consistency
LHBO Oct 6, 2024
ccc6f15
Bike dataset for the vignette. Copied from Heskes PR.
LHBO Oct 6, 2024
14d8c5c
Simulation study justifying chaning the categorical prepare data func…
LHBO Oct 6, 2024
5a2efce
File where (symmetric conditional) regular and causal Shapley values …
LHBO Oct 6, 2024
bdf892f
File used to compare with Heskes implementation on https://gitlab.sci…
LHBO Oct 6, 2024
42f01bc
styler
LHBO Oct 6, 2024
496f800
updated test output file
LHBO Oct 6, 2024
ad38546
Updated the plot functions to support plotting the average feature va…
LHBO Oct 6, 2024
a0c20cc
Update variable name
LHBO Oct 6, 2024
bb361ec
Update to adaptive
LHBO Oct 6, 2024
308b32f
Delete not needed functions
LHBO Oct 7, 2024
f973dcf
Name change: legit -> valid.
LHBO Oct 7, 2024
58b7ece
Changed so valid coalitions are now a data table
LHBO Oct 7, 2024
0bad63e
Add test for unique sampling. Add suport for sampling of causal coali…
LHBO Oct 7, 2024
62ae312
fixes to causal
LHBO Oct 7, 2024
4544abc
Plot updates
LHBO Oct 7, 2024
7220e51
Restructures tests
LHBO Oct 7, 2024
a0cb6ee
Updates to vignette
LHBO Oct 7, 2024
3f6d77d
Accept changes in files I have not touched
LHBO Oct 7, 2024
65cd09e
timing
LHBO Oct 7, 2024
ce1b679
documentation
LHBO Oct 7, 2024
54c738a
explain_forecast
LHBO Oct 7, 2024
f4d15ae
compute_estimates
LHBO Oct 7, 2024
190807b
shapr-package
LHBO Oct 7, 2024
ef62b6e
REFERENCES
LHBO Oct 7, 2024
fa4d95f
adaptive test snaps
LHBO Oct 7, 2024
9d7ea1d
vignette figures
LHBO Oct 7, 2024
5e7d585
Accept changes to vignettes
LHBO Oct 7, 2024
abffdf5
output snaps and tests
LHBO Oct 7, 2024
4e4b22a
approach timeseries
LHBO Oct 7, 2024
217c14f
check convergence
LHBO Oct 7, 2024
fbed1c3
print_iter
LHBO Oct 7, 2024
f815cd1
finalize explanation
LHBO Oct 7, 2024
8a341ed
Shapley setup
LHBO Oct 7, 2024
4372e9d
Explain
LHBO Oct 7, 2024
0c34b6f
VAEAC
LHBO Oct 7, 2024
3b19930
regression_separate
LHBO Oct 7, 2024
fa8fba1
Setup
LHBO Oct 7, 2024
4f7f1fc
typo
LHBO Oct 7, 2024
849edd0
new manuals
LHBO Oct 7, 2024
ad0fcb0
manuals
LHBO Oct 7, 2024
73e23ce
delete categorical old
LHBO Oct 7, 2024
b6aaa78
roxygen
LHBO Oct 7, 2024
1b80550
vaeac manuals
LHBO Oct 7, 2024
37abdf5
verbose
LHBO Oct 7, 2024
e8455d1
styler + lintr
LHBO Oct 7, 2024
70c4720
YAML since I do not have the rights to change it
LHBO Oct 7, 2024
53041a3
The same again
LHBO Oct 7, 2024
aa00e3d
Remove causal printouts
LHBO Oct 7, 2024
270a070
Test snap update
LHBO Oct 7, 2024
176a0d7
typo
LHBO Oct 7, 2024
e44cb9d
Test asym / caus output
LHBO Oct 7, 2024
fb3c0aa
vignette cache + logical error
LHBO Oct 7, 2024
9a5cbfd
To wide printout
LHBO Oct 7, 2024
1bb38ba
rerun test with new width output
LHBO Oct 7, 2024
26c99b5
redundant docu in copula
LHBO Oct 7, 2024
e67c5ad
rcpp update
LHBO Oct 7, 2024
5be36bb
built vignette
LHBO Oct 7, 2024
0c03611
man
martinju Oct 9, 2024
1ce2c6e
update vignettes
martinju Oct 9, 2024
bc490c9
update test files
martinju Oct 9, 2024
83f4eeb
check updates
martinju Oct 9, 2024
ac3810e
rename test files
martinju Oct 9, 2024
653b57e
avoid mvnfast dependency ++
martinju Oct 9, 2024
306ae41
remove sort_feature_list
martinju Oct 9, 2024
861f442
[skip actions] run on GHA
martinju Oct 9, 2024
6cf1c75
checks
martinju Oct 10, 2024
0b57d0f
styler and lint
martinju Oct 10, 2024
2698dd8
doc
martinju Oct 10, 2024
fb6f101
test updates
martinju Oct 10, 2024
d774ca2
remove check_coalitiotns_respect_order
martinju Oct 10, 2024
05f89c5
update test files
martinju Oct 10, 2024
f5c9276
dontrun internal examples functions
martinju Oct 10, 2024
6c5a197
Remove douplicates in finalize_explantaion.R
LHBO Oct 10, 2024
3fe8661
Increased readability in approach vaeac
LHBO Oct 10, 2024
236cecf
Delete old `batch_prepare_vS_MC_auxiliary` function and rename `batch…
LHBO Oct 10, 2024
d385d99
[skip actions] fixing GHA notes
martinju Oct 10, 2024
5204a15
Updated the format of the explain documentations, corrected some typo…
LHBO Oct 10, 2024
9ce8106
Refactored approach_gaussian.R for readability
LHBO Oct 10, 2024
5c6b135
Refactored approach_copula.R for readability
LHBO Oct 10, 2024
972c060
Removed todo note
LHBO Oct 10, 2024
ab81de8
Updated documentation. Some missing parameters
LHBO Oct 10, 2024
be26384
Removed check function not used
LHBO Oct 10, 2024
a9cef22
typo in explain
LHBO Oct 10, 2024
05d5ee8
typo in explain
LHBO Oct 10, 2024
58195b9
update documentation
LHBO Oct 10, 2024
b14129d
Forgot to update batch_prepare_vS_MC_auxiliary function call
LHBO Oct 10, 2024
ba7383a
n_MC_samples_updated -> n_MC_samples
LHBO Oct 10, 2024
32b0e2e
Update docu about dt_valid_causal_coalitions
LHBO Oct 10, 2024
789927c
documentation update
LHBO Oct 10, 2024
cf5596b
Merge Martin's updates into my branch
LHBO Oct 10, 2024
110c5de
updates to asym vignette
martinju Oct 10, 2024
6ec1c8f
Logical error in approach gaussiand and copula after refactoring. Ver…
LHBO Oct 11, 2024
4b4dad0
vignette
martinju Oct 11, 2024
503fb9b
Merge remote-tracking branch 'LHBO/CausalShapleyNew' into CausalShapl…
martinju Oct 11, 2024
0f9f668
.
martinju Oct 11, 2024
8c33ac1
fixed timing
martinju Oct 11, 2024
cbd2889
bugfix includegraphics
martinju Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ export(plot_MSEv_eval_crit)
export(plot_SV_several_approaches)
export(predict_model)
export(prepare_data)
export(prepare_data_causal)
export(prepare_data_copula_cpp)
export(prepare_data_copula_cpp_caus)
export(prepare_data_gaussian_cpp)
export(prepare_data_gaussian_cpp_caus)
export(prepare_next_iteration)
export(print_iter)
export(regression.train_model)
Expand Down Expand Up @@ -136,5 +139,6 @@ importFrom(utils,capture.output)
importFrom(utils,head)
importFrom(utils,methods)
importFrom(utils,modifyList)
importFrom(utils,relist)
importFrom(utils,tail)
useDynLib(shapr, .registration = TRUE)
54 changes: 54 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,36 @@ prepare_data_copula_cpp <- function(MC_samples_mat, x_explain_mat, x_explain_gau
.Call(`_shapr_prepare_data_copula_cpp`, MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat)
}

#' Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand
#'
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the
#' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`.
#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations to
#' explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat`.
#' @param x_explain_gaussian_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the
#' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been
#' transformed to a standardized normal distribution.
#' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations.
#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of
#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones.
#' This is not a problem internally in shapr as the empty and grand coalitions treated differently.
#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed
#' using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution.
#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance
#' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been
#' transformed to a standardized normal distribution.
#'
#' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`),
#' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single
#' conditional Gaussian MC samples for each explicand and `S_ind` coalition.
#'
#' @export
#' @keywords internal
#' @author Lars Henry Berge Olsen
prepare_data_copula_cpp_caus <- function(MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) {
.Call(`_shapr_prepare_data_copula_cpp_caus`, MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat)
}

#' Generate Gaussian MC samples
#'
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the
Expand All @@ -162,6 +192,30 @@ prepare_data_gaussian_cpp <- function(MC_samples_mat, x_explain_mat, S, mu, cov_
.Call(`_shapr_prepare_data_gaussian_cpp`, MC_samples_mat, x_explain_mat, S, mu, cov_mat)
}

#' Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand
#'
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the
#' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`.
#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations
#' to explain. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat`
#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of
#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones.
#' This is not a problem internally in shapr as the empty and grand coalitions treated differently.
#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature.
#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance
#' between all pairs of features.
#'
#' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`),
#' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single
#' conditional Gaussian MC samples for each explicand and `S_ind` coalition.
#'
#' @export
#' @keywords internal
#' @author Lars Henry Berge Olsen
prepare_data_gaussian_cpp_caus <- function(MC_samples_mat, x_explain_mat, S, mu, cov_mat) {
.Call(`_shapr_prepare_data_gaussian_cpp_caus`, MC_samples_mat, x_explain_mat, S, mu, cov_mat)
}

#' (Generalized) Mahalanobis distance
#'
#' Used to get the Euclidean distance as well by setting \code{mcov} = \code{diag(m)}.
Expand Down
122 changes: 84 additions & 38 deletions R/approach_categorical.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#'
#' @param categorical.epsilon Numeric value. (Optional)
#' If \code{joint_probability_dt} is not supplied, probabilities/frequencies are
#' estimated using `x_train`. If certain observations occur in `x_train` and NOT in `x_explain`,
#' estimated using `x_train`. If certain observations occur in `x_explain` and NOT in `x_train`,
#' then epsilon is used as the proportion of times that these observations occurs in the training data.
#' In theory, this proportion should be zero, but this causes an error later in the Shapley computation.
#'
Expand Down Expand Up @@ -36,35 +36,44 @@ setup_approach.categorical <- function(internal,

# estimate joint_prob_dt if it is not passed to the function
if (is.null(joint_probability_dt)) {
# Get the frequency of the unique feature value combinations in the training data
joint_prob_dt0 <- x_train[, .N, eval(feature_names)]

explain_not_in_train <- data.table::setkeyv(data.table::setDT(x_explain), feature_names)[!x_train]
# Get the feature value combinations in the explicands that are NOT in the training data and their frequency
explain_not_in_train <- data.table::setkeyv(data.table::setDT(data.table::copy(x_explain)), feature_names)[!x_train]
N_explain_not_in_train <- nrow(unique(explain_not_in_train))

# Add these feature value combinations, and their corresponding frequency, to joint_prob_dt0
if (N_explain_not_in_train > 0) {
joint_prob_dt0 <- rbind(joint_prob_dt0, cbind(explain_not_in_train, N = categorical.epsilon))
}

# Compute the joint probability for each feature value combination
joint_prob_dt0[, joint_prob := N / .N]
joint_prob_dt0[, joint_prob := joint_prob / sum(joint_prob)]
data.table::setkeyv(joint_prob_dt0, feature_names)

# Remove the frequency column and add an id column
joint_probability_dt <- joint_prob_dt0[, N := NULL][, id_all := .I]
} else {
# The `joint_probability_dt` is passed to explain by the user, and we do some checks.
for (i in colnames(x_explain)) {
# Check that feature name is present
is_error <- !(i %in% names(joint_probability_dt))

if (is_error > 0) {
stop(paste0(i, " is in x_explain but not in joint_probability_dt."))
}

# Check that the feature has the same levels
is_error <- !all(levels(x_explain[[i]]) %in% levels(joint_probability_dt[[i]]))

if (is_error > 0) {
stop(paste0(i, " in x_explain has factor levels than in joint_probability_dt."))
}
}

# Check that dt contains a `joint_prob` col all entries are probabilities between 0 and 1 (inclusive) and add to 1.
is_error <- !("joint_prob" %in% names(joint_probability_dt)) |
!all(joint_probability_dt$joint_prob <= 1) |
!all(joint_probability_dt$joint_prob >= 0) |
Expand All @@ -76,9 +85,11 @@ setup_approach.categorical <- function(internal,
sum(joint_prob) must equal to 1.')
}

# Add an id column
joint_probability_dt <- joint_probability_dt[, id_all := .I]
}

# Store the `joint_probability_dt` data table
internal$parameters$categorical.joint_prob_dt <- joint_probability_dt

return(internal)
Expand All @@ -90,24 +101,12 @@ setup_approach.categorical <- function(internal,
#' @rdname prepare_data
#' @export
#' @keywords internal
#' @author Annabelle Redelmeier and Lars Henry Berge Olsen
prepare_data.categorical <- function(internal, index_features = NULL, ...) {
x_train <- internal$data$x_train
x_explain <- internal$data$x_explain

joint_probability_dt <- internal$parameters$categorical.joint_prob_dt

iter <- length(internal$iter_list)

X <- internal$iter_list[[iter]]$X
S <- internal$iter_list[[iter]]$S


if (is.null(index_features)) { # 2,3
features <- X$features # list of [1], [2], [2, 3]
} else {
features <- X$features[index_features] # list of [1],
# Use a faster function when index_feature is only a single coalition, as in causal Shapley values.
if (length(index_features) == 1) {
return(prepare_data_single_coalition(internal, index_features))
}
feature_names <- internal$parameters$feature_names

# 3 id columns: id, id_coalition, and id_all
# id: for each x_explain observation
Expand All @@ -116,19 +115,25 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {
# the training data (not necessarily the ones in the explain data)


# Extract the needed objects/variables
x_explain <- internal$data$x_explain
joint_probability_dt <- internal$parameters$categorical.joint_prob_dt
feature_names <- internal$parameters$feature_names
feature_conditioned <- paste0(feature_names, "_conditioned")
feature_conditioned_id <- c(feature_conditioned, "id")

S_dt <- data.table::data.table(S)
# Extract from iterative list
iter <- length(internal$iter_list)
S <- internal$iter_list[[iter]]$S
S_dt <- data.table::data.table(S[index_features, , drop = FALSE])
S_dt[S_dt == 0] <- NA
S_dt[, id_coalition := seq_len(nrow(S_dt))]

S_dt[, id_coalition := index_features]
data.table::setnames(S_dt, c(feature_conditioned, "id_coalition"))

# (1) Compute marginal probabilities

# multiply table of probabilities nrow(S) times
joint_probability_mult <- joint_probability_dt[rep(id_all, nrow(S))]
# multiply table of probabilities length(index_features) times
joint_probability_mult <- joint_probability_dt[rep(id_all, length(index_features))]

data.table::setkeyv(joint_probability_mult, "id_all")
j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix
Expand Down Expand Up @@ -156,14 +161,10 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {

cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned]
cond_dt[, cond_prob := joint_prob / marg_prob]
cond_dt[id_coalition == 1, marg_prob := 0]
cond_dt[id_coalition == 1, cond_prob := 1]

# check marginal probabilities
cond_dt_unique <- unique(cond_dt, by = feature_conditioned)
check <- cond_dt_unique[id_coalition != 1][, .(sum_prob = sum(marg_prob)),
by = "id_coalition"
][["sum_prob"]]
check <- cond_dt_unique[id_coalition != 1][, .(sum_prob = sum(marg_prob)), by = "id_coalition"][["sum_prob"]]
if (!all(round(check) == 1)) {
print("Warning - not all marginal probabilities sum to 1. There could be a problem
with the joint probabilities. Consider checking.")
Expand All @@ -181,9 +182,7 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {
dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE]

# check conditional probabilities
check <- dt[id_coalition != 1][, .(sum_prob = sum(cond_prob)),
by = c("id_coalition", "id")
][["sum_prob"]]
check <- dt[id_coalition != 1][, .(sum_prob = sum(cond_prob)), by = c("id_coalition", "id")][["sum_prob"]]
if (!all(round(check) == 1)) {
print("Warning - not all conditional probabilities sum to 1. There could be a problem
with the joint probabilities. Consider checking.")
Expand All @@ -192,11 +191,58 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {
setnames(dt, "cond_prob", "w")
data.table::setkeyv(dt, c("id_coalition", "id"))

# here we merge so that we only return the combintations found in our actual explain data
# this merge does not change the number of rows in dt
# dt <- merge(dt, x$X[, .(id_coalition, n_features)], by = "id_coalition")
# dt[n_features %in% c(0, ncol(x_explain)), w := 1.0]
dt[id_coalition %in% c(1, 2^ncol(x_explain)), w := 1.0]
ret_col <- c("id_coalition", "id", feature_names, "w")
return(dt[id_coalition %in% index_features, mget(ret_col)])
# Return the relevant columns
return(dt[, mget(c("id_coalition", "id", feature_names, "w"))])
}

#' Compute the conditional probabilities for a single coalition for the categorical approach
#'
#' The [shapr::prepare_data.categorical()] function is slow when evaluated for a single coalition.
#' This is a bottleneck for Causal Shapley values which call said function a lot with single coalitions.
#'
#' @inheritParams default_doc
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
prepare_data_single_coalition <- function(internal, index_features) {
# if (length(index_features) != 1) stop("`index_features` must be single integer.")

# Extract the needed objects
x_explain <- internal$data$x_explain
feature_names <- internal$parameters$feature_names
joint_probability_dt <- internal$parameters$categorical.joint_prob_dt

# Extract from iterative list
iter <- length(internal$iter_list)
S <- internal$iter_list[[iter]]$S

# Add an id column to x_explain (copy as this changes `x_explain` outside the function)
x_explain_copy <- data.table::copy(x_explain)[, id := .I]

# Extract the feature names of the features we are to condition on
cond_cols <- feature_names[S[index_features, ] == 1]
cond_cols_with_id <- c("id", cond_cols)

# Extract the feature values to condition and including the id column
dt_conditional_feature_values <- x_explain_copy[, cond_cols_with_id, with = FALSE]

# Merge (right outer join) the joint_probability_dt data with the conditional feature values
results_id_coalition <- data.table::merge.data.table(joint_probability_dt,
dt_conditional_feature_values,
by = cond_cols,
allow.cartesian = TRUE
)

# Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands
results_id_coalition[, w := joint_prob / sum(joint_prob), by = id]
results_id_coalition[, c("id_all", "joint_prob") := NULL]

# Set the index_features to their correct value
results_id_coalition[, id_coalition := index_features]

# Set id_coalition and id to be the keys and the two first columns for consistency with other approaches
data.table::setkeyv(results_id_coalition, c("id_coalition", "id"))
data.table::setcolorder(results_id_coalition, c("id_coalition", "id", feature_names))

return(results_id_coalition)
}
33 changes: 28 additions & 5 deletions R/approach_copula.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,41 @@ prepare_data.copula <- function(internal, index_features, ...) {
copula.mu <- internal$parameters$copula.mu
copula.cov_mat <- internal$parameters$copula.cov_mat
copula.x_explain_gaussian_mat <- as.matrix(internal$data$copula.x_explain_gaussian)
causal_sampling <- internal$parameters$causal_sampling

# Update the number of MC samples for causal Shapley values not in the first step
causal_first_step <- isTRUE(internal$parameters$causal_first_step) # Only set when called from `prepdare_data_causal`
n_MC_samples_updated <- if (causal_sampling && !causal_first_step) n_explain else n_MC_samples

# Update the `copula.x_explain_gaussian_mat` for causal Shapley values not in the first step
LHBO marked this conversation as resolved.
Show resolved Hide resolved
if (causal_sampling && !causal_first_step) {
copula.x_explain_gaussian <- apply(
X = rbind(x_explain_mat, x_train_mat),
MARGIN = 2,
FUN = gaussian_transform_separate,
n_y = nrow(x_explain_mat)
)
if (is.null(dim(copula.x_explain_gaussian))) copula.x_explain_gaussian <- t(as.matrix(copula.x_explain_gaussian))
copula.x_explain_gaussian_mat <- as.matrix(copula.x_explain_gaussian)
}

iter <- length(internal$iter_list)

S <- internal$iter_list[[iter]]$S[index_features, , drop = FALSE]

# Generate the MC samples from N(0, 1)
MC_samples_mat <- matrix(rnorm(n_MC_samples * n_features), nrow = n_MC_samples, ncol = n_features)
MC_samples_mat <- matrix(rnorm(n_MC_samples_updated * n_features), nrow = n_MC_samples_updated, ncol = n_features)

# Determine which copula data generating function to use
prepare_data_copula <-
if (causal_sampling && !causal_first_step) prepare_data_copula_cpp_caus else prepare_data_copula_cpp

# Use C++ to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}), for all coalitions and explicands,
# and then transforming them back to the original scale using the inverse Gaussian transform in C++.
# The object `dt` is a 3D array of dimension (n_MC_samples, n_explain * n_coalitions, n_features).
dt <- prepare_data_copula_cpp(
# The `dt` object is a 3D array of dimension (n_MC_samples, n_explain * n_coalitions, n_features) for regular
# Shapley and in the first step for causal Shapley values. For later steps in the causal Shapley value framework,
# the `dt` object is a matrix of dimension (n_explain * n_coalitions, n_features).
dt <- prepare_data_copula(
MC_samples_mat = MC_samples_mat,
x_explain_mat = x_explain_mat,
x_explain_gaussian_mat = copula.x_explain_gaussian_mat,
Expand All @@ -78,8 +101,8 @@ prepare_data.copula <- function(internal, index_features, ...) {
cov_mat = copula.cov_mat
)

# Reshape `dt` to a 2D array of dimension (n_MC_samples * n_explain * n_coalitions, n_features).
dim(dt) <- c(n_coalitions_now * n_explain * n_MC_samples, n_features)
# Reshape `dt` to a 2D array of dimension (n_MC_samples * n_explain * n_coalitions, n_features) when needed
if (!causal_sampling || causal_first_step) dim(dt) <- c(n_coalitions_now * n_explain * n_MC_samples, n_features)

# Convert to a data.table and add extra identification columns
dt <- data.table::as.data.table(dt)
Expand Down
Loading