From a796a99c14921fe09ffb7096b0ed906feb128f04 Mon Sep 17 00:00:00 2001 From: kapsner Date: Tue, 4 Apr 2023 17:26:18 +0200 Subject: [PATCH 01/21] feat: adding support for treeshap calculation of survshap for ranger algorithm --- DESCRIPTION | 2 +- R/surv_shap.R | 59 +++++++++++++++++++++++++++-- tests/testthat/test-predict_parts.R | 14 +++++++ 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 93e5581c..0be37ae5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,7 +18,7 @@ Description: Survival analysis models are commonly used in medicine and other ar License: GPL (>= 3) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.1 +RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), diff --git a/R/surv_shap.R b/R/surv_shap.R index bbb36b36..41e098bb 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -25,7 +25,14 @@ surv_shap <- function(explainer, exact = FALSE ) { test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) - new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)] + # make that this also works for 1-row matrix + col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) + if (is.matrix(new_observation) && nrow(new_observation) == 1) { + new_observation <- as.matrix(t(new_observation[, col_index])) + } else { + new_observation <- new_observation[, col_index] + } + if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") if (!is.null(y_true)) { @@ -38,6 +45,17 @@ surv_shap <- function(explainer, } } + # hack to use rf-model death times as explainer death times, as + # treeshap::ranger_surv_fun.unify extracts survival times directly + # from the ranger object for calculating the predictions + if (calculation_method == "treeshap") { + if (inherits(explainer$model, "ranger")) { + explainer$times <- explainer$model$unique.death.times + } else { + stop("Calculation method `treeshap` is currently only implemented for `ranger`.") + } + } + res <- list() res$eval_times <- explainer$times res$variable_values <- new_observation @@ -45,7 +63,8 @@ surv_shap <- function(explainer, res$result <- switch(calculation_method, "exact_kernel" = shap_kernel(explainer, new_observation, ...), "kernelshap" = use_kernelshap(explainer, new_observation, ...), - stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")) + "treeshap" = use_treeshap(explainer, new_observation, ...), + stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) @@ -148,14 +167,46 @@ aggregate_surv_shap <- function(survshap, method) { use_kernelshap <- function(explainer, new_observation, ...){ predfun <- function(model, newdata){ - explainer$predict_survival_function(model, newdata, times=explainer$times) + explainer$predict_survival_function(model, newdata, times = explainer$times) } tmp_res <- kernelshap::kernelshap(explainer$model, new_observation, bg_X = explainer$data, - pred_fun = predfun, verbose=FALSE) + pred_fun = predfun, verbose = FALSE) shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) colnames(shap_values) <- colnames(tmp_res$X) rownames(shap_values) <- paste("t=", explainer$times, sep = "") return(shap_values) } + +use_treeshap <- function(explainer, new_observation, ...){ + + if (inherits(explainer$model, "ranger")) { + UNIFY_FUN <- treeshap::ranger_surv_fun.unify + } else { + stop("Support for `treeshap` is currently only implemented for `ranger`.") + } + + tmp_unified <- UNIFY_FUN( + rf_model = explainer$model, + data = explainer$data + ) + + tmp_res <- do.call( + rbind, + lapply( + tmp_unified, + function(m) { + treeshap::treeshap( + unified_model = m, + x = new_observation + )$shaps + } + ) + ) + + shap_values <- data.frame(tmp_res) + colnames(shap_values) <- colnames(tmp_res) + rownames(shap_values) <- paste("t=", explainer$times, sep = "") + return(shap_values) +} diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 20539a9f..681d4ab0 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -18,6 +18,20 @@ test_that("survshap explanations work", { parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute") plot(parts_ranger) + # test ranger with treeshap (we need the data as matrix) + rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE) + new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")]) + parts_ranger_treeshap <- predict_parts( + rsf_ranger_exp_matrix, + new_observation = new_obs, + y_true = c(100, 1), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(parts_ranger_treeshap) + + parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")]) plot(parts_src) From d500fd631600cd096043263826274c1a017de251 Mon Sep 17 00:00:00 2001 From: kapsner Date: Wed, 5 Apr 2023 09:32:59 +0200 Subject: [PATCH 02/21] fix: added contraint that new-observation for predict surv_shap has exactly one row due to indexing of y_true to first element and if not using y_true the function produces strange results --- R/surv_shap.R | 10 +++++++++- tests/testthat/test-predict_parts.R | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 41e098bb..b9273a16 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -24,7 +24,15 @@ surv_shap <- function(explainer, B = 25, exact = FALSE ) { + # if providing y_true, it must be exactly one single new observation, + # otherwise the indexing of y_true doesn't make any sense + stopifnot( + ifelse(!is.null(y_true), nrow(new_observation) == 1, TRUE), + nrow(new_observation) == 1 # produces nonesense, if more than on new observation + ) + test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) + # make that this also works for 1-row matrix col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) if (is.matrix(new_observation) && nrow(new_observation) == 1) { @@ -58,7 +66,7 @@ surv_shap <- function(explainer, res <- list() res$eval_times <- explainer$times - res$variable_values <- new_observation + res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, "exact_kernel" = shap_kernel(explainer, new_observation, ...), diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 681d4ab0..be9d8ee7 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -27,7 +27,7 @@ test_that("survshap explanations work", { new_observation = new_obs, y_true = c(100, 1), aggregation_method = "mean_absolute", - calculation_method = "treeshap" + calculation_method = "kernelshap" ) plot(parts_ranger_treeshap) From 44f56dbee9eb55b7e349adf24e01d4a841c2d947 Mon Sep 17 00:00:00 2001 From: kapsner Date: Wed, 5 Apr 2023 11:04:03 +0200 Subject: [PATCH 03/21] chore: updated description --- DESCRIPTION | 3 +++ 1 file changed, 3 insertions(+) diff --git a/DESCRIPTION b/DESCRIPTION index 0be37ae5..dd65b816 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -24,6 +24,7 @@ Imports: DALEX (>= 2.2.1), ggplot2, kernelshap, + treeshap, pec, survival, patchwork @@ -44,6 +45,8 @@ Suggests: testthat (>= 3.0.0), withr, xgboost +Remotes: + github::kapsner/treeshap Config/testthat/edition: 3 VignetteBuilder: knitr URL: https://modeloriented.github.io/survex/ From b25c468e8c51be7aa35a9ee8f907543711cca05e Mon Sep 17 00:00:00 2001 From: kapsner Date: Wed, 5 Apr 2023 22:40:08 +0200 Subject: [PATCH 04/21] chore: comment for clarifying implementation of treeshap --- R/surv_shap.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/surv_shap.R b/R/surv_shap.R index 227d1bc5..09efed42 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -216,6 +216,8 @@ use_kernelshap <- function(explainer, new_observation, ...){ use_treeshap <- function(explainer, new_observation, ...){ if (inherits(explainer$model, "ranger")) { + # UNIFY_FUN to prepare code for easy Integration of other ml algorithms + # that are supported by treeshap UNIFY_FUN <- treeshap::ranger_surv_fun.unify } else { stop("Support for `treeshap` is currently only implemented for `ranger`.") From dc7d33cd471a668949a07d4f28d0c513b482cdfa Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 6 Apr 2023 10:12:31 +0200 Subject: [PATCH 05/21] chore: moved treeshap and kernelshap to suggests and added respective error handling --- DESCRIPTION | 4 ++-- R/surv_shap.R | 51 +++++++++++++++++++++++++++++++++++++++++------- man/surv_shap.Rd | 9 +++++---- 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a37d8466..a9bf39a8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,8 +23,6 @@ Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), ggplot2, - kernelshap, - treeshap, pec, survival, patchwork, @@ -36,6 +34,7 @@ Suggests: generics, glmnet, ingredients, + kernelshap, knitr, mboost, parsnip, @@ -44,6 +43,7 @@ Suggests: ranger, rmarkdown, testthat (>= 3.0.0), + treeshap, withr, xgboost Remotes: diff --git a/R/surv_shap.R b/R/surv_shap.R index 09efed42..7fed92bf 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -4,8 +4,9 @@ #' @param new_observation a new observation for which predictions need to be explained #' @param ... additional parameters, passed to internal functions #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting -#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation -#' @param aggregation_method a character, either `"mean_absolute"` or `"integral"`, `"max_absolute"`, `"sum_of_squares"` +#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, +#' or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models). +#' @param aggregation_method a character, either `"mean_absolute"` or `"integral"` (default), `"max_absolute"`, `"sum_of_squares"` #' #' @return A list, containing the calculated SurvSHAP(t) results in the `result` field #' @@ -18,12 +19,15 @@ surv_shap <- function(explainer, ..., y_true = NULL, - calculation_method = "kernelshap", - aggregation_method = "integral", + calculation_method = c("kernelshap", "exact_kernel", "treeshap"), + aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares"), path = "average", B = 25, exact = FALSE ) { + calculation_method <- match.arg(calculation_method) + aggregation_method <- match.arg(aggregation_method) + # make this code work for multiple observations stopifnot(ifelse(!is.null(y_true), ifelse(is.matrix(y_true), @@ -31,6 +35,29 @@ surv_shap <- function(explainer, is.null(dim(y_true)) && length(y_true) == 2L), TRUE)) + if (calculation_method == "kernelshap") { + if (!requireNamespace("kernelshap", quietly = TRUE)) { + stop( + paste0( + "Package \"kernelshap\" must be installed to use ", + "'calculation_method = \"kernelshap\"'." + ), + call. = FALSE + ) + } + } + if (calculation_method == "treeshap") { + if (!requireNamespace("treeshap", quietly = TRUE)) { + stop( + paste0( + "Package \"treeshap\" must be installed to use ", + "'calculation_method = \"treeshap\"'." + ), + call. = FALSE + ) + } + } + test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) # make this code also work for 1-row matrix @@ -56,8 +83,8 @@ surv_shap <- function(explainer, } } - # hack to use rf-model death times as explainer death times, as - # treeshap::ranger_surv_fun.unify extracts survival times directly + # hack to use rf-model's death times as explainer death times, as + # treeshap::ranger_surv_fun.unify extracts survival time-points directly # from the ranger object for calculating the predictions if (calculation_method == "treeshap") { if (inherits(explainer$model, "ranger")) { @@ -215,19 +242,29 @@ use_kernelshap <- function(explainer, new_observation, ...){ use_treeshap <- function(explainer, new_observation, ...){ + # init unify_append_args + unify_append_args <- list() + if (inherits(explainer$model, "ranger")) { # UNIFY_FUN to prepare code for easy Integration of other ml algorithms # that are supported by treeshap UNIFY_FUN <- treeshap::ranger_surv_fun.unify + unify_append_args <- list(type = "survival") } else { stop("Support for `treeshap` is currently only implemented for `ranger`.") } - tmp_unified <- UNIFY_FUN( + unify_args <- list( rf_model = explainer$model, data = explainer$data ) + if (length(unify_append_args) > 0) { + unify_args <- c(unify_args, unify_append_args) + } + + tmp_unified <- do.call(UNIFY_FUN, unify_args) + tmp_res_list <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 0e8457d9..bd00fdcc 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -9,8 +9,8 @@ surv_shap( new_observation, ..., y_true = NULL, - calculation_method = "kernelshap", - aggregation_method = "integral", + calculation_method = c("kernelshap", "exact_kernel", "treeshap"), + aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares"), path = "average", B = 25, exact = FALSE @@ -25,9 +25,10 @@ surv_shap( \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} -\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements), \code{"exact_kernel"} for exact Kernel SHAP estimation, +or \code{"treeshap"} for use of \code{treeshap} library (efficient implementation to compute SHAP values for tree-based models).} -\item{aggregation_method}{a character, either \code{"mean_absolute"} or \code{"integral"}, \code{"max_absolute"}, \code{"sum_of_squares"}} +\item{aggregation_method}{a character, either \code{"mean_absolute"} or \code{"integral"} (default), \code{"max_absolute"}, \code{"sum_of_squares"}} } \value{ A list, containing the calculated SurvSHAP(t) results in the \code{result} field From b8135ee2ca810bbf93d86c29edc8a58a17ac71a8 Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 6 Apr 2023 10:46:43 +0200 Subject: [PATCH 06/21] fix: fixed issues when providing matrix to kernelshap --- R/surv_shap.R | 15 +++++++++++++-- tests/testthat/test-predict_parts.R | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 7fed92bf..76918d13 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -63,9 +63,12 @@ surv_shap <- function(explainer, # make this code also work for 1-row matrix col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) if (is.matrix(new_observation) && nrow(new_observation) == 1) { - new_observation <- as.matrix(t(new_observation[, col_index])) + new_observation <- data.frame(as.matrix(t(new_observation[, col_index]))) } else { new_observation <- new_observation[, col_index] + if (!inherits(new_observation, "data.frame")) { + new_observation <- data.frame(new_observation) + } } if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") @@ -205,6 +208,14 @@ aggregate_surv_shap <- function(survshap, method) { use_kernelshap <- function(explainer, new_observation, ...){ + stopifnot(inherits(new_observation, "data.frame")) + # get explainer data to be able to make class checks and transformations + explainer_data <- explainer$data + # ensure that classes of explainer$data and new_observation are equal + if (!inherits(explainer_data, "data.frame")) { + explainer_data <- data.frame(explainer_data) + } + predfun <- function(model, newdata){ explainer$predict_survival_function( model, @@ -219,7 +230,7 @@ use_kernelshap <- function(explainer, new_observation, ...){ tmp_res <- kernelshap::kernelshap( object = explainer$model, X = new_observation[as.integer(i), ], - bg_X = explainer$data, + bg_X = explainer_data, pred_fun = predfun, verbose = FALSE ) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 7223f020..e871e452 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -22,7 +22,7 @@ test_that("survshap explanations work", { rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE) new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")]) - parts_ranger_treeshap <- predict_parts( + parts_ranger_kernelshap <- predict_parts( rsf_ranger_exp_matrix, new_observation = new_obs, y_true = c(100, 1), From ad886bb4bd08b70bcaf65acd2c5778db50ae29be Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 6 Apr 2023 11:54:25 +0200 Subject: [PATCH 07/21] feat: fully functional treeshap integration including support for global survshap values and unit-tests --- R/surv_shap.R | 22 +++++++---- tests/testthat/test-predict_parts.R | 61 +++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 76918d13..512bdf12 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -86,11 +86,11 @@ surv_shap <- function(explainer, } } - # hack to use rf-model's death times as explainer death times, as - # treeshap::ranger_surv_fun.unify extracts survival time-points directly - # from the ranger object for calculating the predictions if (calculation_method == "treeshap") { if (inherits(explainer$model, "ranger")) { + # hack to use rf-model's death times as explainer death times, as + # treeshap::ranger_surv_fun.unify extracts survival time-points directly + # from the ranger object for calculating the predictions explainer$times <- explainer$model$unique.death.times } else { stop("Calculation method `treeshap` is currently only implemented for `ranger`.") @@ -253,13 +253,15 @@ use_kernelshap <- function(explainer, new_observation, ...){ use_treeshap <- function(explainer, new_observation, ...){ + stopifnot(inherits(new_observation, "data.frame")) + # init unify_append_args unify_append_args <- list() if (inherits(explainer$model, "ranger")) { # UNIFY_FUN to prepare code for easy Integration of other ml algorithms # that are supported by treeshap - UNIFY_FUN <- treeshap::ranger_surv_fun.unify + UNIFY_FUN <- treeshap::ranger_surv.unify unify_append_args <- list(type = "survival") } else { stop("Support for `treeshap` is currently only implemented for `ranger`.") @@ -284,10 +286,14 @@ use_treeshap <- function(explainer, new_observation, ...){ lapply( tmp_unified, function(m) { - treeshap::treeshap( - unified_model = m, - x = new_observation - )$shaps + new_obs_mat <- as.matrix(new_observation[as.integer(i), ]) + # ensure that matrix has expected dimensions; as.integer is + # necessary for valid comparison with "identical" + stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation))))) + treeshap::treeshap( + unified_model = m, + x = new_obs_mat + )$shaps } ) ) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 8d1e80b3..08c58e13 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -18,7 +18,7 @@ test_that("survshap explanations work", { parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute") plot(parts_ranger) - # test ranger with treeshap (we need the data as matrix) + # test ranger with kernelshap when using a matrix as input for data and new observation rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE) new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")]) @@ -29,7 +29,7 @@ test_that("survshap explanations work", { aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) - plot(parts_ranger_treeshap) + plot(parts_ranger_kernelshap) parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")]) @@ -59,9 +59,9 @@ test_that("survshap explanations work", { expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent")) expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent")) - }) + test_that("global survshap explanations with kernelshap work for ranger", { veteran <- survival::veteran @@ -84,6 +84,61 @@ test_that("global survshap explanations with kernelshap work for ranger", { }) +# dont need to compute common code multiple times +veteran <- survival::veteran + +rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) +rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE) + +test_that("local survshap explanations with treeshap work for ranger", { + + new_obs <- model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))]) + parts_ranger <- predict_parts( + rsf_ranger_exp_matrix, + new_obs, + y_true = c(veteran$time[2], veteran$status[2]), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(parts_ranger) + + expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) + # treeshap does not use time-points from explainer but instead time points provided by the ranger-model + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$model$unique.death.times)) + expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data))) + +}) + + +test_that("global survshap explanations with treeshap work for ranger", { + + new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))]) + parts_ranger_tree <- predict_parts( + rsf_ranger_exp_matrix, + new_obs, + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(parts_ranger_tree) + + expect_s3_class(parts_ranger_tree, c("predict_parts_survival", "surv_shap")) + # treeshap does not use time-points from explainer but instead time points provided by the ranger-model + expect_equal(nrow(parts_ranger_tree$result), length(rsf_ranger_exp_matrix$model$unique.death.times)) + expect_true(all(colnames(parts_ranger_tree$result) == colnames(rsf_ranger_exp_matrix$data))) + + # to compare plots, compute kernelshap with dummified features + parts_ranger_kernel <- predict_parts( + rsf_ranger_exp_matrix, + new_obs, + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + plot(parts_ranger_kernel) +}) + + test_that("survlime explanations work", { veteran <- survival::veteran From d2c16a608792c6387b2e65890446521890d088e0 Mon Sep 17 00:00:00 2001 From: kapsner Date: Sat, 8 Apr 2023 17:10:16 +0200 Subject: [PATCH 08/21] feat: code adaptions to treeshap computation using pre-difned survival times from explain-object addresses #75 --- DESCRIPTION | 2 +- R/surv_shap.R | 9 ++------- tests/testthat/test-predict_parts.R | 1 - 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a9bf39a8..6364d21e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9001 +Version: 1.0.0.9002 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), diff --git a/R/surv_shap.R b/R/surv_shap.R index 512bdf12..fa3cdbac 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -87,12 +87,7 @@ surv_shap <- function(explainer, } if (calculation_method == "treeshap") { - if (inherits(explainer$model, "ranger")) { - # hack to use rf-model's death times as explainer death times, as - # treeshap::ranger_surv_fun.unify extracts survival time-points directly - # from the ranger object for calculating the predictions - explainer$times <- explainer$model$unique.death.times - } else { + if (!inherits(explainer$model, "ranger")) { stop("Calculation method `treeshap` is currently only implemented for `ranger`.") } } @@ -262,7 +257,7 @@ use_treeshap <- function(explainer, new_observation, ...){ # UNIFY_FUN to prepare code for easy Integration of other ml algorithms # that are supported by treeshap UNIFY_FUN <- treeshap::ranger_surv.unify - unify_append_args <- list(type = "survival") + unify_append_args <- list(type = "survival", times = explainer$times) } else { stop("Support for `treeshap` is currently only implemented for `ranger`.") } diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 08c58e13..c137bbbc 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -31,7 +31,6 @@ test_that("survshap explanations work", { ) plot(parts_ranger_kernelshap) - parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")]) plot(parts_src) From 1329254ad6a092bfd38575eb2f3fb9fa48b399d1 Mon Sep 17 00:00:00 2001 From: kapsner Date: Sun, 9 Apr 2023 12:56:57 +0200 Subject: [PATCH 09/21] refactor: implemented quality checks and now explicitly control row order for kernelshap and treeshap --- R/surv_shap.R | 26 +++++++++++++++++--------- tests/testthat/test-predict_parts.R | 18 ++++++++---------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index fa3cdbac..9a9c2401 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -102,6 +102,8 @@ surv_shap <- function(explainer, "kernelshap" = use_kernelshap(explainer, new_observation, ...), "treeshap" = use_treeshap(explainer, new_observation, ...), stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) + # quality-check here + stopifnot(nrow(res$result) == length(res$eval_times)) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) @@ -129,7 +131,7 @@ shap_kernel <- function(explainer, new_observation, ...) { shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data)) colnames(shap_values) <- paste("t=", timestamps, sep = "") - return (t(shap_values)) + return(t(shap_values)) } generate_shap_kernel_weights <- function(permutations, p) { @@ -229,10 +231,12 @@ use_kernelshap <- function(explainer, new_observation, ...){ pred_fun = predfun, verbose = FALSE ) - tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) + tmp_shap_values <- data.table::as.data.table( + t(sapply(tmp_res$S, cbind)) + ) colnames(tmp_shap_values) <- colnames(tmp_res$X) - rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") - data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE) + tmp_shap_values$rn <- explainer$times + return(tmp_shap_values) }, USE.NAMES = TRUE, simplify = FALSE @@ -293,10 +297,10 @@ use_treeshap <- function(explainer, new_observation, ...){ ) ) - tmp_shap_values <- data.frame(tmp_res) + tmp_shap_values <- data.table::as.data.table(tmp_res) colnames(tmp_shap_values) <- colnames(tmp_res) - rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") - data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE) + tmp_shap_values$rn <- explainer$times + return(tmp_shap_values) }, USE.NAMES = TRUE, simplify = FALSE @@ -330,10 +334,14 @@ aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) { # no aggregation required tmp_res <- shap_res_list[[1]] } - shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")] + + # rn == explainer$times -> now sort everything by these times so that time-points are + # then named correctly + shap_values <- tmp_res[order(get("rn"))][, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")] + # transform to data.frame to make everything compatible with # previous code shap_values <- data.frame(shap_values) - rownames(shap_values) <- tmp_res$rn + rownames(shap_values) <- paste("t=", tmp_res$rn, sep = "") return(shap_values) } diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index c137bbbc..f6e88df7 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -6,7 +6,7 @@ test_that("survshap explanations work", { rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph, verbose = FALSE) - rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) rsf_src_exp <- explain(rsf_src, verbose = FALSE) parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares") @@ -20,7 +20,7 @@ test_that("survshap explanations work", { # test ranger with kernelshap when using a matrix as input for data and new observation rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) - rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")]) parts_ranger_kernelshap <- predict_parts( rsf_ranger_exp_matrix, @@ -65,7 +65,7 @@ test_that("global survshap explanations with kernelshap work for ranger", { veteran <- survival::veteran rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) - rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) parts_ranger <- predict_parts( rsf_ranger_exp, @@ -87,7 +87,7 @@ test_that("global survshap explanations with kernelshap work for ranger", { veteran <- survival::veteran rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) -rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE) +rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) test_that("local survshap explanations with treeshap work for ranger", { @@ -102,8 +102,7 @@ test_that("local survshap explanations with treeshap work for ranger", { plot(parts_ranger) expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) - # treeshap does not use time-points from explainer but instead time points provided by the ranger-model - expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$model$unique.death.times)) + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times)) expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data))) }) @@ -115,22 +114,21 @@ test_that("global survshap explanations with treeshap work for ranger", { parts_ranger_tree <- predict_parts( rsf_ranger_exp_matrix, new_obs, - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), aggregation_method = "mean_absolute", calculation_method = "treeshap" ) plot(parts_ranger_tree) expect_s3_class(parts_ranger_tree, c("predict_parts_survival", "surv_shap")) - # treeshap does not use time-points from explainer but instead time points provided by the ranger-model - expect_equal(nrow(parts_ranger_tree$result), length(rsf_ranger_exp_matrix$model$unique.death.times)) + expect_equal(nrow(parts_ranger_tree$result), length(rsf_ranger_exp_matrix$times)) expect_true(all(colnames(parts_ranger_tree$result) == colnames(rsf_ranger_exp_matrix$data))) # to compare plots, compute kernelshap with dummified features parts_ranger_kernel <- predict_parts( rsf_ranger_exp_matrix, new_obs, - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) From 60e3b91d395e29e0861dcc665a027319f8d024a7 Mon Sep 17 00:00:00 2001 From: kapsner Date: Sun, 9 Apr 2023 13:17:01 +0200 Subject: [PATCH 10/21] chore: rounding time-points for row names to a maximum of 2 digits --- R/surv_shap.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 9a9c2401..378b7208 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -342,6 +342,6 @@ aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) { # transform to data.frame to make everything compatible with # previous code shap_values <- data.frame(shap_values) - rownames(shap_values) <- paste("t=", tmp_res$rn, sep = "") + rownames(shap_values) <- paste("t=", round(tmp_res$rn, 2), sep = "") return(shap_values) } From 5b1a810ee111a0ce68fb00fe12d1473e16ceb6e5 Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 13 Apr 2023 16:59:42 +0200 Subject: [PATCH 11/21] chore: more informative error messages in stopifnot statements --- R/surv_shap.R | 26 ++++++++++++++++++-------- tests/testthat/test-predict_parts.R | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 378b7208..647edc25 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -29,11 +29,12 @@ surv_shap <- function(explainer, aggregation_method <- match.arg(aggregation_method) # make this code work for multiple observations - stopifnot(ifelse(!is.null(y_true), - ifelse(is.matrix(y_true), - nrow(new_observation) == nrow(y_true), - is.null(dim(y_true)) && length(y_true) == 2L), - TRUE)) + stopifnot("y_true must be either a 2-column matrix of same length as new_observation, or a 2-element vector" = ifelse( + !is.null(y_true), + ifelse(is.matrix(y_true), + nrow(new_observation) == nrow(y_true), + is.null(dim(y_true)) && length(y_true) == 2L), + TRUE)) if (calculation_method == "kernelshap") { if (!requireNamespace("kernelshap", quietly = TRUE)) { @@ -103,7 +104,10 @@ surv_shap <- function(explainer, "treeshap" = use_treeshap(explainer, new_observation, ...), stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) # quality-check here - stopifnot(nrow(res$result) == length(res$eval_times)) + stopifnot( + "Number of rows of SurvSHAP table are not identical with length(eval_times)" = + nrow(res$result) == length(res$eval_times) + ) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) @@ -205,7 +209,10 @@ aggregate_surv_shap <- function(survshap, method) { use_kernelshap <- function(explainer, new_observation, ...){ - stopifnot(inherits(new_observation, "data.frame")) + stopifnot( + "new_observation must be a data.frame" = inherits( + new_observation, "data.frame") + ) # get explainer data to be able to make class checks and transformations explainer_data <- explainer$data # ensure that classes of explainer$data and new_observation are equal @@ -252,7 +259,10 @@ use_kernelshap <- function(explainer, new_observation, ...){ use_treeshap <- function(explainer, new_observation, ...){ - stopifnot(inherits(new_observation, "data.frame")) + stopifnot( + "new_observation must be a data.frame" = inherits( + new_observation, "data.frame") + ) # init unify_append_args unify_append_args <- list() diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index f6e88df7..c1d02a31 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -70,7 +70,7 @@ test_that("global survshap explanations with kernelshap work for ranger", { parts_ranger <- predict_parts( rsf_ranger_exp, veteran[1:40, !colnames(veteran) %in% c("time", "status")], - y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) From 6a4c9268076f725fa63595ee5bda597ac30ccdbf Mon Sep 17 00:00:00 2001 From: kapsner Date: Tue, 25 Jul 2023 14:08:44 +0200 Subject: [PATCH 12/21] feat: adaptions of treeshap feature branch to new global survshap - removed data.table parts - updated error handling - added unit tests for treeshap (both, local and global shap) addresses #75 --- DESCRIPTION | 7 ++++--- R/surv_shap.R | 18 +++++++----------- tests/testthat/test-model_survshap.R | 10 +++++++--- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 742bfaf1..41ce5324 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,12 +1,13 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9002 +Version: 1.0.0.9003 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")), person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")), - person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823")) + person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823")), + person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")) ) Description: Survival analysis models are commonly used in medicine and other areas. Many of them are too complex to be interpreted by human. Exploration and explanation is needed, but @@ -46,7 +47,7 @@ Suggests: withr, xgboost Remotes: - github::kapsner/treeshap + github::ModelOriented/treeshap Config/testthat/edition: 3 VignetteBuilder: knitr URL: https://modeloriented.github.io/survex/ diff --git a/R/surv_shap.R b/R/surv_shap.R index fb295cc7..24d73292 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -160,8 +160,6 @@ shap_kernel <- function(explainer, new_observation, ...) { shap_values <- calculate_shap_values(explainer, explainer$model, baseline_sf, as.data.frame(explainer$data), permutations, kernel_weights, as.data.frame(new_observation), timestamps) - - shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data)) colnames(shap_values) <- paste("t=", timestamps, sep = "") return(t(shap_values)) @@ -259,13 +257,11 @@ use_kernelshap <- function(explainer, new_observation, observation_aggregation_m tmp_res <- kernelshap::kernelshap( object = explainer$model, X = new_observation[as.integer(i), ], - bg_X = explainer_data, + bg_X = explainer$data, pred_fun = predfun, verbose = FALSE ) - tmp_shap_values <- data.table::as.data.table( - t(sapply(tmp_res$S, cbind)) - ) + tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) colnames(tmp_shap_values) <- colnames(tmp_res$X) rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") tmp_shap_values @@ -307,7 +303,7 @@ use_treeshap <- function(explainer, new_observation, ...){ tmp_unified <- do.call(UNIFY_FUN, unify_args) - tmp_res_list <- sapply( + shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { tmp_res <- do.call( @@ -327,10 +323,10 @@ use_treeshap <- function(explainer, new_observation, ...){ ) ) - tmp_shap_values <- data.table::as.data.table(tmp_res) + tmp_shap_values <- data.frame(tmp_res) colnames(tmp_shap_values) <- colnames(tmp_res) - tmp_shap_values$rn <- explainer$times - return(tmp_shap_values) + rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") + tmp_shap_values }, USE.NAMES = TRUE, simplify = FALSE @@ -340,7 +336,7 @@ use_treeshap <- function(explainer, new_observation, ...){ } -#'@internal +# @internal aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index abb05c9c..d3f6d85c 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -1,4 +1,5 @@ +# create objects here so that they do not have to be created redundantly veteran <- survival::veteran rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) @@ -41,13 +42,16 @@ test_that("global survshap explanations with kernelshap work for ranger, using e }) +# testing if matrix works as input +rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) +rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) test_that("global survshap explanations with treeshap work for ranger", { new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))]) ranger_global_survshap_tree <- model_survshap( - rsf_ranger_exp, - new_obs, + rsf_ranger_exp_matrix, + new_observation = new_obs, y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), aggregation_method = "mean_absolute", calculation_method = "treeshap" @@ -56,6 +60,6 @@ test_that("global survshap explanations with treeshap work for ranger", { expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap")) expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp$times)) - expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp$data))) + expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data))) }) From a9269f093c0708aba59ab4bf5f12efe901e85566 Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 31 Aug 2023 11:32:13 +0200 Subject: [PATCH 13/21] chore: updated messages --- R/surv_shap.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 2bff61ab..2478e397 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -93,7 +93,7 @@ surv_shap <- function(explainer, if (calculation_method == "treeshap") { if (!inherits(explainer$model, "ranger")) { - stop("Calculation method `treeshap` is currently only implemented for `ranger`.") + stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.") } } From 0e8c85c24b8b499bba9343a22b848bbd2f9e6eda Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 31 Aug 2023 12:02:34 +0200 Subject: [PATCH 14/21] fix: added missing output_type arguments to switch of calculation method in surv_shap which caused error of github action unittests when using 'exact_kernel' --- R/surv_shap.R | 8 ++++---- tests/testthat/test-predict_parts.R | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 2478e397..9c4f6ac0 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -28,7 +28,7 @@ surv_shap <- function(explainer, # make this code work for multiple observations stopifnot( - "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( + "`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse( !is.null(y_true), ifelse( is.matrix(y_true), @@ -102,8 +102,8 @@ surv_shap <- function(explainer, # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = use_exact_shap(explainer, new_observation, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, ...), + "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...), "treeshap" = use_treeshap(explainer, new_observation, ...), stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) # quality-check here @@ -129,7 +129,7 @@ surv_shap <- function(explainer, return(res) } -use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) { +use_exact_shap <- function(explainer, new_observation, output_type, ...) { shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index ec4f4807..f8399d1b 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -103,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", { plot(parts_cph, rug = "censors") plot(parts_cph, rug = "none") + # test global exact + parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf") + plot(parts_cph_glob) + parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf") plot(parts_ranger) From 40a076eb647e6ec178c7e830266bd2c913796acb Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 31 Aug 2023 12:46:17 +0200 Subject: [PATCH 15/21] fix: new_observation now as.matrix for kernelshap --- R/surv_shap.R | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 9c4f6ac0..2b0e5d37 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -257,20 +257,12 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ explainer_data <- data.frame(explainer_data) } - predfun <- function(model, newdata){ - explainer$predict_survival_function( - model, - newdata, - times = explainer$times - ) - } - shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { tmp_res <- kernelshap::kernelshap( object = explainer$model, - X = new_observation[as.integer(i), ], + X = as.matrix(new_observation[as.integer(i), ]), bg_X = explainer$data, pred_fun = predfun, verbose = FALSE From a496a673421ee73af367a9d88976325cc39f09da Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 31 Aug 2023 12:57:58 +0200 Subject: [PATCH 16/21] fix: removed data.frame conversion of explainer in kernelshap --- R/surv_shap.R | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 2b0e5d37..b46f967f 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -250,12 +250,6 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ "new_observation must be a data.frame" = inherits( new_observation, "data.frame") ) - # get explainer data to be able to make class checks and transformations - explainer_data <- explainer$data - # ensure that classes of explainer$data and new_observation are equal - if (!inherits(explainer_data, "data.frame")) { - explainer_data <- data.frame(explainer_data) - } shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), @@ -263,7 +257,7 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ tmp_res <- kernelshap::kernelshap( object = explainer$model, X = as.matrix(new_observation[as.integer(i), ]), - bg_X = explainer$data, + bg_X = as.matrix(explainer$data), pred_fun = predfun, verbose = FALSE ) From 23e57da7ddfcf1b5c3aaebd0e117c1a6b7e28ac8 Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 31 Aug 2023 14:19:07 +0200 Subject: [PATCH 17/21] fix: another try to fix kernelshap data, now X and bg_X as data.frame --- R/surv_shap.R | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index b46f967f..46e8d5e9 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -251,16 +251,24 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ new_observation, "data.frame") ) + # get explainer data to be able to make class checks and transformations + explainer_data <- explainer$data + # ensure that classes of explainer$data and new_observation are equal + if (!inherits(explainer_data, "data.frame")) { + explainer_data <- data.frame(explainer_data) + } + shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { tmp_res <- kernelshap::kernelshap( object = explainer$model, - X = as.matrix(new_observation[as.integer(i), ]), - bg_X = as.matrix(explainer$data), + X = new_observation[as.integer(i), ], # data.frame + bg_X = explainer_data, # data.frame pred_fun = predfun, verbose = FALSE ) + # kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE` tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) colnames(tmp_shap_values) <- colnames(tmp_res$X) rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") From f61aad55b1351e8ed66167a0348ef167e70de51e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Krzyzi=C5=84ski?= Date: Mon, 2 Oct 2023 11:20:50 +0200 Subject: [PATCH 18/21] Update surv_shap.R description --- R/surv_shap.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index 46e8d5e9..e2f84b1d 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -5,8 +5,7 @@ #' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations. #' @param ... additional parameters, passed to internal functions #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting -#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, -#' or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models). +#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models). #' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"` #' #' @return A list, containing the calculated SurvSHAP(t) results in the `result` field From a3a83114e66c60001222ea77496dde43856c1e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Krzyzi=C5=84ski?= Date: Mon, 2 Oct 2023 11:21:42 +0200 Subject: [PATCH 19/21] remove treeshap from Remotes --- DESCRIPTION | 2 -- 1 file changed, 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 4bad3e28..634701c5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -50,8 +50,6 @@ Suggests: testthat (>= 3.0.0), withr, xgboost -Remotes: - github::ModelOriented/treeshap Config/testthat/edition: 3 VignetteBuilder: knitr URL: https://modeloriented.github.io/survex/ From 885a27b8f9037fc48a2c052b5d9a71d1e2182528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Krzyzi=C5=84ski?= Date: Mon, 2 Oct 2023 12:54:26 +0200 Subject: [PATCH 20/21] Update surv_shap.R --- R/surv_shap.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/surv_shap.R b/R/surv_shap.R index e2f84b1d..e0d47593 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -318,7 +318,7 @@ use_treeshap <- function(explainer, new_observation, ...){ lapply( tmp_unified, function(m) { - new_obs_mat <- as.matrix(new_observation[as.integer(i), ]) + new_obs_mat <- new_observation[as.integer(i), ] # ensure that matrix has expected dimensions; as.integer is # necessary for valid comparison with "identical" stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation))))) From 1d275eb8f9a50600185220eced9b5e7f1c6b704f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Krzyzi=C5=84ski?= Date: Mon, 2 Oct 2023 12:54:59 +0200 Subject: [PATCH 21/21] Update test-predict_parts.R --- tests/testthat/test-predict_parts.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index f8399d1b..6e8c733f 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -69,7 +69,7 @@ test_that("local survshap explanations with treeshap work for ranger", { rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) - new_obs <- model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))]) + new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))])) parts_ranger <- model_survshap( rsf_ranger_exp_matrix, new_obs,