Skip to content

Commit

Permalink
Check and update data before explanation (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Jan 26, 2021
1 parent f7d742c commit ba36e74
Show file tree
Hide file tree
Showing 50 changed files with 2,830 additions and 1,816 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ on:
push:
branches:
- master
- cranversion
pull_request:
branches:
- master
- cranversion

name: R-CMD-check

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ on:
push:
branches:
- master
- cranversion
pull_request:
branches:
- master
- cranversion

name: lint

Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
on:
push:
branches: master

branches:
- master
- cranversion
name: pkgdown

jobs:
Expand Down
2 changes: 2 additions & 0 deletions CRAN-RELEASE
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This package was submitted to CRAN on 2021-01-21.
Once it is accepted, delete this file and tag the release (commit 8bad7333).
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: shapr
Version: 0.1.3
Version: 0.1.4.9000
Title: Prediction Explanation with Dependence-Aware Shapley Values
Description: Complex machine learning models are often hard to interpret. However, in
many situations it is crucial to understand and explain why a model made a specific
Expand Down
33 changes: 20 additions & 13 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ S3method(explain,ctree)
S3method(explain,ctree_comb_mincrit)
S3method(explain,empirical)
S3method(explain,gaussian)
S3method(features,gam)
S3method(features,glm)
S3method(features,lm)
S3method(features,ranger)
S3method(features,xgb.Booster)
S3method(model_type,default)
S3method(model_type,gam)
S3method(model_type,glm)
S3method(model_type,lm)
S3method(model_type,ranger)
S3method(model_type,xgb.Booster)
S3method(get_model_specs,gam)
S3method(get_model_specs,glm)
S3method(get_model_specs,lm)
S3method(get_model_specs,ranger)
S3method(get_model_specs,xgb.Booster)
S3method(model_checker,default)
S3method(model_checker,gam)
S3method(model_checker,glm)
S3method(model_checker,lm)
S3method(model_checker,ranger)
S3method(model_checker,xgb.Booster)
S3method(plot,shapr)
S3method(predict_model,default)
S3method(predict_model,gam)
Expand All @@ -29,21 +29,25 @@ S3method(prepare_data,ctree)
S3method(prepare_data,empirical)
S3method(prepare_data,gaussian)
export(aicc_full_single_cpp)
export(check_features)
export(correction_matrix_cpp)
export(create_ctree)
export(explain)
export(feature_combinations)
export(feature_matrix_cpp)
export(features)
export(get_data_specs)
export(get_model_specs)
export(hat_matrix_cpp)
export(mahalanobis_distance_cpp)
export(make_dummies)
export(model_type)
export(model_checker)
export(observation_impute_cpp)
export(predict_model)
export(prepare_data)
export(preprocess_data)
export(rss_cpp)
export(shapr)
export(update_data)
export(weight_matrix_cpp)
importFrom(Rcpp,sourceCpp)
importFrom(data.table,":=")
Expand All @@ -65,9 +69,12 @@ importFrom(graphics,hist)
importFrom(graphics,plot)
importFrom(graphics,rect)
importFrom(stats,as.formula)
importFrom(stats,contrasts)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,setNames)
importFrom(utils,head)
importFrom(utils,methods)
importFrom(utils,tail)
useDynLib(shapr, .registration = TRUE)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

# shapr 0.1.4

* Patch to fulfill CRAN policy of using packages under Suggests conditionally (in tests and examples)

# shapr 0.1.3

