Skip to content

Commit

Permalink
Merge pull request #85 from kapsner/feat-treeshap
Browse files Browse the repository at this point in the history
Feature: support SurvSHAP computation with {treeshap}
  • Loading branch information
krzyzinskim authored Oct 2, 2023
2 parents 3864b87 + 1d275eb commit 0b2f4f5
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 19 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Authors@R:
person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")),
person("Sophie", "Langbein", role = c("aut")),
person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")),
person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")),
person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823"))
)
Description: Survival analysis models are commonly used in medicine and other areas. Many of them
Expand All @@ -25,6 +26,7 @@ Imports:
DALEX (>= 2.2.1),
ggplot2 (>= 3.4.0),
kernelshap,
treeshap,
pec,
survival,
patchwork
Expand Down
142 changes: 128 additions & 14 deletions R/surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
#' @param ... additional parameters, passed to internal functions
#' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
#'
#' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
Expand All @@ -19,10 +19,15 @@ surv_shap <- function(explainer,
output_type,
...,
y_true = NULL,
calculation_method = "kernelshap",
aggregation_method = "integral") {
calculation_method = c("kernelshap", "exact_kernel", "treeshap"),
aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares")
) {
calculation_method <- match.arg(calculation_method)
aggregation_method <- match.arg(aggregation_method)

# make this code work for multiple observations
stopifnot(
"`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
"`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
!is.null(y_true),
ifelse(
is.matrix(y_true),
Expand All @@ -33,14 +38,40 @@ surv_shap <- function(explainer,
)
)

if (calculation_method == "kernelshap") {
if (!requireNamespace("kernelshap", quietly = TRUE)) {
stop(
paste0(
"Package \"kernelshap\" must be installed to use ",
"'calculation_method = \"kernelshap\"'."
),
call. = FALSE
)
}
}
if (calculation_method == "treeshap") {
if (!requireNamespace("treeshap", quietly = TRUE)) {
stop(
paste0(
"Package \"treeshap\" must be installed to use ",
"'calculation_method = \"treeshap\"'."
),
call. = FALSE
)
}
}

test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)

# make this code also work for 1-row matrix
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
new_observation <- as.matrix(t(new_observation[, col_index]))
new_observation <- data.frame(as.matrix(t(new_observation[, col_index])))
} else {
new_observation <- new_observation[, col_index]
if (!inherits(new_observation, "data.frame")) {
new_observation <- data.frame(new_observation)
}
}

if (ncol(explainer$data) != ncol(new_observation)) {
Expand All @@ -59,14 +90,25 @@ surv_shap <- function(explainer,
}
}

if (calculation_method == "treeshap") {
if (!inherits(explainer$model, "ranger")) {
stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.")
}
}

res <- list()
res$eval_times <- explainer$times
# to display final object correctly, when is.matrix(new_observation) == TRUE
res$variable_values <- as.data.frame(new_observation)
res$result <- switch(calculation_method,
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
"treeshap" = use_treeshap(explainer, new_observation, ...),
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
# quality-check here
stopifnot(
"Number of rows of SurvSHAP table are not identical with length(eval_times)" =
nrow(res$result) == length(res$eval_times)
)

if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind)
Expand All @@ -86,7 +128,7 @@ surv_shap <- function(explainer,
return(res)
}

use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) {
use_exact_shap <- function(explainer, new_observation, output_type, ...) {
shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
Expand Down Expand Up @@ -123,11 +165,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) {
timestamps
)



shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data))
colnames(shap_values) <- paste("t=", timestamps, sep = "")

return(t(shap_values))
}

Expand Down Expand Up @@ -204,19 +243,31 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
times = explainer$times
)
}
}

stopifnot(
"new_observation must be a data.frame" = inherits(
new_observation, "data.frame")
)

# get explainer data to be able to make class checks and transformations
explainer_data <- explainer$data
# ensure that classes of explainer$data and new_observation are equal
if (!inherits(explainer_data, "data.frame")) {
explainer_data <- data.frame(explainer_data)
}

shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
tmp_res <- kernelshap::kernelshap(
object = explainer$model,
X = new_observation[as.integer(i), ],
bg_X = explainer$data,
X = new_observation[as.integer(i), ], # data.frame
bg_X = explainer_data, # data.frame
pred_fun = predfun,
verbose = FALSE
)
# kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
colnames(tmp_shap_values) <- colnames(tmp_res$X)
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
Expand All @@ -229,6 +280,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
return(shap_values)
}

use_treeshap <- function(explainer, new_observation, ...){

stopifnot(
"new_observation must be a data.frame" = inherits(
new_observation, "data.frame")
)

# init unify_append_args
unify_append_args <- list()

if (inherits(explainer$model, "ranger")) {
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
# that are supported by treeshap
UNIFY_FUN <- treeshap::ranger_surv.unify
unify_append_args <- list(type = "survival", times = explainer$times)
} else {
stop("Support for `treeshap` is currently only implemented for `ranger`.")
}

unify_args <- list(
rf_model = explainer$model,
data = explainer$data
)

if (length(unify_append_args) > 0) {
unify_args <- c(unify_args, unify_append_args)
}

tmp_unified <- do.call(UNIFY_FUN, unify_args)

shap_values <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
tmp_res <- do.call(
rbind,
lapply(
tmp_unified,
function(m) {
new_obs_mat <- new_observation[as.integer(i), ]
# ensure that matrix has expected dimensions; as.integer is
# necessary for valid comparison with "identical"
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
treeshap::treeshap(
unified_model = m,
x = new_obs_mat
)$shaps
}
)
)

tmp_shap_values <- data.frame(tmp_res)
colnames(tmp_shap_values) <- colnames(tmp_res)
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
tmp_shap_values
},
USE.NAMES = TRUE,
simplify = FALSE
)

return(shap_values)

}

#' @keywords internal
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) {
if (length(shap_res_list) > 1) {
Expand Down
3 changes: 2 additions & 1 deletion man/model_survshap.surv_explainer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions man/surv_shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-model_survshap.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

# create objects here so that they do not have to be created redundantly
veteran <- survival::veteran
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
Expand Down Expand Up @@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex
expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times))
expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data)))
})

# testing if matrix works as input
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)

test_that("global survshap explanations with treeshap work for ranger", {

new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))])
ranger_global_survshap_tree <- model_survshap(
rsf_ranger_exp_matrix,
new_observation = new_obs,
y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]),
aggregation_method = "mean_absolute",
calculation_method = "treeshap"
)
plot(ranger_global_survshap_tree)

expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap"))
expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times))
expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data)))

})
42 changes: 41 additions & 1 deletion tests/testthat/test-predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_that("survshap explanations work", {
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)

cph_exp <- explain(cph, verbose = FALSE)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
rsf_src_exp <- explain(rsf_src, verbose = FALSE)

parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares")
Expand All @@ -19,6 +19,19 @@ test_that("survshap explanations work", {
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute")
plot(parts_ranger)

# test ranger with kernelshap when using a matrix as input for data and new observation
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")])
parts_ranger_kernelshap <- predict_parts(
rsf_ranger_exp_matrix,
new_observation = new_obs,
y_true = c(100, 1),
aggregation_method = "mean_absolute",
calculation_method = "kernelshap"
)
plot(parts_ranger_kernelshap)

parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")])
plot(parts_src)

Expand Down Expand Up @@ -46,6 +59,29 @@ test_that("survshap explanations work", {
expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent"))
expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent"))

})

test_that("local survshap explanations with treeshap work for ranger", {

veteran <- survival::veteran

rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)


new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))]))
parts_ranger <- model_survshap(
rsf_ranger_exp_matrix,
new_obs,
y_true = c(veteran$time[2], veteran$status[2]),
aggregation_method = "mean_absolute",
calculation_method = "treeshap"
)
plot(parts_ranger)

expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times))
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data)))

})

Expand All @@ -67,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", {
plot(parts_cph, rug = "censors")
plot(parts_cph, rug = "none")

# test global exact
parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf")
plot(parts_cph_glob)

parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf")
plot(parts_ranger)

Expand Down

0 comments on commit 0b2f4f5

Please sign in to comment.