Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user experience #267

Merged
merged 18 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,32 @@ jobs:
steps:
- uses: actions/checkout@v2

- uses: r-lib/actions/setup-r@master
- uses: r-lib/actions/setup-r@v1

- name: Query dependencies
run: |
install.packages('remotes')
saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
shell: Rscript {0}

- name: Cache R packages
uses: actions/cache@v1
- name: Restore R package cache
uses: actions/cache@v2
with:
path: ${{ env.R_LIBS_USER }}
key: macOS-r-4.0-1-${{ hashFiles('.github/depends.Rds') }}
restore-keys: macOS-r-4.0-1-
key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-

- name: Install dependencies
run: |
remotes::install_deps(dependencies = TRUE, type = "binary")
remotes::install_cran("lintr", type = "binary")
install.packages(c("remotes"))
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("lintr")
shell: Rscript {0}

- name: Install package
run: R CMD INSTALL .

- name: Lint
run: lintr::lint_package()
shell: Rscript {0}
41 changes: 32 additions & 9 deletions R/explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#' either be \code{"gaussian"}, \code{"copula"}, \code{"empirical"}, or \code{"ctree"}. See details for more
#' information.
#'
#' @param n_samples Positive integer. Indicating the maximum number of samples to use in the
#' Monte Carlo integration for every conditional expectation. See also details.
#'
#' @param prediction_zero Numeric. The prediction value for unseen data, typically equal to the mean of
#' the response.
#'
Expand All @@ -19,7 +22,6 @@
#' @details The most important thing to notice is that \code{shapr} has implemented four different
#' approaches for estimating the conditional distributions of the data, namely \code{"empirical"},
#' \code{"gaussian"}, \code{"copula"} and \code{"ctree"}.
#'
#' In addition, the user also has the option of combining the four approaches.
#' E.g. if you're in a situation where you have trained a model the consists of 10 features,
martinju marked this conversation as resolved.
Show resolved Hide resolved
#' and you'd like to use the \code{"gaussian"} approach when you condition on a single feature,
Expand All @@ -29,6 +31,13 @@
#' \code{"approach[i]" = "gaussian"} it means that you'd like to use the \code{"gaussian"} approach
martinju marked this conversation as resolved.
Show resolved Hide resolved
#' when conditioning on \code{i} features.
#'
#' For \code{approach="ctree"}, \code{n_samples} corresponds to the number of samples
#' from the leaf node (see an exception related to the \code{sample} argument).
#' For \code{approach="empirical"}, \code{n_samples} is the \eqn{K} parameter in equations (14-15) of
#' Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the
#' \code{w_threshold} argument.
#'
#'
#' @return Object of class \code{c("shapr", "list")}. Contains the following items:
#' \describe{
#' \item{dt}{data.table}
Expand Down Expand Up @@ -62,6 +71,10 @@
#'
#' @author Camilla Lingjaerde, Nikolai Sellereite, Martin Jullum, Annabelle Redelmeier
#'
#'@references
#' Aas, K., Jullum, M., & Løland, A. (2021). Explaining individual predictions when features are dependent:
#' More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502.
#'
#' @examples
#' if (requireNamespace("MASS", quietly = TRUE)) {
#' # Load example data
Expand Down Expand Up @@ -131,7 +144,7 @@
#' )
#' print(explain_groups$dt)
#' }
explain <- function(x, explainer, approach, prediction_zero, ...) {
explain <- function(x, explainer, approach, prediction_zero, n_samples = 1e3, ...) {
extras <- list(...)

# Check input for x
Expand Down Expand Up @@ -186,15 +199,19 @@ explain <- function(x, explainer, approach, prediction_zero, ...) {
#' is only applicable when \code{approach = "empirical"}, and \code{type} is either equal to
#' \code{"AICc_each_k"} or \code{"AICc_full"}
#'
#' @param w_threshold Positive integer between 0 and 1.
#' @param w_threshold Numeric vector of length 1, with \code{0 < w_threshold <= 1} representing the minimum proportion
#' of the total empirical weight that data samples should use. If e.g. \code{w_threshold = .8} we will choose the
#' \code{K} samples with the largest weight so that the sum of the weights accounts for 80\% of the total weight.
#' \code{w_threshold} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).
#'
#' @rdname explain
#'
#' @export
explain.empirical <- function(x, explainer, approach, prediction_zero,
n_samples = 1e3, w_threshold = 0.95,
type = "fixed_sigma", fixed_sigma_vec = 0.1,
n_samples_aicc = 1000, eval_max_aicc = 20,
start_aicc = 0.1, w_threshold = 0.95, ...) {
start_aicc = 0.1, ...) {

# Add arguments to explainer object
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
Expand All @@ -205,6 +222,7 @@ explain.empirical <- function(x, explainer, approach, prediction_zero,
explainer$eval_max_aicc <- eval_max_aicc
explainer$start_aicc <- start_aicc
explainer$w_threshold <- w_threshold
explainer$n_samples <- n_samples

# Generate data
dt <- prepare_data(explainer, ...)
Expand All @@ -229,12 +247,14 @@ explain.empirical <- function(x, explainer, approach, prediction_zero,
#' @rdname explain
#'
#' @export
explain.gaussian <- function(x, explainer, approach, prediction_zero, mu = NULL, cov_mat = NULL, ...) {
explain.gaussian <- function(x, explainer, approach, prediction_zero, n_samples = 1e3, mu = NULL, cov_mat = NULL, ...) {


# Add arguments to explainer object
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$approach <- approach
explainer$n_samples <- n_samples


# If mu is not provided directly, use mean of training data
if (is.null(mu)) {
Expand Down Expand Up @@ -270,11 +290,12 @@ explain.gaussian <- function(x, explainer, approach, prediction_zero, mu = NULL,

#' @rdname explain
#' @export
explain.copula <- function(x, explainer, approach, prediction_zero, ...) {
explain.copula <- function(x, explainer, approach, prediction_zero, n_samples = 1e3, ...) {

# Setup
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$approach <- approach
explainer$n_samples <- n_samples

# Prepare transformed data
x_train <- apply(
Expand Down Expand Up @@ -334,7 +355,7 @@ explain.copula <- function(x, explainer, approach, prediction_zero, ...) {
#' @name explain
#'
#' @export
explain.ctree <- function(x, explainer, approach, prediction_zero,
explain.ctree <- function(x, explainer, approach, prediction_zero, n_samples = 1e3,
mincriterion = 0.95, minsplit = 20,
minbucket = 7, sample = TRUE, ...) {
# Checks input argument
Expand All @@ -349,6 +370,7 @@ explain.ctree <- function(x, explainer, approach, prediction_zero,
explainer$minsplit <- minsplit
explainer$minbucket <- minbucket
explainer$sample <- sample
explainer$n_samples <- n_samples

# Generate data
dt <- prepare_data(explainer, ...)
Expand All @@ -367,12 +389,13 @@ explain.ctree <- function(x, explainer, approach, prediction_zero,
#' @name explain
#'
#' @export
explain.combined <- function(x, explainer, approach, prediction_zero,
explain.combined <- function(x, explainer, approach, prediction_zero, n_samples = 1e3,
mu = NULL, cov_mat = NULL, ...) {
# Get indices of combinations
l <- get_list_approaches(explainer$X$n_features, approach)
explainer$return <- TRUE
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$n_samples <- n_samples

dt_l <- list()
for (i in seq_along(l)) {
Expand Down Expand Up @@ -435,7 +458,7 @@ get_list_approaches <- function(n_features, approach) {
#'
#' @export
explain.ctree_comb_mincrit <- function(x, explainer, approach,
prediction_zero, mincriterion, ...) {
prediction_zero, n_samples, mincriterion, ...) {

# Get indices of combinations
l <- get_list_ctree_mincrit(explainer$X$n_features, mincriterion)
Expand Down
3 changes: 2 additions & 1 deletion R/features.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_z
exact <- TRUE
message(
paste0(
"\nn_combinations is larger than or equal to 2^m = ", 2^m, ". \n",
"\nSuccess with message:\n",
"n_combinations is larger than or equal to 2^m = ", 2^m, ". \n",
"Using exact instead."
)
)
Expand Down
35 changes: 15 additions & 20 deletions R/observations.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
#' the total number of unique features, respectively. Note that \code{m = ncol(x_train)}.
#' @param x_train Numeric matrix
#' @param x_test Numeric matrix
#' @param w_threshold Numeric vector of length 1, where \code{w_threshold > 0} and
#' \code{w_threshold <= 1}. If \code{w_threshold = .8} we will choose the \code{K} samples with
#' the largest weight so that the sum of the weights accounts for 80\% of the total weight.
#'
#' @inheritParams explain
#' @inherit explain references
#'
#' @return data.table
#'
Expand Down Expand Up @@ -71,9 +71,6 @@ observation_impute <- function(W_kernel, S, x_train, x_test, w_threshold = .7, n
#'
#' @param x Explainer object. See \code{\link{explain}} for more information.
#'
#' @param n_samples Positive integer. Indicating the maximum number of samples to use in the
#' Monte Carlo integration for every conditional expectation.
#'
#' @param seed Positive integer. If \code{NULL} the seed will be inherited from the calling environment.
#'
#' @param index_features Positive integer vector. Specifies the indices of combinations to apply to the present method.
Expand All @@ -94,7 +91,7 @@ prepare_data <- function(x, ...) {

#' @rdname prepare_data
#' @export
prepare_data.empirical <- function(x, seed = 1, n_samples = 1e3, index_features = NULL, ...) {
prepare_data.empirical <- function(x, seed = 1, index_features = NULL, ...) {
id <- id_combination <- w <- NULL # due to NSE notes in R CMD check

# Get distance matrix ----------------
Expand Down Expand Up @@ -158,7 +155,7 @@ prepare_data.empirical <- function(x, seed = 1, n_samples = 1e3, index_features
x_train = as.matrix(x$x_train),
x_test = x$x_test[i, , drop = FALSE],
w_threshold = x$w_threshold,
n_samples = n_samples
n_samples = x$n_samples
)

dt_l[[i]][, id := i]
Expand All @@ -171,7 +168,7 @@ prepare_data.empirical <- function(x, seed = 1, n_samples = 1e3, index_features

#' @rdname prepare_data
#' @export
prepare_data.gaussian <- function(x, seed = 1, n_samples = 1e3, index_features = NULL, ...) {
prepare_data.gaussian <- function(x, seed = 1, index_features = NULL, ...) {
id <- id_combination <- w <- NULL # due to NSE notes in R CMD check

n_xtest <- nrow(x$x_test)
Expand All @@ -187,15 +184,15 @@ prepare_data.gaussian <- function(x, seed = 1, n_samples = 1e3, index_features =
l <- lapply(
X = features,
FUN = sample_gaussian,
n_samples = n_samples,
n_samples = x$n_samples,
mu = x$mu,
cov_mat = x$cov_mat,
m = ncol(x$x_test),
x_test = x$x_test[i, , drop = FALSE]
)

dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
dt_l[[i]][, w := 1 / n_samples]
dt_l[[i]][, w := 1 / x$n_samples]
dt_l[[i]][, id := i]
if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
}
Expand All @@ -206,7 +203,7 @@ prepare_data.gaussian <- function(x, seed = 1, n_samples = 1e3, index_features =

#' @rdname prepare_data
#' @export
prepare_data.copula <- function(x, x_test_gaussian = 1, seed = 1, n_samples = 1e3, index_features = NULL, ...) {
prepare_data.copula <- function(x, x_test_gaussian = 1, seed = 1, index_features = NULL, ...) {
id <- id_combination <- w <- NULL # due to NSE notes in R CMD check
n_xtest <- nrow(x$x_test)
dt_l <- list()
Expand All @@ -221,7 +218,7 @@ prepare_data.copula <- function(x, x_test_gaussian = 1, seed = 1, n_samples = 1e
l <- lapply(
X = features,
FUN = sample_copula,
n_samples = n_samples,
n_samples = x$n_samples,
mu = x$mu,
cov_mat = x$cov_mat,
m = ncol(x$x_test),
Expand All @@ -231,17 +228,15 @@ prepare_data.copula <- function(x, x_test_gaussian = 1, seed = 1, n_samples = 1e
)

dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
dt_l[[i]][, w := 1 / n_samples]
dt_l[[i]][, w := 1 / x$n_samples]
dt_l[[i]][, id := i]
if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
}
dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
return(dt)
}

#' @param n_samples Integer. The number of obs to sample from the leaf if \code{sample} = TRUE or if \code{sample}
#' = FALSE but \code{n_samples} is less than the number of obs in the leaf.
#'

#' @param index_features List. Default is NULL but if either various methods are being used or various mincriterion are
#' used for different numbers of conditioned features, this will be a list with the features to pass.
#'
Expand All @@ -258,7 +253,7 @@ prepare_data.copula <- function(x, x_test_gaussian = 1, seed = 1, n_samples = 1e
#'
#' @rdname prepare_data
#' @export
prepare_data.ctree <- function(x, seed = 1, n_samples = 1e3, index_features = NULL,
prepare_data.ctree <- function(x, seed = 1, index_features = NULL,
mc_cores = 1, mc_cores_create_ctree = mc_cores,
mc_cores_sample_ctree = mc_cores, ...) {
id <- id_combination <- w <- NULL # due to NSE notes in R CMD check
Expand Down Expand Up @@ -290,7 +285,7 @@ prepare_data.ctree <- function(x, seed = 1, n_samples = 1e3, index_features = NU
l <- parallel::mclapply(
X = all_trees,
FUN = sample_ctree,
n_samples = n_samples,
n_samples = x$n_samples,
x_test = x$x_test[i, , drop = FALSE],
x_train = x$x_train,
p = ncol(x$x_test),
Expand All @@ -300,7 +295,7 @@ prepare_data.ctree <- function(x, seed = 1, n_samples = 1e3, index_features = NU
)

dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
dt_l[[i]][, w := 1 / n_samples]
dt_l[[i]][, w := 1 / x$n_samples]
dt_l[[i]][, id := i]
if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
}
Expand Down
15 changes: 11 additions & 4 deletions R/preprocess_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ check_features <- function(f_list_1, f_list_2,
}
if (NULL_1 & use_1_as_truth) {
message(paste0(
"\nSuccess with message:\n",
"The specified ", name_1, " provides NULL feature labels. ",
"The labels of ", name_2, " are taken as the truth."
))
Expand All @@ -170,6 +171,7 @@ check_features <- function(f_list_1, f_list_2,
}
if ((NA_1 & use_1_as_truth)) {
message(paste0(
"\nSuccess with message:\n",
"The specified ", name_1, " provides feature labels that are NA. ",
"The labels of ", name_2, " are taken as the truth."
))
Expand Down Expand Up @@ -245,6 +247,7 @@ check_features <- function(f_list_1, f_list_2,
#### Checking classes ####
if (any(is.na(f_list_1$classes)) & use_1_as_truth) { # Only relevant when f_list_1 is a model
message(paste0(
"\nSuccess with message:\n",
"The specified ", name_1, " provides feature classes that are NA. ",
"The classes of ", name_2, " are taken as the truth."
))
Expand Down Expand Up @@ -272,6 +275,7 @@ check_features <- function(f_list_1, f_list_2,
is_NULL <- any(is.null(relevant_factor_levels))
if ((is_NA | is_NULL) & use_1_as_truth) {
message(paste0(
"\nSuccess with message:\n",
"The specified ", name_1, " provides factor feature levels that are NULL or NA. ",
"The factor levels of ", name_2, " are taken as the truth."
))
Expand Down Expand Up @@ -330,9 +334,9 @@ update_data <- function(data, updater) {
# Reorder and delete unused columns
cnms_remove <- setdiff(colnames(data), new_labels)
if (length(cnms_remove) > 0) {
message(
paste0(
"The columns(s) ",
message(paste0(
"\nSuccess with message:\n",
"The columns(s) ",
paste0(cnms_remove, collapse = ", "),
" is not used by the model and thus removed from the data."
)
Expand All @@ -348,6 +352,7 @@ update_data <- function(data, updater) {
if (any(!identical_levels)) {
changed_levels <- which(!identical_levels)
message(paste0(
"\nSuccess with message:\n",
"Levels are reordered for the factor feature(s) ",
paste0(new_labels[changed_levels], collapse = ", "), "."
))
Expand Down Expand Up @@ -383,7 +388,9 @@ process_groups <- function(group, feature_labels) {

# Make group names if not existing
if (is.null(names(group))) {
message("Group names not provided. Assigning them the default names 'group1', 'group2', 'group3' etc.")
message(
"\nSuccess with message:\n
Group names not provided. Assigning them the default names 'group1', 'group2', 'group3' etc.")
names(group) <- paste0("group", seq_along(group))
}

Expand Down
Loading