* Fix installation error on Solaris
Expand Down
76 changes: 17 additions & 59 deletions R/explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
#'
#' @param ... Additional arguments passed to \code{\link{prepare_data}}
#'
#' @details The most important thing to notice is that \code{shapr} has implemented three different
#' @details The most important thing to notice is that \code{shapr} has implemented four different
#' approaches for estimating the conditional distributions of the data, namely \code{"empirical"},
#' \code{"gaussian"} and \code{"copula"}.
#' \code{"gaussian"}, \code{"copula"} and \code{"ctree"}.
#'
#' In addition to this the user will also have the option of combining the three approaches.
#' In addition, the user also has the option of combining the four approaches.
#' E.g. if you're in a situation where you have trained a model the consists of 10 features,
#' and you'd like to use the \code{"gaussian"} approach when you condition on a single feature,
#' the \code{"empirical"} approach if you condition on 2-5 features, and \code{"copula"} version
Expand Down Expand Up @@ -60,9 +60,10 @@
#'
#' @export
#'
#' @author Camilla Lingjaerde, Nikolai Sellereite
#' @author Camilla Lingjaerde, Nikolai Sellereite, Martin Jullum, Annabelle Redelmeier
#'
#' @examples
#' if (requireNamespace("MASS", quietly = TRUE)) {
#' # Load example data
#' data("Boston", package = "MASS")
#'
Expand Down Expand Up @@ -99,19 +100,22 @@
#' print(explain1$dt)
#'
#' # Plot the results
#' if (requireNamespace("ggplot2", quietly = TRUE)) {
#' plot(explain1)
#' }
#' }
explain <- function(x, explainer, approach, prediction_zero, ...) {
extras <- list(...)

# Check input for x
if (!is.matrix(x) & !is.data.frame(x)) {
stop("x should be a matrix or a dataframe.")
stop("x should be a matrix or a data.frame/data.table.")
}

# Check input for approach
if (!(is.vector(approach) &&
is.atomic(approach) &&
(length(approach) == 1 | length(approach) == length(explainer$feature_labels)) &&
(length(approach) == 1 | length(approach) == length(explainer$feature_list$labels)) &&
all(is.element(approach, c("empirical", "gaussian", "copula", "ctree"))))
) {
stop(
Expand All @@ -123,16 +127,7 @@ explain <- function(x, explainer, approach, prediction_zero, ...) {
)
}

# Check that x contains correct variables
if (!all(explainer$feature_labels %in% colnames(x))) {
stop(
paste0(
"\nThe test data, x, does not contain all features necessary for\n",
"generating predictions. Please modify x so that all labels given\n",
"by explainer$feature_labels is present in colnames(x)."
)
)
}


if (length(approach) > 1) {
class(x) <- "combined"
Expand Down Expand Up @@ -175,7 +170,7 @@ explain.empirical <- function(x, explainer, approach, prediction_zero,
start_aicc = 0.1, w_threshold = 0.95, ...) {

# Add arguments to explainer object
explainer$x_test <- explainer_x_test(x, explainer$feature_labels)
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$approach <- approach
explainer$type <- type
explainer$fixed_sigma_vec <- fixed_sigma_vec
Expand Down Expand Up @@ -207,8 +202,9 @@ explain.empirical <- function(x, explainer, approach, prediction_zero,
#' @export
explain.gaussian <- function(x, explainer, approach, prediction_zero, mu = NULL, cov_mat = NULL, ...) {


# Add arguments to explainer object
explainer$x_test <- explainer_x_test(x, explainer$feature_labels)
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$approach <- approach

# If mu is not provided directly, use mean of training data
Expand Down Expand Up @@ -246,7 +242,7 @@ explain.gaussian <- function(x, explainer, approach, prediction_zero, mu = NULL,
explain.copula <- function(x, explainer, approach, prediction_zero, ...) {

# Setup
explainer$x_test <- explainer_x_test(x, explainer$feature_labels)
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)
explainer$approach <- approach

# Prepare transformed data
Expand Down Expand Up @@ -314,7 +310,7 @@ explain.ctree <- function(x, explainer, approach, prediction_zero,
}

# Add arguments to explainer object
explainer$x_test <- explainer_x_test_dt(x, explainer$feature_labels)
explainer$x_test <- preprocess_data(x, explainer$feature_list)$x_dt
explainer$approach <- approach
explainer$mincriterion <- mincriterion
explainer$minsplit <- minsplit
Expand All @@ -341,7 +337,7 @@ explain.combined <- function(x, explainer, approach, prediction_zero,
# Get indices of combinations
l <- get_list_approaches(explainer$X$n_features, approach)
explainer$return <- TRUE
explainer$x_test <- explainer_x_test(x, explainer$feature_labels)
explainer$x_test <- as.matrix(preprocess_data(x, explainer$feature_list)$x_dt)

dt_l <- list()
for (i in seq_along(l)) {
Expand Down Expand Up @@ -398,32 +394,6 @@ get_list_approaches <- function(n_features, approach) {
return(l)
}

#' @keywords internal
explainer_x_test <- function(x_test, feature_labels) {

# Remove variables that were not used for training
x <- data.table::as.data.table(x_test)
cnms_remove <- setdiff(colnames(x), feature_labels)
if (length(cnms_remove) > 0) x[, (cnms_remove) := NULL]
data.table::setcolorder(x, feature_labels)

return(as.matrix(x))
}

#' @keywords internal
explainer_x_test_dt <- function(x_test, feature_labels) {

# Remove variables that were not used for training
# Same as explainer_x_test() but doesn't convert to a matrix
# Useful for ctree method which sometimes takes categorical features
x <- data.table::as.data.table(x_test)
cnms_remove <- setdiff(colnames(x), feature_labels)
if (length(cnms_remove) > 0) x[, (cnms_remove) := NULL]
data.table::setcolorder(x, feature_labels)

return(x)
}


#' @rdname explain
#' @name explain
Expand Down Expand Up @@ -462,15 +432,3 @@ get_list_ctree_mincrit <- function(n_features, mincriterion) {
}
return(l)
}

#' @keywords internal
explainer_x_test <- function(x_test, feature_labels) {

# Remove variables that were not used for training
x <- data.table::as.data.table(x_test)
cnms_remove <- setdiff(colnames(x), feature_labels)
if (length(cnms_remove) > 0) x[, (cnms_remove) := NULL]
data.table::setcolorder(x, feature_labels)

return(as.matrix(x))
}
Loading

0 comments on commit ba36e74

Please sign in to comment.