diff --git a/DESCRIPTION b/DESCRIPTION index 8a65d264..634701c5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -7,6 +7,7 @@ Authors@R: person("Mateusz", "KrzyziƄski", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")), person("Sophie", "Langbein", role = c("aut")), person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")), + person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")), person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823")) ) Description: Survival analysis models are commonly used in medicine and other areas. Many of them @@ -25,6 +26,7 @@ Imports: DALEX (>= 2.2.1), ggplot2 (>= 3.4.0), kernelshap, + treeshap, pec, survival, patchwork diff --git a/R/surv_shap.R b/R/surv_shap.R index f8fd2048..e0d47593 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -5,7 +5,7 @@ #' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations. #' @param ... additional parameters, passed to internal functions #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting -#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation +#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models). #' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"` #' #' @return A list, containing the calculated SurvSHAP(t) results in the `result` field @@ -19,10 +19,15 @@ surv_shap <- function(explainer, output_type, ..., y_true = NULL, - calculation_method = "kernelshap", - aggregation_method = "integral") { + calculation_method = c("kernelshap", "exact_kernel", "treeshap"), + aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares") +) { + calculation_method <- match.arg(calculation_method) + aggregation_method <- match.arg(aggregation_method) + + # make this code work for multiple observations stopifnot( - "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( + "`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse( !is.null(y_true), ifelse( is.matrix(y_true), @@ -33,14 +38,40 @@ surv_shap <- function(explainer, ) ) + if (calculation_method == "kernelshap") { + if (!requireNamespace("kernelshap", quietly = TRUE)) { + stop( + paste0( + "Package \"kernelshap\" must be installed to use ", + "'calculation_method = \"kernelshap\"'." + ), + call. = FALSE + ) + } + } + if (calculation_method == "treeshap") { + if (!requireNamespace("treeshap", quietly = TRUE)) { + stop( + paste0( + "Package \"treeshap\" must be installed to use ", + "'calculation_method = \"treeshap\"'." + ), + call. = FALSE + ) + } + } + test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) # make this code also work for 1-row matrix col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) if (is.matrix(new_observation) && nrow(new_observation) == 1) { - new_observation <- as.matrix(t(new_observation[, col_index])) + new_observation <- data.frame(as.matrix(t(new_observation[, col_index]))) } else { new_observation <- new_observation[, col_index] + if (!inherits(new_observation, "data.frame")) { + new_observation <- data.frame(new_observation) + } } if (ncol(explainer$data) != ncol(new_observation)) { @@ -59,14 +90,25 @@ surv_shap <- function(explainer, } } + if (calculation_method == "treeshap") { + if (!inherits(explainer$model, "ranger")) { + stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.") + } + } + res <- list() res$eval_times <- explainer$times # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...), - stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented") + "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...), + "treeshap" = use_treeshap(explainer, new_observation, ...), + stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) + # quality-check here + stopifnot( + "Number of rows of SurvSHAP table are not identical with length(eval_times)" = + nrow(res$result) == length(res$eval_times) ) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) @@ -86,7 +128,7 @@ surv_shap <- function(explainer, return(res) } -use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) { +use_exact_shap <- function(explainer, new_observation, output_type, ...) { shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { @@ -123,11 +165,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) { timestamps ) - - shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data)) colnames(shap_values) <- paste("t=", timestamps, sep = "") - return(t(shap_values)) } @@ -204,7 +243,18 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ times = explainer$times ) } + } + + stopifnot( + "new_observation must be a data.frame" = inherits( + new_observation, "data.frame") + ) + # get explainer data to be able to make class checks and transformations + explainer_data <- explainer$data + # ensure that classes of explainer$data and new_observation are equal + if (!inherits(explainer_data, "data.frame")) { + explainer_data <- data.frame(explainer_data) } shap_values <- sapply( @@ -212,11 +262,12 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ FUN = function(i) { tmp_res <- kernelshap::kernelshap( object = explainer$model, - X = new_observation[as.integer(i), ], - bg_X = explainer$data, + X = new_observation[as.integer(i), ], # data.frame + bg_X = explainer_data, # data.frame pred_fun = predfun, verbose = FALSE ) + # kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE` tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) colnames(tmp_shap_values) <- colnames(tmp_res$X) rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") @@ -229,6 +280,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ return(shap_values) } +use_treeshap <- function(explainer, new_observation, ...){ + + stopifnot( + "new_observation must be a data.frame" = inherits( + new_observation, "data.frame") + ) + + # init unify_append_args + unify_append_args <- list() + + if (inherits(explainer$model, "ranger")) { + # UNIFY_FUN to prepare code for easy Integration of other ml algorithms + # that are supported by treeshap + UNIFY_FUN <- treeshap::ranger_surv.unify + unify_append_args <- list(type = "survival", times = explainer$times) + } else { + stop("Support for `treeshap` is currently only implemented for `ranger`.") + } + + unify_args <- list( + rf_model = explainer$model, + data = explainer$data + ) + + if (length(unify_append_args) > 0) { + unify_args <- c(unify_args, unify_append_args) + } + + tmp_unified <- do.call(UNIFY_FUN, unify_args) + + shap_values <- sapply( + X = as.character(seq_len(nrow(new_observation))), + FUN = function(i) { + tmp_res <- do.call( + rbind, + lapply( + tmp_unified, + function(m) { + new_obs_mat <- new_observation[as.integer(i), ] + # ensure that matrix has expected dimensions; as.integer is + # necessary for valid comparison with "identical" + stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation))))) + treeshap::treeshap( + unified_model = m, + x = new_obs_mat + )$shaps + } + ) + ) + + tmp_shap_values <- data.frame(tmp_res) + colnames(tmp_shap_values) <- colnames(tmp_res) + rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") + tmp_shap_values + }, + USE.NAMES = TRUE, + simplify = FALSE + ) + + return(shap_values) + +} + #' @keywords internal aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 840f6460..dc85050f 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -26,7 +26,8 @@ model_survshap(explainer, ...) \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} -\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements), \code{"exact_kernel"} for exact Kernel SHAP estimation, +or \code{"treeshap"} for use of \code{treeshap} library (efficient implementation to compute SHAP values for tree-based models).} \item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 885e64f6..073dfc58 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -10,8 +10,8 @@ surv_shap( output_type, ..., y_true = NULL, - calculation_method = "kernelshap", - aggregation_method = "integral" + calculation_method = c("kernelshap", "exact_kernel", "treeshap"), + aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares") ) } \arguments{ @@ -25,7 +25,8 @@ surv_shap( \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} -\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements), \code{"exact_kernel"} for exact Kernel SHAP estimation, +or \code{"treeshap"} for use of \code{treeshap} library (efficient implementation to compute SHAP values for tree-based models).} \item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} } diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index d56b75cc..4ee3be45 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -1,4 +1,5 @@ +# create objects here so that they do not have to be created redundantly veteran <- survival::veteran rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) @@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times)) expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data))) }) + +# testing if matrix works as input +rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) +rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) + +test_that("global survshap explanations with treeshap work for ranger", { + + new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))]) + ranger_global_survshap_tree <- model_survshap( + rsf_ranger_exp_matrix, + new_observation = new_obs, + y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(ranger_global_survshap_tree) + + expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap")) + expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times)) + expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data))) + +}) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 450fb1cf..6e8c733f 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -6,7 +6,7 @@ test_that("survshap explanations work", { rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph, verbose = FALSE) - rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) rsf_src_exp <- explain(rsf_src, verbose = FALSE) parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares") @@ -19,6 +19,19 @@ test_that("survshap explanations work", { parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute") plot(parts_ranger) + # test ranger with kernelshap when using a matrix as input for data and new observation + rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) + new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")]) + parts_ranger_kernelshap <- predict_parts( + rsf_ranger_exp_matrix, + new_observation = new_obs, + y_true = c(100, 1), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + plot(parts_ranger_kernelshap) + parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")]) plot(parts_src) @@ -46,6 +59,29 @@ test_that("survshap explanations work", { expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent")) expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent")) +}) + +test_that("local survshap explanations with treeshap work for ranger", { + + veteran <- survival::veteran + + rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) + + + new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))])) + parts_ranger <- model_survshap( + rsf_ranger_exp_matrix, + new_obs, + y_true = c(veteran$time[2], veteran$status[2]), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(parts_ranger) + + expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times)) + expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data))) }) @@ -67,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", { plot(parts_cph, rug = "censors") plot(parts_cph, rug = "none") + # test global exact + parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf") + plot(parts_cph_glob) + parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf") plot(parts_ranger)