Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Causal and asymmetric Shapley values implementation #273

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
9220773
Add code for computing asymmetric (Frye et al.) and causal (Heskes et…
igbucur Aug 21, 2021
53da84b
Bugfix variable name in explain.causal
igbucur Aug 21, 2021
90ee920
Add references to explain.causal
igbucur Aug 21, 2021
cc8616b
Update NAMESPACE and documentation for asymmetric and causal Shapley …
igbucur Aug 21, 2021
2707b6e
Bugfix in explain.causal, explainer_x_test removed in shapr master
igbucur Aug 21, 2021
17de7e9
Bugfix in prepare_data.causal where check for NULL value failed.
igbucur Aug 21, 2021
a13ec74
Bugfix moved all checks for the default value of causal_ordering to s…
igbucur Aug 22, 2021
fe0b08a
Merge branch 'master' of https://github.com/NorskRegnesentral/shapr i…
igbucur Aug 25, 2021
cb40de1
Add basic tests for sample_causal and minor fix for the fully conditi…
igbucur Aug 25, 2021
c2e1d67
Replace default causal_ordering value in feature_exact and feature_co…
igbucur Aug 25, 2021
7ea3b1d
Add extra warnings for cases that are not yet implemented.
igbucur Aug 25, 2021
429bfee
Add basic tests to cover the asymmetric case in features functions
igbucur Aug 25, 2021
d7f7de3
Update known values for test objects after adding asymmetric and caus…
igbucur Aug 25, 2021
8597e86
Update docs feature_combinations
igbucur Aug 26, 2021
8e1588c
minor fix in docs, line too long
igbucur Aug 26, 2021
317bb3c
Merge branch 'master' of https://github.com/NorskRegnesentral/shapr i…
igbucur Oct 5, 2021
dc6fbc4
Updated new test objects with asymmetric and causal_ordering components
igbucur Oct 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(explain,causal)
S3method(explain,combined)
S3method(explain,copula)
S3method(explain,ctree)
Expand All @@ -25,6 +26,7 @@ S3method(predict_model,glm)
S3method(predict_model,lm)
S3method(predict_model,ranger)
S3method(predict_model,xgb.Booster)
S3method(prepare_data,causal)
S3method(prepare_data,copula)
S3method(prepare_data,ctree)
S3method(prepare_data,empirical)
Expand Down
73 changes: 67 additions & 6 deletions R/explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ explain <- function(x, explainer, approach, prediction_zero,
if (!(is.vector(approach) &&
is.atomic(approach) &&
(length(approach) == 1 | length(approach) == length(explainer$feature_list$labels)) &&
all(is.element(approach, c("empirical", "gaussian", "copula", "ctree", "independence"))))
all(is.element(approach, c("empirical", "gaussian", "causal", "copula", "ctree", "independence"))))
) {
stop(
paste(
"It seems that you passed a non-valid value for approach.",
"It should be either 'empirical', 'gaussian', 'copula', 'ctree', 'independence' or",
"It should be either 'empirical', 'gaussian', 'copula', 'ctree', 'independence', 'causal' or",
"a vector of length=ncol(x) with only the above characters."
)
)
Expand Down Expand Up @@ -284,11 +284,11 @@ explain.empirical <- function(x, explainer, approach, prediction_zero,

#' @param mu Numeric vector. (Optional) Containing the mean of the data generating distribution.
#' If \code{NULL} the expected values are estimated from the data. Note that this is only used
#' when \code{approach = "gaussian"}.
#' when \code{approach = "gaussian"} or \code{approach = "causal"}.
#'
#' @param cov_mat Numeric matrix. (Optional) Containing the covariance matrix of the data
#' generating distribution. \code{NULL} means it is estimated from the data if needed
#' (in the Gaussian approach).
#' (in the Gaussian or causal approach).
#'
#' @rdname explain
#'
Expand All @@ -304,7 +304,6 @@ explain.gaussian <- function(x, explainer, approach, prediction_zero, n_samples
explainer$approach <- approach
explainer$n_samples <- n_samples


# If mu is not provided directly, use mean of training data
if (is.null(mu)) {
explainer$mu <- unname(colMeans(explainer$x_train))
Expand All @@ -331,7 +330,63 @@ explain.gaussian <- function(x, explainer, approach, prediction_zero, n_samples
}


#' @param confounding Logical vector that specifies whether we assume confounding or not.
#' If a single value is specified, then the assumption is set globally for all components.
#' Otherwise, the logical vector must contain a value for each component in the ordering.
#'
#' @author Tom Heskes, Ioan Gabriel Bucur
#'
#' @references
#' Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values:
#' incorporating causal knowledge into model-agnostic explainability.
#' Advances in Neural Information Processing Systems, 33.
#'
#' Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal Shapley Values:
#' Exploiting Causal Knowledge to Explain Individual Predictions of Complex Models.
#' Advances in Neural Information Processing Systems, 33.
#'
#' @rdname explain
#'
#' @export
explain.causal <- function(x, explainer, approach, prediction_zero, n_samples = 1e3,
mu = NULL, cov_mat = NULL, confounding = FALSE, ...) {

# Add arguments to explainer object
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$approach <- approach
explainer$n_samples <- n_samples

# If mu is not provided directly, use mean of training data
if (is.null(mu)) {
explainer$mu <- unname(colMeans(explainer$x_train))
} else {
explainer$mu <- mu
}

# If cov_mat is not provided directly, use sample covariance of training data
if (is.null(cov_mat)) {
cov_mat <- stats::cov(explainer$x_train)
}

# Make sure that covariance matrix is positive-definite
eigen_values <- eigen(cov_mat)$values
if (any(eigen_values <= 1e-06)) {
explainer$cov_mat <- as.matrix(Matrix::nearPD(cov_mat)$mat)
} else {
explainer$cov_mat <- cov_mat
}

explainer$confounding <- confounding

# Generate data
dt <- prepare_data(explainer, ...)
if (!is.null(explainer$return)) return(dt)

# Predict
r <- prediction(dt, prediction_zero, explainer)

return(r)
}


#' @rdname explain
Expand Down Expand Up @@ -478,7 +533,7 @@ explain.combined <- function(x, explainer, approach, prediction_zero, n_samples
#' \code{length(n_features) <= 2^m}, where \code{m} equals the number
#' of features.
#' @param approach Character vector of length \code{m}. All elements should be
#' either \code{"empirical"}, \code{"gaussian"} or \code{"copula"}.
#' either \code{"empirical"}, \code{"causal"}, \code{"gaussian"} or \code{"copula"}.
#'
#' @keywords internal
#'
Expand All @@ -502,6 +557,12 @@ get_list_approaches <- function(n_features, approach) {
l$empirical <- which(n_features %in% x)
}

x <- which(approach == "causal")
if (length(x) > 0) {
if (approach[1] == "causal") x <- c(0, x)
l$causal <- which(n_features %in% x)
}

x <- which(approach == "gaussian")
if (length(x) > 0) {
if (approach[1] == "gaussian") x <- c(0, x)
Expand Down
39 changes: 36 additions & 3 deletions R/features.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
#' @param group_num List. Contains vector of integers indicating the feature numbers for the
#' different groups.
#'
#' @param asymmetric Logical. The flag specifies whether we want to compute
#' asymmetric Shapley values. If so, a causal ordering also needs to be specified
#' and we only consider variable permutations with the given causal ordering.
#'
#' @param causal_ordering List. Contains vectors specifying (partial) causal ordering.
#' Each element in the list is a component in the order, which can contain one
#' or more variable indices in a vector. For example, in list(1, c(2, 3)),
#' 2 > 1 and 3 > 1, but 2 and 3 are not comparable.
#'
#' @return A data.table that contains the following columns:
#' \describe{
#' \item{id_combination}{Positive integer. Represents a unique key for each combination. Note that the table
Expand All @@ -35,7 +44,8 @@
#'
#' # Subsample of combinations
#' x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2)
feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, group_num = NULL) {
feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6,
group_num = NULL, asymmetric = FALSE, causal_ordering = list(1:m)) {


m_group <- length(group_num) # The number of groups
Expand Down Expand Up @@ -103,8 +113,14 @@ feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_z
if (m_group == 0) {
# Here if feature-wise Shapley values
if (exact) {
dt <- feature_exact(m, weight_zero_m)
dt <- feature_exact(m, weight_zero_m, asymmetric, causal_ordering)
} else {
if (asymmetric) {
cat(paste0(
"Input asymmetric = TRUE is not supported in combination with exact = FALSE.\n",
"Changing to asymmetric = FALSE"
))
}
dt <- feature_not_exact(m, n_combinations, weight_zero_m)
stopifnot(
data.table::is.data.table(dt),
Expand All @@ -126,12 +142,21 @@ feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_z
p <- NULL # due to NSE notes in R CMD check
dt[, p := NULL]
}

# TODO: change flag explicitly somewhere?
if (asymmetric) {
cat(paste0(
"Input asymmetric = TRUE is not supported for group-wise Shapley values.\n",
"Changing to asymmetric = FALSE"
))
}
}
return(dt)
}

#' @keywords internal
feature_exact <- function(m, weight_zero_m = 10^6) {
feature_exact <- function(m, weight_zero_m = 10^6, asymmetric = FALSE, causal_ordering = list(1:m)) {

features <- id_combination <- n_features <- shapley_weight <- N <- NULL # due to NSE notes in R CMD check

dt <- data.table::data.table(id_combination = seq(2^m))
Expand All @@ -141,6 +166,14 @@ feature_exact <- function(m, weight_zero_m = 10^6) {
dt[, N := .N, n_features]
dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = n_features, weight_zero_m)]

if (asymmetric) {

# Filter out the features that do not agree with the order
dt <- dt[sapply(dt$features, respects_order, causal_ordering), ]
dt[, N := .(count = .N), by = n_features]
dt[, shapley_weight := .(shapley_weights(m, N, n_features))]
}

return(dt)
}

Expand Down
53 changes: 53 additions & 0 deletions R/observations.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ prepare_data.gaussian <- function(x, index_features = NULL, ...) {
features <- x$X$features[index_features]
}

# For asymmetric Shapley values, we filter out the features inconsistent with the causal ordering.
if (x$asymmetric) {
features <- features[sapply(features, respects_order, causal_ordering = x$causal_ordering)]
}

for (i in seq(n_xtest)) {
l <- lapply(
X = features,
Expand All @@ -253,6 +258,54 @@ prepare_data.gaussian <- function(x, index_features = NULL, ...) {



#' @rdname prepare_data
#' @export
#'
#' @author Tom Heskes, Ioan Gabriel Bucur
prepare_data.causal <- function(x, seed = 1, index_features = NULL, ...) {

id <- id_combination <- w <- NULL # due to NSE notes in R CMD check

n_xtest <- nrow(x$x_test)
dt_l <- list()
if (!is.null(seed)) set.seed(seed)
if (is.null(index_features)) {
features <- x$X$features
} else {
features <- x$X$features[index_features]
}

# For asymmetric Shapley values, we filter out the features inconsistent with the causal ordering.
if (x$asymmetric) {
features <- features[sapply(features, respects_order, x$causal_ordering)]
}

for (i in seq(n_xtest)) {

l <- lapply(
X = features,
FUN = sample_causal,
n_samples = x$n_samples,
mu = x$mu,
cov_mat = x$cov_mat,
m = ncol(x$x_test),
x_test = x$x_test[i, , drop = FALSE],
causal_ordering = x$causal_ordering,
confounding = x$confounding
)

dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
dt_l[[i]][, w := 1 / x$n_samples]
dt_l[[i]][, id := i]
if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
}

dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
dt[id_combination %in% c(1, 2^ncol(x$x_test)), w := 1.0]
return(dt)
}


#' @rdname prepare_data
#' @export
prepare_data.copula <- function(x, index_features = NULL, ...) {
Expand Down
Loading