diff --git a/NAMESPACE b/NAMESPACE index c808b7e3b..9f334230c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +S3method(explain,causal) S3method(explain,combined) S3method(explain,copula) S3method(explain,ctree) @@ -25,6 +26,7 @@ S3method(predict_model,glm) S3method(predict_model,lm) S3method(predict_model,ranger) S3method(predict_model,xgb.Booster) +S3method(prepare_data,causal) S3method(prepare_data,copula) S3method(prepare_data,ctree) S3method(prepare_data,empirical) diff --git a/R/explanation.R b/R/explanation.R index 05fb91597..19c7af115 100644 --- a/R/explanation.R +++ b/R/explanation.R @@ -167,12 +167,12 @@ explain <- function(x, explainer, approach, prediction_zero, if (!(is.vector(approach) && is.atomic(approach) && (length(approach) == 1 | length(approach) == length(explainer$feature_list$labels)) && - all(is.element(approach, c("empirical", "gaussian", "copula", "ctree", "independence")))) + all(is.element(approach, c("empirical", "gaussian", "causal", "copula", "ctree", "independence")))) ) { stop( paste( "It seems that you passed a non-valid value for approach.", - "It should be either 'empirical', 'gaussian', 'copula', 'ctree', 'independence' or", + "It should be either 'empirical', 'gaussian', 'copula', 'ctree', 'independence', 'causal' or", "a vector of length=ncol(x) with only the above characters." ) ) @@ -284,11 +284,11 @@ explain.empirical <- function(x, explainer, approach, prediction_zero, #' @param mu Numeric vector. (Optional) Containing the mean of the data generating distribution. #' If \code{NULL} the expected values are estimated from the data. Note that this is only used -#' when \code{approach = "gaussian"}. +#' when \code{approach = "gaussian"} or \code{approach = "causal"}. #' #' @param cov_mat Numeric matrix. (Optional) Containing the covariance matrix of the data #' generating distribution. \code{NULL} means it is estimated from the data if needed -#' (in the Gaussian approach). +#' (in the Gaussian or causal approach). #' #' @rdname explain #' @@ -304,7 +304,6 @@ explain.gaussian <- function(x, explainer, approach, prediction_zero, n_samples explainer$approach <- approach explainer$n_samples <- n_samples - # If mu is not provided directly, use mean of training data if (is.null(mu)) { explainer$mu <- unname(colMeans(explainer$x_train)) @@ -331,7 +330,63 @@ explain.gaussian <- function(x, explainer, approach, prediction_zero, n_samples } +#' @param confounding Logical vector that specifies whether we assume confounding or not. +#' If a single value is specified, then the assumption is set globally for all components. +#' Otherwise, the logical vector must contain a value for each component in the ordering. +#' +#' @author Tom Heskes, Ioan Gabriel Bucur +#' +#' @references +#' Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +#' incorporating causal knowledge into model-agnostic explainability. +#' Advances in Neural Information Processing Systems, 33. +#' +#' Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal Shapley Values: +#' Exploiting Causal Knowledge to Explain Individual Predictions of Complex Models. +#' Advances in Neural Information Processing Systems, 33. +#' +#' @rdname explain +#' +#' @export +explain.causal <- function(x, explainer, approach, prediction_zero, n_samples = 1e3, + mu = NULL, cov_mat = NULL, confounding = FALSE, ...) { + # Add arguments to explainer object + explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt) + explainer$approach <- approach + explainer$n_samples <- n_samples + + # If mu is not provided directly, use mean of training data + if (is.null(mu)) { + explainer$mu <- unname(colMeans(explainer$x_train)) + } else { + explainer$mu <- mu + } + + # If cov_mat is not provided directly, use sample covariance of training data + if (is.null(cov_mat)) { + cov_mat <- stats::cov(explainer$x_train) + } + + # Make sure that covariance matrix is positive-definite + eigen_values <- eigen(cov_mat)$values + if (any(eigen_values <= 1e-06)) { + explainer$cov_mat <- as.matrix(Matrix::nearPD(cov_mat)$mat) + } else { + explainer$cov_mat <- cov_mat + } + + explainer$confounding <- confounding + + # Generate data + dt <- prepare_data(explainer, ...) + if (!is.null(explainer$return)) return(dt) + + # Predict + r <- prediction(dt, prediction_zero, explainer) + + return(r) +} #' @rdname explain @@ -478,7 +533,7 @@ explain.combined <- function(x, explainer, approach, prediction_zero, n_samples #' \code{length(n_features) <= 2^m}, where \code{m} equals the number #' of features. #' @param approach Character vector of length \code{m}. All elements should be -#' either \code{"empirical"}, \code{"gaussian"} or \code{"copula"}. +#' either \code{"empirical"}, \code{"causal"}, \code{"gaussian"} or \code{"copula"}. #' #' @keywords internal #' @@ -502,6 +557,12 @@ get_list_approaches <- function(n_features, approach) { l$empirical <- which(n_features %in% x) } + x <- which(approach == "causal") + if (length(x) > 0) { + if (approach[1] == "causal") x <- c(0, x) + l$causal <- which(n_features %in% x) + } + x <- which(approach == "gaussian") if (length(x) > 0) { if (approach[1] == "gaussian") x <- c(0, x) diff --git a/R/features.R b/R/features.R index 6e342c2ca..66280d382 100644 --- a/R/features.R +++ b/R/features.R @@ -11,6 +11,15 @@ #' @param group_num List. Contains vector of integers indicating the feature numbers for the #' different groups. #' +#' @param asymmetric Logical. The flag specifies whether we want to compute +#' asymmetric Shapley values. If so, a causal ordering also needs to be specified +#' and we only consider variable permutations with the given causal ordering. +#' +#' @param causal_ordering List. Contains vectors specifying (partial) causal ordering. +#' Each element in the list is a component in the order, which can contain one +#' or more variable indices in a vector. For example, in list(1, c(2, 3)), +#' 2 > 1 and 3 > 1, but 2 and 3 are not comparable. +#' #' @return A data.table that contains the following columns: #' \describe{ #' \item{id_combination}{Positive integer. Represents a unique key for each combination. Note that the table @@ -35,7 +44,8 @@ #' #' # Subsample of combinations #' x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2) -feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, group_num = NULL) { +feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, + group_num = NULL, asymmetric = FALSE, causal_ordering = list(1:m)) { m_group <- length(group_num) # The number of groups @@ -103,8 +113,14 @@ feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_z if (m_group == 0) { # Here if feature-wise Shapley values if (exact) { - dt <- feature_exact(m, weight_zero_m) + dt <- feature_exact(m, weight_zero_m, asymmetric, causal_ordering) } else { + if (asymmetric) { + cat(paste0( + "Input asymmetric = TRUE is not supported in combination with exact = FALSE.\n", + "Changing to asymmetric = FALSE" + )) + } dt <- feature_not_exact(m, n_combinations, weight_zero_m) stopifnot( data.table::is.data.table(dt), @@ -126,12 +142,21 @@ feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_z p <- NULL # due to NSE notes in R CMD check dt[, p := NULL] } + + # TODO: change flag explicitly somewhere? + if (asymmetric) { + cat(paste0( + "Input asymmetric = TRUE is not supported for group-wise Shapley values.\n", + "Changing to asymmetric = FALSE" + )) + } } return(dt) } #' @keywords internal -feature_exact <- function(m, weight_zero_m = 10^6) { +feature_exact <- function(m, weight_zero_m = 10^6, asymmetric = FALSE, causal_ordering = list(1:m)) { + features <- id_combination <- n_features <- shapley_weight <- N <- NULL # due to NSE notes in R CMD check dt <- data.table::data.table(id_combination = seq(2^m)) @@ -141,6 +166,14 @@ feature_exact <- function(m, weight_zero_m = 10^6) { dt[, N := .N, n_features] dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = n_features, weight_zero_m)] + if (asymmetric) { + + # Filter out the features that do not agree with the order + dt <- dt[sapply(dt$features, respects_order, causal_ordering), ] + dt[, N := .(count = .N), by = n_features] + dt[, shapley_weight := .(shapley_weights(m, N, n_features))] + } + return(dt) } diff --git a/R/observations.R b/R/observations.R index e550026c4..1bf90bf3f 100644 --- a/R/observations.R +++ b/R/observations.R @@ -230,6 +230,11 @@ prepare_data.gaussian <- function(x, index_features = NULL, ...) { features <- x$X$features[index_features] } + # For asymmetric Shapley values, we filter out the features inconsistent with the causal ordering. + if (x$asymmetric) { + features <- features[sapply(features, respects_order, causal_ordering = x$causal_ordering)] + } + for (i in seq(n_xtest)) { l <- lapply( X = features, @@ -253,6 +258,54 @@ prepare_data.gaussian <- function(x, index_features = NULL, ...) { +#' @rdname prepare_data +#' @export +#' +#' @author Tom Heskes, Ioan Gabriel Bucur +prepare_data.causal <- function(x, seed = 1, index_features = NULL, ...) { + + id <- id_combination <- w <- NULL # due to NSE notes in R CMD check + + n_xtest <- nrow(x$x_test) + dt_l <- list() + if (!is.null(seed)) set.seed(seed) + if (is.null(index_features)) { + features <- x$X$features + } else { + features <- x$X$features[index_features] + } + + # For asymmetric Shapley values, we filter out the features inconsistent with the causal ordering. + if (x$asymmetric) { + features <- features[sapply(features, respects_order, x$causal_ordering)] + } + + for (i in seq(n_xtest)) { + + l <- lapply( + X = features, + FUN = sample_causal, + n_samples = x$n_samples, + mu = x$mu, + cov_mat = x$cov_mat, + m = ncol(x$x_test), + x_test = x$x_test[i, , drop = FALSE], + causal_ordering = x$causal_ordering, + confounding = x$confounding + ) + + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") + dt_l[[i]][, w := 1 / x$n_samples] + dt_l[[i]][, id := i] + if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + } + + dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) + dt[id_combination %in% c(1, 2^ncol(x$x_test)), w := 1.0] + return(dt) +} + + #' @rdname prepare_data #' @export prepare_data.copula <- function(x, index_features = NULL, ...) { diff --git a/R/sampling.R b/R/sampling.R index b31840519..23995dbf6 100644 --- a/R/sampling.R +++ b/R/sampling.R @@ -60,7 +60,7 @@ sample_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_test) { # Check input stopifnot(is.matrix(x_test)) - # Handles the unconditional and full conditional separtely when predicting + # Handles the unconditional and full conditional separately when predicting cnms <- colnames(x_test) if (length(index_given) %in% c(0, m)) { return(data.table::as.data.table(x_test)) @@ -91,6 +91,112 @@ sample_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_test) { return(as.data.table(ret)) } + +#' Sample conditional Gaussian variables following a causal chain graph with do-calculus. +#' +#' @inheritParams sample_copula +#' +#' @param causal_ordering List of vectors specifying (partial) causal ordering. Each element in +#' the list is a component in the order, which can contain one or more variable indices in a vector. +#' For example, in list(1, c(2, 3)), 2 > 1 and 3 > 1, but 2 and 3 are not comparable. +#' @param confounding Logical vector specifying which variables are affected by confounding. +#' Confounding must be speficied globally with a single TRUE / FALSE value for all components, +#' or separately for each causal component in the causal ordering. +#' +#' @return data.table +#' +#' @keywords internal +#' +#' @author Tom Heskes, Ioan Gabriel Bucur +#' +#' @examples +#' m <- 10 +#' n_samples <- 50 +#' mu <- rep(1, m) +#' cov_mat <- cov(matrix(rnorm(n_samples * m), n_samples, m)) +#' x_test <- matrix(MASS::mvrnorm(1, mu, cov_mat), nrow = 1) +#' cnms <- paste0("x", seq(m)) +#' colnames(x_test) <- cnms +#' index_given <- c(4, 7) +#' causal_ordering <- list(c(1:3), c(4:6), c(7:10)) +#' confounding <- c(TRUE, FALSE, TRUE) +#' r <- shapr:::sample_causal( +#' index_given, n_samples, mu, cov_mat, m, x_test, +#' causal_ordering, confounding +#' ) +sample_causal <- function(index_given, n_samples, mu, cov_mat, m, x_test, + causal_ordering, confounding) { + + # Check input + stopifnot(is.matrix(x_test)) + stopifnot(is.list(causal_ordering)) + stopifnot(is.logical(confounding)) + + if (length(confounding) > 1 && length(confounding) != length(causal_ordering)) { + stop("Confounding must be specified globally (one value for all components), or separately for each component in the causal ordering.") + } + + # In case of global confounding value, replicate it across components. + if (length(confounding) == 1) { + confounding <- rep(confounding, length(causal_ordering)) + } + + if (!base::setequal(unlist(causal_ordering), seq(m))) { + stop(paste("Incomplete or incorrect partial causal_ordering specified for", m, "variables")) + } + + # Handles the unconditional and full conditional separately when predicting + if (length(index_given) %in% c(0, m)) { + return(data.table::as.data.table(x_test)) + } + + dependent_ind <- setdiff(1:length(mu), index_given) + xall <- matrix(NA, ncol = m, nrow = n_samples) + xall[, index_given] <- rep(x_test[index_given], each = n_samples) + + for(i in seq(length(causal_ordering))) { + + # check overlap between dependent_ind and component + to_be_sampled <- intersect(causal_ordering[[i]], dependent_ind) + if (length(to_be_sampled) > 0) { + # condition upon all variables in ancestor components + to_be_conditioned <- unlist(causal_ordering[0:(i-1)]) + + # back to conditioning if confounding is FALSE or no conditioning if confounding is TRUE + if (!confounding[i]) { + # add intervened variables in the same component + to_be_conditioned <- union(intersect(causal_ordering[[i]], index_given), to_be_conditioned) + } + if (length(to_be_conditioned) == 0) { + # draw new samples from marginal distribution + newsamples <- mvnfast::rmvn(n_samples, mu=mu[to_be_sampled], sigma=as.matrix(cov_mat[to_be_sampled,to_be_sampled])) + } else { + + # compute conditional Gaussian + C <- cov_mat[to_be_sampled,to_be_conditioned, drop=FALSE] + D <- cov_mat[to_be_conditioned, to_be_conditioned] + CDinv <- C %*% solve(D) + cVar <- cov_mat[to_be_sampled,to_be_sampled] - CDinv %*% t(C) + if (!isSymmetric(cVar)) { + cVar <- Matrix::symmpart(cVar) + } + + # draw new samples from conditional distribution + mu_sample <- matrix(rep(mu[to_be_sampled],each=n_samples),nrow=n_samples) + mu_cond <- matrix(rep(mu[to_be_conditioned],each=n_samples),nrow=n_samples) + cMU <- mu_sample + t(CDinv %*% t(xall[,to_be_conditioned] - mu_cond)) + newsamples <- mvnfast::rmvn(n_samples, mu=matrix(0,1,length(to_be_sampled)), sigma=as.matrix(cVar)) + newsamples <- newsamples + cMU + + } + xall[,to_be_sampled] <- newsamples + } + } + + colnames(xall) <- colnames(x_test) + return(as.data.table(xall)) +} + #' Helper function to sample a combination of training and testing rows, which does not risk #' getting the same observation twice. Need to improve this help file. #' diff --git a/R/shapley.R b/R/shapley.R index 858cc4469..b9087c10b 100644 --- a/R/shapley.R +++ b/R/shapley.R @@ -76,6 +76,15 @@ weight_matrix <- function(X, normalize_W_weights = TRUE, is_groupwise = FALSE) { #' the number of groups. The list element contains character vectors with the features included #' in each of the different groups. #' +#' @param asymmetric Logical. The flag specifies whether we want to compute +#' asymmetric Shapley values. If so, a causal ordering also needs to be specified +#' and we only consider variable permutations with the given causal ordering. +#' +#' @param causal_ordering List. Contains vectors specifying (partial) causal ordering. +#' Each element in the list is a component in the order, which can contain one +#' or more variable indices in a vector. For example, in list(1, c(2, 3)), +#' 2 > 1 and 3 > 1, but 2 and 3 are not comparable. +#' #' @return Named list that contains the following items: #' \describe{ #' \item{exact}{Boolean. Equals \code{TRUE} if \code{n_combinations = NULL} or @@ -140,10 +149,8 @@ weight_matrix <- function(X, normalize_W_weights = TRUE, is_groupwise = FALSE) { #' print(nrow(explainer_group$X)) #' # 4 (which equals 2^(#groups)) #' } -shapr <- function(x, - model, - n_combinations = NULL, - group = NULL) { +shapr <- function(x, model, n_combinations = NULL, group = NULL, + asymmetric = FALSE, causal_ordering = NULL) { # Checks input argument if (!is.matrix(x) & !is.data.frame(x)) { @@ -154,7 +161,6 @@ shapr <- function(x, explainer <- as.list(environment()) explainer$exact <- ifelse(is.null(n_combinations), TRUE, FALSE) - # Check features of training data against model specification feature_list_model <- get_model_specs(model) @@ -163,14 +169,11 @@ shapr <- function(x, feature_list = feature_list_model ) - - x_train <- processed_list$x_dt updated_feature_list <- processed_list$updated_feature_list explainer$n_features <- ncol(x_train) - # Processes groups if specified. Otherwise do nothing is_groupwise <- !is.null(group) if (is_groupwise) { @@ -203,7 +206,9 @@ shapr <- function(x, exact = explainer$exact, n_combinations = n_combinations, weight_zero_m = 10^6, - group_num = group_num + group_num = group_num, + asymmetric = asymmetric, + causal_ordering = causal_ordering ) # Get weighted matrix ---------------- @@ -224,6 +229,12 @@ shapr <- function(x, explainer$exact <- TRUE } + # If no causal ordering is specified, put all variables in a single component, + # in the order in which they are stored in the data.frame / matrix. + if (is.null(causal_ordering)) { + causal_ordering <- list(1:length(updated_feature_list$labels)) + } + explainer$S <- feature_matrix explainer$W <- weighted_mat explainer$X <- dt_combinations @@ -233,6 +244,8 @@ shapr <- function(x, explainer$group <- group explainer$is_groupwise <- is_groupwise explainer$n_combinations <- nrow(feature_matrix) + explainer$asymmetric <- asymmetric + explainer$causal_ordering <- causal_ordering attr(explainer, "class") <- c("explainer", "list") diff --git a/R/utils.R b/R/utils.R index 6788d7851..b5418eecf 100644 --- a/R/utils.R +++ b/R/utils.R @@ -6,3 +6,37 @@ unique_features <- function(x) { ) ) } + +#' Helper function that checks a conditioning index against a particular causal ordering. +#' +#' @param index Integer conditioning index to check against the causal ordering. +#' @param causal_ordering List of vectors specifying (partial) causal ordering. Each element in +#' the list is a component in the order, which can contain one or more variable indices in a vector. +#' For example, in list(1, c(2, 3)), 2 > 1 and 3 > 1, but 2 and 3 are not comparable. +#' +#' @keywords internal +#' +#' @author Tom Heskes, Ioan Gabriel Bucur +respects_order <- function(index, causal_ordering) { + + for (i in index) { + + idx_position <- Position(function(x) i %in% x, causal_ordering, nomatch = 0) + + stopifnot(idx_position > 0) # It should always be in the causal_ordering + + # check for precedents (only relevant if not root set) + if (idx_position > 1) { + + # get precedents + precedents <- unlist(causal_ordering[1:(idx_position-1)]) + + # all precedents must be in index + if (!setequal(precedents, intersect(precedents, index))) { + return(FALSE) + } + } + } + + return(TRUE) +} diff --git a/man/explain.Rd b/man/explain.Rd index 1dbb5c216..c00cd34ab 100644 --- a/man/explain.Rd +++ b/man/explain.Rd @@ -5,6 +5,7 @@ \alias{explain.independence} \alias{explain.empirical} \alias{explain.gaussian} +\alias{explain.causal} \alias{explain.copula} \alias{explain.ctree} \alias{explain.combined} @@ -63,6 +64,18 @@ explain( ... ) +\method{explain}{causal}( + x, + explainer, + approach, + prediction_zero, + n_samples = 1000, + mu = NULL, + cov_mat = NULL, + confounding = FALSE, + ... +) + \method{explain}{copula}( x, explainer, @@ -169,11 +182,15 @@ is only applicable when \code{approach = "empirical"}, and \code{type} is either \item{cov_mat}{Numeric matrix. (Optional) Containing the covariance matrix of the data generating distribution. \code{NULL} means it is estimated from the data if needed -(in the Gaussian approach).} +(in the Gaussian or causal approach).} \item{mu}{Numeric vector. (Optional) Containing the mean of the data generating distribution. If \code{NULL} the expected values are estimated from the data. Note that this is only used -when \code{approach = "gaussian"}.} +when \code{approach = "gaussian"} or \code{approach = "causal"}.} + +\item{confounding}{Logical vector that specifies whether we assume confounding or not. +If a single value is specified, then the assumption is set globally for all components. +Otherwise, the logical vector must contain a value for each component in the ordering.} \item{mincriterion}{Numeric value or vector where length of vector is the number of features in model. Value is equal to 1 - alpha where alpha is the nominal level of the conditional @@ -316,7 +333,17 @@ if (requireNamespace("MASS", quietly = TRUE)) { \references{ Aas, K., Jullum, M., & Løland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. + +Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in Neural Information Processing Systems, 33. + +Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal Shapley Values: +Exploiting Causal Knowledge to Explain Individual Predictions of Complex Models. +Advances in Neural Information Processing Systems, 33. } \author{ Camilla Lingjaerde, Nikolai Sellereite, Martin Jullum, Annabelle Redelmeier + +Tom Heskes, Ioan Gabriel Bucur } diff --git a/man/feature_combinations.Rd b/man/feature_combinations.Rd index 810d2b865..b5d23a88e 100644 --- a/man/feature_combinations.Rd +++ b/man/feature_combinations.Rd @@ -9,7 +9,9 @@ feature_combinations( exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, - group_num = NULL + group_num = NULL, + asymmetric = FALSE, + causal_ordering = list(1:m) ) } \arguments{ @@ -27,6 +29,15 @@ weights when doing numerical operations.} \item{group_num}{List. Contains vector of integers indicating the feature numbers for the different groups.} + +\item{asymmetric}{Logical. The flag specifies whether we want to compute +asymmetric Shapley values. If so, a causal ordering also needs to be specified +and we only consider variable permutations with the given causal ordering.} + +\item{causal_ordering}{List. Contains vectors specifying (partial) causal ordering. +Each element in the list is a component in the order, which can contain one +or more variable indices in a vector. For example, in list(1, c(2, 3)), +2 > 1 and 3 > 1, but 2 and 3 are not comparable.} } \value{ A data.table that contains the following columns: diff --git a/man/get_list_approaches.Rd b/man/get_list_approaches.Rd index 06218ec3f..997b34467 100644 --- a/man/get_list_approaches.Rd +++ b/man/get_list_approaches.Rd @@ -12,7 +12,7 @@ get_list_approaches(n_features, approach) of features.} \item{approach}{Character vector of length \code{m}. All elements should be -either \code{"empirical"}, \code{"gaussian"} or \code{"copula"}.} +either \code{"empirical"}, \code{"causal"}, \code{"gaussian"} or \code{"copula"}.} } \value{ List diff --git a/man/observation_impute.Rd b/man/observation_impute.Rd index dd48d6c1f..9eb61aaf5 100644 --- a/man/observation_impute.Rd +++ b/man/observation_impute.Rd @@ -42,6 +42,14 @@ Generate permutations of training data using test observations \references{ Aas, K., Jullum, M., & Løland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. + +Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in Neural Information Processing Systems, 33. + +Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal Shapley Values: +Exploiting Causal Knowledge to Explain Individual Predictions of Complex Models. +Advances in Neural Information Processing Systems, 33. } \author{ Nikolai Sellereite diff --git a/man/prepare_data.Rd b/man/prepare_data.Rd index a4857a3e7..88c21fe91 100644 --- a/man/prepare_data.Rd +++ b/man/prepare_data.Rd @@ -5,6 +5,7 @@ \alias{prepare_data.independence} \alias{prepare_data.empirical} \alias{prepare_data.gaussian} +\alias{prepare_data.causal} \alias{prepare_data.copula} \alias{prepare_data.ctree} \title{Generate data used for predictions} @@ -17,6 +18,8 @@ prepare_data(x, ...) \method{prepare_data}{gaussian}(x, index_features = NULL, ...) +\method{prepare_data}{causal}(x, seed = 1, index_features = NULL, ...) + \method{prepare_data}{copula}(x, index_features = NULL, ...) \method{prepare_data}{ctree}( @@ -36,6 +39,8 @@ prepare_data(x, ...) \item{index_features}{List. Default is NULL but if either various methods are being used or various mincriterion are used for different numbers of conditioned features, this will be a list with the features to pass.} +\item{seed}{Positive integer. If \code{NULL} the seed will be inherited from the calling environment.} + \item{mc_cores}{Integer. Only for class \code{ctree} currently. The number of cores to use in paralellization of the tree building (\code{create_ctree}) and tree sampling (\code{sample_ctree}). Defaults to 1. Note: Uses parallel::mclapply which relies on forking, i.e. uses only 1 core on Windows systems.} @@ -46,8 +51,6 @@ parallel::mclapply which relies on forking, i.e. uses only 1 core on Windows sys \item{mc_cores_sample_ctree}{Integer. Same as \code{mc_cores}, but specific for the tree building prediction function. Defaults to \code{mc_cores}.} - -\item{seed}{Positive integer. If \code{NULL} the seed will be inherited from the calling environment.} } \value{ A `data.table` containing simulated data passed to \code{\link{prediction}}. @@ -55,4 +58,7 @@ A `data.table` containing simulated data passed to \code{\link{prediction}}. \description{ Generate data used for predictions } +\author{ +Tom Heskes, Ioan Gabriel Bucur +} \keyword{internal} diff --git a/man/respects_order.Rd b/man/respects_order.Rd new file mode 100644 index 000000000..374f493d4 --- /dev/null +++ b/man/respects_order.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{respects_order} +\alias{respects_order} +\title{Helper function that checks a conditioning index against a particular causal ordering.} +\usage{ +respects_order(index, causal_ordering) +} +\arguments{ +\item{index}{Integer conditioning index to check against the causal ordering.} + +\item{causal_ordering}{List of vectors specifying (partial) causal ordering. Each element in +the list is a component in the order, which can contain one or more variable indices in a vector. +For example, in list(1, c(2, 3)), 2 > 1 and 3 > 1, but 2 and 3 are not comparable.} +} +\description{ +Helper function that checks a conditioning index against a particular causal ordering. +} +\author{ +Tom Heskes, Ioan Gabriel Bucur +} +\keyword{internal} diff --git a/man/sample_causal.Rd b/man/sample_causal.Rd new file mode 100644 index 000000000..2d6f1906e --- /dev/null +++ b/man/sample_causal.Rd @@ -0,0 +1,60 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sampling.R +\name{sample_causal} +\alias{sample_causal} +\title{Sample conditional Gaussian variables following a causal chain graph with do-calculus.} +\usage{ +sample_causal( + index_given, + n_samples, + mu, + cov_mat, + m, + x_test, + causal_ordering, + confounding +) +} +\arguments{ +\item{index_given}{Integer vector. The indices of the features to condition upon. Note that +\code{min(index_given) >= 1} and \code{max(index_given) <= m}.} + +\item{m}{Positive integer. The total number of features.} + +\item{x_test}{Numeric matrix. Contains the features of the observation whose +predictions ought to be explained (test data).} + +\item{causal_ordering}{List of vectors specifying (partial) causal ordering. Each element in +the list is a component in the order, which can contain one or more variable indices in a vector. +For example, in list(1, c(2, 3)), 2 > 1 and 3 > 1, but 2 and 3 are not comparable.} + +\item{confounding}{Logical vector specifying which variables are affected by confounding. +Confounding must be speficied globally with a single TRUE / FALSE value for all components, +or separately for each causal component in the causal ordering.} +} +\value{ +data.table +} +\description{ +Sample conditional Gaussian variables following a causal chain graph with do-calculus. +} +\examples{ +m <- 10 +n_samples <- 50 +mu <- rep(1, m) +cov_mat <- cov(matrix(rnorm(n_samples * m), n_samples, m)) +x_test <- matrix(MASS::mvrnorm(1, mu, cov_mat), nrow = 1) +cnms <- paste0("x", seq(m)) +colnames(x_test) <- cnms +index_given <- c(4, 7) +causal_ordering <- list(c(1:3), c(4:6), c(7:10)) +confounding <- c(TRUE, FALSE, TRUE) +r <- shapr:::sample_causal( + index_given, n_samples, mu, cov_mat, m, x_test, + causal_ordering, confounding +) +} +\author{ +Tom Heskes, Ioan Gabriel Bucur +} +\keyword{internal} diff --git a/man/shapr.Rd b/man/shapr.Rd index dac7775c7..9b2b816f0 100644 --- a/man/shapr.Rd +++ b/man/shapr.Rd @@ -4,7 +4,14 @@ \alias{shapr} \title{Create an explainer object with Shapley weights for test data.} \usage{ -shapr(x, model, n_combinations = NULL, group = NULL) +shapr( + x, + model, + n_combinations = NULL, + group = NULL, + asymmetric = FALSE, + causal_ordering = NULL +) } \arguments{ \item{x}{Numeric matrix or data.frame/data.table. Contains the data used to estimate the (conditional) @@ -22,6 +29,15 @@ combinations equals \code{2^ncol(x)}.} If provided, group wise Shapley values are computed. \code{group} then has length equal to the number of groups. The list element contains character vectors with the features included in each of the different groups.} + +\item{asymmetric}{Logical. The flag specifies whether we want to compute +asymmetric Shapley values. If so, a causal ordering also needs to be specified +and we only consider variable permutations with the given causal ordering.} + +\item{causal_ordering}{List. Contains vectors specifying (partial) causal ordering. +Each element in the list is a component in the order, which can contain one +or more variable indices in a vector. For example, in list(1, c(2, 3)), +2 > 1 and 3 > 1, but 2 and 3 are not comparable.} } \value{ Named list that contains the following items: diff --git a/tests/testthat/test-features.R b/tests/testthat/test-features.R index 667c06964..f2f10f63f 100644 --- a/tests/testthat/test-features.R +++ b/tests/testthat/test-features.R @@ -9,6 +9,12 @@ test_that("Test feature_combinations", { x1 <- feature_combinations(m = m, exact = exact, weight_zero_m = w) x2 <- feature_exact(m, w) + # Example 1a (asymmetric) + x1a <- feature_combinations(m = m, exact = exact, weight_zero_m = w, + asymmetric = TRUE, causal_ordering = split(1:m, 1:m)) + x2a <- feature_exact(m, w, TRUE, causal_ordering = split(1:m, 1:m)) + x3a <- feature_exact(m, w, TRUE) # default causal ordering allows all combinations + # Example 2 ----------- m <- 10 exact <- FALSE @@ -43,12 +49,26 @@ test_that("Test feature_combinations", { weight_zero_m = w ) + # Example 3a (asymmetric) + y3a <- feature_combinations( + m = m, + exact = exact, + n_combinations = n_combinations, + weight_zero_m = w, + asymmetric = TRUE, + causal_ordering = split(1:m, 1:m) + ) + # Test results ----------- expect_equal(x1, x2) + expect_equal(x1a, x2a) + expect_equal(x1, x3a) expect_equal(y1, y2) expect_equal(nrow(y3), 2^3) + expect_equal(nrow(y3a), 4) expect_error(feature_combinations(100)) expect_error(feature_combinations(100, n_combinations = NULL)) + expect_error(feature_combinations(10, asymmetric = TRUE, causal_ordering = NULL)) }) test_that("Test feature_exact", { @@ -82,6 +102,29 @@ test_that("Test feature_exact", { expect_equal(x[["features"]], lfeatures) expect_equal(x[["n_features"]], n_components) expect_equal(x[["N"]], n) + + # Example asymmetric + xa <- feature_exact(m, weight_zero_m, TRUE, split(1:m, 1:m)) + + # Define results ----------- + lfeatures <- list( + integer(0), + 1L, + c(1L, 2L), + c(1L, 2L, 3L) + ) + id_combination <- c(1, 2, 5, 8) + n_components <- 0:3 + n <- rep(1, 4) + + # Tests ----------- + expect_true(data.table::is.data.table(xa)) + expect_equal(names(xa), cnms) + expect_equal(unname(sapply(xa, typeof)), classes) + expect_equal(xa[["id_combination"]], id_combination) + expect_equal(xa[["features"]], lfeatures) + expect_equal(xa[["n_features"]], n_components) + expect_equal(xa[["N"]], n) }) test_that("Test feature_not_exact", { diff --git a/tests/testthat/test-sampling.R b/tests/testthat/test-sampling.R index 08597c291..9397c4e15 100644 --- a/tests/testthat/test-sampling.R +++ b/tests/testthat/test-sampling.R @@ -100,6 +100,63 @@ test_that("test sample_gaussian", { } }) +test_that("test sample_causal", { + if (requireNamespace("MASS", quietly = TRUE)) { + # Example ----------- + m <- 10 + n_samples <- 50 + mu <- rep(1, m) + cov_mat <- cov(matrix(rnorm(n_samples * m), n_samples, m)) + x_test <- matrix(MASS::mvrnorm(1, mu, cov_mat), nrow = 1) + cnms <- paste0("x", seq(m)) + colnames(x_test) <- cnms + index_given <- c(4, 7) + causal_ordering <- list(1:4, 8:10, 5:7) + confounding <- c(TRUE, TRUE, FALSE) + r <- sample_causal(index_given, n_samples, mu, cov_mat, m, x_test, causal_ordering, confounding) + + # Test output format ------------------ + expect_true(data.table::is.data.table(r)) + expect_equal(ncol(r), m) + expect_equal(nrow(r), n_samples) + expect_equal(colnames(r), cnms) + + # Check that the given features are not resampled, but kept as is. + for (i in seq(m)) { + var_name <- cnms[i] + + if (i %in% index_given) { + expect_equal( + unique(r[[var_name]]), x_test[, var_name][[1]] + ) + } else { + expect_true( + length(unique(r[[var_name]])) == n_samples + ) + } + } + + # Example 2 ------------- + # Check that conditioning upon all variables simply returns the test observation. + r <- sample_causal(1:m, n_samples, mu, cov_mat, m, x_test, causal_ordering, confounding) + expect_identical(r, data.table::as.data.table(x_test)) + + # Tests for errors ------------------ + expect_error( + sample_causal(m + 1, n_samples, mu, cov_mat, m, x_test, causal_ordering, confounding) + ) + expect_error( + sample_causal(m + 1, n_samples, mu, cov_mat, m, as.vector(x_test), causal_ordering, confounding) + ) + expect_error( + sample_causal(m, n_samples, mu, cov_mat, m, x_test, unlist(causal_ordering), confounding) + ) + expect_error( + sample_causal(m, n_samples, mu, cov_mat, m, x_test, causal_ordering, as.integer(confounding)) + ) + } +}) + test_that("test sample_copula", { if (requireNamespace("MASS", quietly = TRUE)) { # Example 1 -------------- diff --git a/tests/testthat/test_objects/shapley_explainer_group1_2_obj.rds b/tests/testthat/test_objects/shapley_explainer_group1_2_obj.rds index ae03351c3..f5004240d 100644 Binary files a/tests/testthat/test_objects/shapley_explainer_group1_2_obj.rds and b/tests/testthat/test_objects/shapley_explainer_group1_2_obj.rds differ diff --git a/tests/testthat/test_objects/shapley_explainer_group1_obj.rds b/tests/testthat/test_objects/shapley_explainer_group1_obj.rds index 176c9ef09..d922c3b6d 100644 Binary files a/tests/testthat/test_objects/shapley_explainer_group1_obj.rds and b/tests/testthat/test_objects/shapley_explainer_group1_obj.rds differ diff --git a/tests/testthat/test_objects/shapley_explainer_group2_2_obj.rds b/tests/testthat/test_objects/shapley_explainer_group2_2_obj.rds index 2d65d72a8..b84aa0c20 100644 Binary files a/tests/testthat/test_objects/shapley_explainer_group2_2_obj.rds and b/tests/testthat/test_objects/shapley_explainer_group2_2_obj.rds differ diff --git a/tests/testthat/test_objects/shapley_explainer_group2_obj.rds b/tests/testthat/test_objects/shapley_explainer_group2_obj.rds index 40e01b14f..ad4c4ef62 100644 Binary files a/tests/testthat/test_objects/shapley_explainer_group2_obj.rds and b/tests/testthat/test_objects/shapley_explainer_group2_obj.rds differ diff --git a/tests/testthat/test_objects/shapley_explainer_obj.rds b/tests/testthat/test_objects/shapley_explainer_obj.rds index 3572b2ca4..1e3be2e6c 100644 Binary files a/tests/testthat/test_objects/shapley_explainer_obj.rds and b/tests/testthat/test_objects/shapley_explainer_obj.rds differ