diff --git a/DESCRIPTION b/DESCRIPTION index 8a65d264..a958d5fc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,12 +1,13 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.1.3.9000 +Version: 1.2.0 Authors@R: c( person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")), person("Sophie", "Langbein", role = c("aut")), person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")), + person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")), person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823")) ) Description: Survival analysis models are commonly used in medicine and other areas. Many of them @@ -46,6 +47,7 @@ Suggests: rmarkdown, rms, testthat (>= 3.0.0), + treeshap (>= 0.3.0), withr, xgboost Config/testthat/edition: 3 diff --git a/NEWS.md b/NEWS.md index e9328333..9329a134 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,8 @@ -# survex (development version) +# survex 1.2.0 +* added new `calculation_method` for `surv_shap()` called `"treeshap"` that uses the `treeshap` package ([#75](https://github.com/ModelOriented/survex/issues/75)) +* enable to calculate SurvSHAP(t) explanations based on subsample of the explainer's data +* changed default kernel width in SurvLIME from sqrt(p * 0.75) to sqrt(p) * 0.75 +* fixed error in SurvLIME when non-factor `categorical_variables` were provided # survex 1.1.3 diff --git a/R/metrics.R b/R/metrics.R index e9687f9b..9e95dc49 100644 --- a/R/metrics.R +++ b/R/metrics.R @@ -12,7 +12,7 @@ utils::globalVariables(c("PredictionSurv")) #' @return a function that can be used to calculate metrics (with parameters `y_true`, `risk`, `surv`, and `times`) #' #' @section References: -#' - \[1\] Graf, Erika, et al. ["Assessment and comparison of prognostic classification schemes for survival data."](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5) Statistics in Medicine 18.17‐18 (1999): 2529-2545. +#' - \[1\] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. #' #' @export loss_integrate <- function(loss_function, ..., normalization = NULL, max_quantile = 1) { @@ -57,7 +57,7 @@ loss_integrate <- function(loss_function, ..., normalization = NULL, max_quantil #' @return numeric from 0 to 1, higher values indicate better performance #' #' @section References: -#' - \[1\] Harrell, F.E., Jr., et al. ["Regression modelling strategies for improved prognostic prediction."](https://onlinelibrary.wiley.com/doi/10.1002/sim.4780030207) Statistics in Medicine 3.2 (1984): 143-152. +#' - \[1\] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152. #' #' @rdname c_index #' @seealso [loss_one_minus_c_index()] @@ -109,7 +109,7 @@ attr(c_index, "loss_type") <- "risk-based" #' @return numeric from 0 to 1, lower values indicate better performance #' #' @section References: -#' - \[1\] Harrell, F.E., Jr., et al. ["Regression modelling strategies for improved prognostic prediction."](https://onlinelibrary.wiley.com/doi/10.1002/sim.4780030207) Statistics in Medicine 3.2 (1984): 143-152. +#' - \[1\] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152. #' #' @rdname loss_one_minus_c_index #' @seealso [c_index()] @@ -152,8 +152,8 @@ attr(loss_one_minus_c_index, "loss_type") <- "risk-based" #' @return numeric from 0 to 1, lower scores are better (Brier score of 0.25 represents a model which returns always returns 0.5 as the predicted survival function) #' #' @section References: -#' - \[1\] Brier, Glenn W. ["Verification of forecasts expressed in terms of probability."](https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml) Monthly Weather Review 78.1 (1950): 1-3. -#' - \[2\] Graf, Erika, et al. ["Assessment and comparison of prognostic classification schemes for survival data."](https://onlinelibrary.wiley.com/doi/10.1002/(SICI)1097-0258(19990915/30)18:17/18%3C2529::AID-SIM274%3E3.0.CO;2-5) Statistics in Medicine 18.17‐18 (1999): 2529-2545. +#' - \[1\] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3. +#' - \[2\] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. #' #' @rdname brier_score #' @seealso [cd_auc()] @@ -217,8 +217,8 @@ attr(loss_brier_score, "loss_type") <- "time-dependent" #' Calculate Cumulative/Dynamic AUC #' #' This function calculates the Cumulative/Dynamic AUC metric for a survival model. It is done using the -#' estimator proposed proposed by Uno et al. \[[1](https://www.jstor.org/stable/27639883)\], -#' and Hung and Chang \[[2](https://www.jstor.org/stable/41000414)\]. +#' estimator proposed proposed by Uno et al. \[1\], +#' and Hung and Chang \[2\]. #' #' C/D AUC is an extension of the AUC metric known from classification models. #' Its values represent the model's performance at specific time points. @@ -232,8 +232,8 @@ attr(loss_brier_score, "loss_type") <- "time-dependent" #' @return a numeric vector of length equal to the length of the times vector, each value (from the range from 0 to 1) represents the AUC metric at a specific time point, with higher values indicating better performance. #' #' @section References: -#' - \[1\] Uno, Hajime, et al. ["Evaluating prediction rules for t-year survivors with censored regression models."](https://www.jstor.org/stable/27639883) Journal of the American Statistical Association 102.478 (2007): 527-537. -#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. ["Optimal composite markers for time dependent receiver operating characteristic curves with censored survival data."](https://www.jstor.org/stable/41000414) Scandinavian Journal of Statistics 37.4 (2010): 664-679. +#' - \[1\] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. #' #' @rdname cd_auc #' @seealso [loss_one_minus_cd_auc()] [integrated_cd_auc()] [brier_score()] @@ -297,8 +297,8 @@ attr(cd_auc, "loss_type") <- "time-dependent" #' @return a numeric vector of length equal to the length of the times vector, each value (from the range from 0 to 1) represents 1 - AUC metric at a specific time point, with lower values indicating better performance. #' #' #' @section References: -#' - \[1\] Uno, Hajime, et al. ["Evaluating prediction rules for t-year survivors with censored regression models."](https://www.jstor.org/stable/27639883) Journal of the American Statistical Association 102.478 (2007): 527-537. -#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. ["Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data."](https://www.jstor.org/stable/41000414) Scandinavian Journal of Statistics 37.4 (2010): 664-679. +#' - \[1\] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. #' #' @rdname loss_one_minus_cd_auc #' @seealso [cd_auc()] @@ -337,8 +337,8 @@ attr(loss_one_minus_cd_auc, "loss_type") <- "time-dependent" #' @return numeric from 0 to 1, higher values indicate better performance #' #' #' @section References: -#' - \[1\] Uno, Hajime, et al. ["Evaluating prediction rules for t-year survivors with censored regression models."](https://www.jstor.org/stable/27639883) Journal of the American Statistical Association 102.478 (2007): 527-537. -#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. ["Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data."](https://www.jstor.org/stable/41000414) Scandinavian Journal of Statistics 37.4 (2010): 664-679. +#' - \[1\] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. #' #' @rdname integrated_cd_auc #' @seealso [cd_auc()] [loss_one_minus_cd_auc()] @@ -373,8 +373,8 @@ attr(integrated_cd_auc, "loss_type") <- "integrated" #' @return numeric from 0 to 1, lower values indicate better performance #' #' #' @section References: -#' - \[1\] Uno, Hajime, et al. ["Evaluating prediction rules for t-year survivors with censored regression models."](https://www.jstor.org/stable/27639883) Journal of the American Statistical Association 102.478 (2007): 527-537. -#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. ["Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data."](https://www.jstor.org/stable/41000414) Scandinavian Journal of Statistics 37.4 (2010): 664-679. +#' - \[1\] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +#' - \[2\] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. #' #' @rdname loss_one_minus_integrated_cd_auc #' @seealso [integrated_cd_auc()] [cd_auc()] [loss_one_minus_cd_auc()] @@ -417,8 +417,8 @@ attr(loss_one_minus_integrated_cd_auc, "loss_type") <- "integrated" #' @return numeric from 0 to 1, lower values indicate better performance #' #' @section References: -#' - \[1\] Brier, Glenn W. ["Verification of forecasts expressed in terms of probability."](https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml) Monthly Weather Review 78.1 (1950): 1-3. -#' - \[2\] Graf, Erika, et al. ["Assessment and comparison of prognostic classification schemes for survival data."](https://onlinelibrary.wiley.com/doi/10.1002/(SICI)1097-0258(19990915/30)18:17/18%3C2529::AID-SIM274%3E3.0.CO;2-5) Statistics in Medicine 18.17‐18 (1999): 2529-2545. +#' - \[1\] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3. +#' - \[2\] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. #' #' @rdname integrated_brier_score #' @seealso [brier_score()] [integrated_cd_auc()] [loss_one_minus_integrated_cd_auc()] @@ -458,6 +458,7 @@ attr(loss_integrated_brier_score, "loss_type") <- "integrated" #' #' @return a function with standardized parameters (`y_true`, `risk`, `surv`, `times`) that can be used to calculate loss #' +#' @examples #' if(FALSE){ #' measure <- msr("surv.calib_beta") #' mlr_measure <- loss_adapt_mlr3proba(measure) @@ -483,7 +484,6 @@ loss_adapt_mlr3proba <- function(measure, reverse = FALSE, ...) { return(output) } - if (reverse) { attr(loss_function, "loss_name") <- paste("one minus", measure$id) } else { diff --git a/R/model_performance.R b/R/model_performance.R index dd388b67..f8b356b0 100644 --- a/R/model_performance.R +++ b/R/model_performance.R @@ -9,17 +9,17 @@ #' @param times a numeric vector of times. If `type == "metrics"` then the survival function is evaluated at these times, if `type == "roc"` then the ROC curves are calculated at these times. #' #' @return An object of class `"model_performance_survival"`. It's a list of metric values calculated for the model. It contains: -#' - Harrell's concordance index \[[1](https://onlinelibrary.wiley.com/doi/abs/10.1002/sim.4780030207)\] -#' - Brier score \[[2](https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml), [3](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5)\] -#' - C/D AUC using the estimator proposed by Uno et. al \[[4](https://www.jstor.org/stable/27639883#metadata_info_tab_contents)\] +#' - Harrell's concordance index \[1\] +#' - Brier score \[2, 3\] +#' - C/D AUC using the estimator proposed by Uno et. al \[4\] #' - integral of the Brier score #' - integral of the C/D AUC #' #' @section References: -#' - \[1\] Harrell, F.E., Jr., et al. ["Regression modelling strategies for improved prognostic prediction."](https://onlinelibrary.wiley.com/doi/abs/10.1002/sim.4780030207) Statistics in Medicine 3.2 (1984): 143-152. -#' - \[2\] Brier, Glenn W. ["Verification of forecasts expressed in terms of probability."](https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml) Monthly Weather Review 78.1 (1950): 1-3. -#' - \[3\] Graf, Erika, et al. ["Assessment and comparison of prognostic classification schemes for survival data."](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5) Statistics in Medicine 18.17‐18 (1999): 2529-2545. -#' - \[4\] Uno, Hajime, et al. ["Evaluating prediction rules for t-year survivors with censored regression models."](https://www.jstor.org/stable/27639883#metadata_info_tab_contents) Journal of the American Statistical Association 102.478 (2007): 527-537. +#' - \[1\] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152. +#' - \[2\] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3. +#' - \[3\] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. +#' - \[4\] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. #' #' @examples #' \donttest{ diff --git a/R/model_survshap.R b/R/model_survshap.R index a3b9b0e1..ebdbd034 100644 --- a/R/model_survshap.R +++ b/R/model_survshap.R @@ -59,6 +59,7 @@ model_survshap <- function(explainer, ...) { model_survshap.surv_explainer <- function(explainer, new_observation = NULL, y_true = NULL, + N = NULL, calculation_method = "kernelshap", aggregation_method = "integral", output_type = "survival", @@ -98,9 +99,11 @@ model_survshap.surv_explainer <- function(explainer, explainer = explainer, new_observation = observations, output_type = output_type, + N = N, y_true = y_true, calculation_method = calculation_method, - aggregation_method = aggregation_method + aggregation_method = aggregation_method, + ... ) attr(shap_values, "label") <- explainer$label diff --git a/R/plot_model_profile_survival.R b/R/plot_model_profile_survival.R index d36638c7..60d433f7 100644 --- a/R/plot_model_profile_survival.R +++ b/R/plot_model_profile_survival.R @@ -230,7 +230,7 @@ plot2_mp <- function(x, if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") if (single_timepoint && !marginalize_over_time) { - subtitle <- paste0(subtitle, " and time =", times) + subtitle <- paste0(subtitle, " and time = ", times) } } diff --git a/R/plot_predict_profile_survival.R b/R/plot_predict_profile_survival.R index d5eb64e0..f9bfb62b 100644 --- a/R/plot_predict_profile_survival.R +++ b/R/plot_predict_profile_survival.R @@ -192,7 +192,7 @@ plot2_cp <- function(x, if (!is.null(subtitle) && subtitle == "default") { subtitle <- paste0("created for the ", unique(variable), " variable") if (single_timepoint && !marginalize_over_time) { - subtitle <- paste0(subtitle, " and time =", times) + subtitle <- paste0(subtitle, " and time = ", times) } } diff --git a/R/plot_surv_shap.R b/R/plot_surv_shap.R index 3cde5f3f..94b6df1e 100644 --- a/R/plot_surv_shap.R +++ b/R/plot_surv_shap.R @@ -121,11 +121,11 @@ plot.surv_shap <- function(x, #' * `color_variable` - variable used to denote the color, by default equal to `variable` #' #' -#'#' ## `plot.aggregated_surv_shap(geom = "curves")` +#' ## `plot.aggregated_surv_shap(geom = "curves")` #' #' * `variable` - variable for which SurvSHAP(t) curves are to be plotted, by default first from result data #' * `boxplot` - whether to plot functional boxplot with marked outliers or all curves colored by variable value -#' +#' * `coef` - length of the functional boxplot's whiskers as multiple of IQR, by default 1.5 #' #' @examples #' \donttest{ @@ -293,7 +293,7 @@ plot_shap_global_beeswarm <- function(x, max_vars = 7, colors = NULL) { df <- as.data.frame(do.call(rbind, x$aggregate)) - cols <- names(sort(colMeans(abs(df))))[1:min(max_vars, length(df))] + cols <- names(sort(colMeans(abs(df)), decreasing = TRUE))[1:min(max_vars, length(df))] df <- df[, cols] df <- stack(df) colnames(df) <- c("shap_value", "variable") @@ -325,6 +325,7 @@ plot_shap_global_beeswarm <- function(x, ggplot(data = df, aes(x = shap_value, y = variable, color = var_value)) + geom_vline(xintercept = 0, color = "#ceced9", linetype = "solid") + geom_jitter(width = 0, height = 0.15) + + scale_y_discrete(limits=rev) + scale_color_gradient2( name = "Variable value", low = colors[1], diff --git a/R/predict_parts.R b/R/predict_parts.R index d57fd459..10c84386 100644 --- a/R/predict_parts.R +++ b/R/predict_parts.R @@ -5,9 +5,9 @@ #' @param explainer an explainer object - model preprocessed by the `explain()` function #' @param new_observation a new observation for which prediction need to be explained #' @param ... other parameters which are passed to `iBreakDown::break_down` if `output_type=="risk"`, or if `output_type=="survival"` to `surv_shap()` or `surv_lime()` functions depending on the selected type -#' @param N the maximum number of observations used for calculation of attributions. If `NULL` (default) all observations will be used. +#' @param N the number of observations used for calculation of attributions. If `NULL` (default) all explainer data will be used for SurvSHAP(t) and 100 neigbours for SurvLIME. #' @param type if `output_type == "survival"` must be either `"survshap"` or `"survlime"`, otherwise refer to the `DALEX::predict_parts` -#' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::predict_parts` function. +#' @param output_type either `"survival"`, `"chf"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. If `"chf"` the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the `DALEX::predict_parts` function. #' @param explanation_label a label that can overwrite explainer label (useful for multiple explanations for the same explainer/model) #' #' @return An object of class `"predict_parts_survival"` and additional classes depending on the type of explanations. It is a list with the element `result` containing the results of the calculation. @@ -27,7 +27,6 @@ #' * `categorical_variables` - character vector, names of variables that should be treated as categories (factors are included by default) #' * `k` - a small positive number > 1, added to chf before taking log, so that weigths aren't negative #' * for `survshap` -#' * `timestamps` - a numeric vector, time points at which the survival function will be evaluated #' * `y_true` - a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting #' * `calculation_method` - a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation #' * `aggregation_method` - a character, either `"mean_absolute"` or `"integral"`, `"max_absolute"`, `"sum_of_squares"` @@ -75,8 +74,8 @@ predict_parts.surv_explainer <- function(explainer, new_observation, ..., N = NU )) } else { res <- switch(type, - "survshap" = surv_shap(explainer, new_observation, output_type, ...), - "survlime" = surv_lime(explainer, new_observation, ...), + "survshap" = surv_shap(explainer, new_observation, output_type, ..., N = N), + "survlime" = surv_lime(explainer, new_observation, ..., N = N), stop("Only `survshap` and `survlime` methods are implemented for now") ) } diff --git a/R/surv_lime.R b/R/surv_lime.R index cc06e64c..3cb7fb9d 100644 --- a/R/surv_lime.R +++ b/R/surv_lime.R @@ -33,6 +33,7 @@ surv_lime <- function(explainer, new_observation, test_explainer(explainer, "surv_lime", has_data = TRUE, has_y = TRUE, has_chf = TRUE) new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)] if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") + if (is.null(N)) N <- 100 predicted_sf <- explainer$predict_survival_function(explainer$model, new_observation, explainer$times) @@ -57,12 +58,11 @@ surv_lime <- function(explainer, new_observation, distances <- apply(scaled_data, 1, dist, scaled_data[1, ]) - if (is.null(kernel_width)) kernel_width <- sqrt(ncol(scaled_data) * 0.75) + if (is.null(kernel_width)) kernel_width <- sqrt(ncol(scaled_data)) * 0.75 weights <- sqrt(exp(-(distances^2) / (kernel_width^2))) na_est <- survival::basehaz(survival::coxph(explainer$y ~ 1)) - model_chfs <- explainer$predict_cumulative_hazard_function(explainer$model, neighbourhood$inverse, na_est$time) + k log_chfs <- log(model_chfs) weights_v <- model_chfs / log_chfs @@ -175,10 +175,13 @@ generate_neighbourhood <- function(data_org, data <- data[, colnames(data_row)] if (length(categorical_variables) > 0) { + inverse_as_factor <- inverse + inverse_as_factor[additional_categorical_variables] <- + lapply(inverse_as_factor[additional_categorical_variables], as.factor) expr <- paste0("~", paste(categorical_variables, collapse = "+")) - categorical_matrix <- model.matrix(as.formula(expr), data = inverse)[, -1] + categorical_matrix <- model.matrix(as.formula(expr), data = inverse_as_factor)[, -1] inverse_ohe <- cbind(inverse, categorical_matrix) - inverse_ohe[, factor_variables] <- NULL + inverse_ohe[, categorical_variables] <- NULL } else { inverse_ohe <- inverse } diff --git a/R/surv_shap.R b/R/surv_shap.R index f8fd2048..75737bac 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -4,8 +4,9 @@ #' @param new_observation new observations for which predictions need to be explained #' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations. #' @param ... additional parameters, passed to internal functions +#' @param N a positive integer, number of observations used as the background data #' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting -#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation +#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models). #' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"` #' #' @return A list, containing the calculated SurvSHAP(t) results in the `result` field @@ -18,11 +19,17 @@ surv_shap <- function(explainer, new_observation, output_type, ..., + N = NULL, y_true = NULL, - calculation_method = "kernelshap", - aggregation_method = "integral") { + calculation_method = c("kernelshap", "exact_kernel", "treeshap"), + aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares") +) { + calculation_method <- match.arg(calculation_method) + aggregation_method <- match.arg(aggregation_method) + + # make this code work for multiple observations stopifnot( - "`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse( + "`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse( !is.null(y_true), ifelse( is.matrix(y_true), @@ -33,14 +40,39 @@ surv_shap <- function(explainer, ) ) - test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) + if (calculation_method == "kernelshap") { + if (!requireNamespace("kernelshap", quietly = TRUE)) { + stop( + paste0( + "Package \"kernelshap\" must be installed to use ", + "'calculation_method = \"kernelshap\"'." + ), + call. = FALSE + ) + } + } + if (calculation_method == "treeshap") { + if (!requireNamespace("treeshap", quietly = TRUE)) { + stop( + paste0( + "Package \"treeshap\" must be installed to use ", + "'calculation_method = \"treeshap\"'." + ), + call. = FALSE + ) + } + } + test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) # make this code also work for 1-row matrix col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) if (is.matrix(new_observation) && nrow(new_observation) == 1) { - new_observation <- as.matrix(t(new_observation[, col_index])) + new_observation <- data.frame(as.matrix(t(new_observation[, col_index]))) } else { new_observation <- new_observation[, col_index] + if (!inherits(new_observation, "data.frame")) { + new_observation <- data.frame(new_observation) + } } if (ncol(explainer$data) != ncol(new_observation)) { @@ -59,14 +91,25 @@ surv_shap <- function(explainer, } } + if (calculation_method == "treeshap") { + if (!inherits(explainer$model, "ranger")) { + stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.") + } + } + res <- list() res$eval_times <- explainer$times # to display final object correctly, when is.matrix(new_observation) == TRUE res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, - "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...), - "kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...), - stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented") + "exact_kernel" = use_exact_shap(explainer, new_observation, output_type, N, ...), + "kernelshap" = use_kernelshap(explainer, new_observation, output_type, N, ...), + "treeshap" = use_treeshap(explainer, new_observation, output_type, ...), + stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented")) + # quality-check here + stopifnot( + "Number of rows of SurvSHAP table are not identical with length(eval_times)" = + nrow(res$result) == length(res$eval_times) ) if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind) @@ -86,11 +129,12 @@ surv_shap <- function(explainer, return(res) } -use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) { + +use_exact_shap <- function(explainer, new_observation, output_type, N, ...) { shap_values <- sapply( X = as.character(seq_len(nrow(new_observation))), FUN = function(i) { - as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], output_type, ...)) + as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], output_type, N, ...)) }, USE.NAMES = TRUE, simplify = FALSE @@ -100,16 +144,16 @@ use_exact_shap <- function(explainer, new_observation, output_type, observation_ } -shap_kernel <- function(explainer, new_observation, output_type, ...) { +shap_kernel <- function(explainer, new_observation, output_type, N, ...) { timestamps <- explainer$times p <- ncol(explainer$data) - + if (is.null(N)) N <- nrow(explainer$data) + background_data <- explainer$data[sample(1:nrow(explainer$data), N),] target_sf <- predict(explainer, new_observation, times = timestamps, output_type = output_type) - sfs <- predict(explainer, explainer$data, times = timestamps, output_type = output_type) + sfs <- predict(explainer, background_data, times = timestamps, output_type = output_type) baseline_sf <- apply(sfs, 2, mean) - permutations <- expand.grid(rep(list(0:1), p)) kernel_weights <- generate_shap_kernel_weights(permutations, p) @@ -117,17 +161,14 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) { explainer, explainer$model, baseline_sf, - as.data.frame(explainer$data), + as.data.frame(background_data), permutations, kernel_weights, as.data.frame(new_observation), timestamps ) - - shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data)) colnames(shap_values) <- paste("t=", timestamps, sep = "") - return(t(shap_values)) } @@ -188,7 +229,7 @@ aggregate_surv_shap <- function(survshap, times, method, ...) { } -use_kernelshap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) { +use_kernelshap <- function(explainer, new_observation, output_type, N, ...) { predfun <- function(model, newdata) { if (output_type == "survival"){ @@ -204,7 +245,18 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ times = explainer$times ) } + } + + stopifnot( + "new_observation must be a data.frame" = inherits( + new_observation, "data.frame") + ) + if (is.null(N)) N <- nrow(explainer$data) + background_data <- explainer$data[sample(1:nrow(explainer$data), N),] + # ensure that classes of explainer$data and new_observation are equal + if (!inherits(background_data, "data.frame")) { + background_data <- data.frame(background_data) } shap_values <- sapply( @@ -212,11 +264,12 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ FUN = function(i) { tmp_res <- kernelshap::kernelshap( object = explainer$model, - X = new_observation[as.integer(i), ], - bg_X = explainer$data, + X = new_observation[as.integer(i), ], # data.frame + bg_X = background_data, # data.frame pred_fun = predfun, verbose = FALSE ) + # kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE` tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) colnames(tmp_shap_values) <- colnames(tmp_res$X) rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") @@ -229,6 +282,51 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_ return(shap_values) } +use_treeshap <- function(explainer, new_observation, output_type, ...){ + + stopifnot( + "new_observation must be a data.frame" = inherits( + new_observation, "data.frame") + ) + + # init unify_append_args + unify_append_args <- list() + + if (!inherits(explainer$model, "ranger")) { + stop("Support for `treeshap` is currently only implemented for `ranger`.") + } + + tmp_unified <- treeshap::unify(explainer$model, + explainer$data, + type = output_type, + times = explainer$times) + + shap_values <- sapply( + X = as.character(seq_len(nrow(new_observation))), + FUN = function(i) { + # ensure that matrix has expected dimensions; as.integer is + # necessary for valid comparison with "identical" + new_obs_mat <- new_observation[as.integer(i), ] + stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation))))) + + tmp_res <- do.call( + rbind, + lapply(treeshap::treeshap(tmp_unified, x = new_obs_mat, ...), function(x) x$shaps) + ) + + tmp_shap_values <- data.frame(tmp_res) + colnames(tmp_shap_values) <- colnames(tmp_res) + rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") + tmp_shap_values + }, + USE.NAMES = TRUE, + simplify = FALSE + ) + + return(shap_values) + +} + #' @keywords internal aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) { if (length(shap_res_list) > 1) { diff --git a/cran-comments.md b/cran-comments.md index 167790c5..de68aaf5 100644 --- a/cran-comments.md +++ b/cran-comments.md @@ -1,18 +1,8 @@ -## Resubmission 05/09/2023 v.1.1.3 - -* Fix notes about long running examples by wrapping - in \donttest{} blocks. -* Fix notes about CPU time being much longer than elapsed - time when running tests and building vignettes. +## Submission 24/10/2023 v.1.2.0 ## R CMD check results -0 errors | 0 warnings | 1 note - -* Found possibly invalid URLs -- the URLs were - checked to be valid and working, seems to be - a false positive - +0 errors | 0 warnings | 0 notes ## Reverse dependencies diff --git a/man/brier_score.Rd b/man/brier_score.Rd index 8dde9ea7..ae99e583 100644 --- a/man/brier_score.Rd +++ b/man/brier_score.Rd @@ -30,8 +30,8 @@ Brier score is used to evaluate the performance of a survival model, based on th \section{References}{ \itemize{ -\item [1] Brier, Glenn W. \href{https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml}{"Verification of forecasts expressed in terms of probability."} Monthly Weather Review 78.1 (1950): 1-3. -\item [2] Graf, Erika, et al. \href{https://onlinelibrary.wiley.com/doi/10.1002/(SICI)1097-0258(19990915/30)18:17/18\%3C2529::AID-SIM274\%3E3.0.CO;2-5}{"Assessment and comparison of prognostic classification schemes for survival data."} Statistics in Medicine 18.17‐18 (1999): 2529-2545. +\item [1] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3. +\item [2] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. } } diff --git a/man/c_index.Rd b/man/c_index.Rd index 3c346222..371844fb 100644 --- a/man/c_index.Rd +++ b/man/c_index.Rd @@ -24,7 +24,7 @@ A function to compute the Harrell's concordance index of a survival model. \section{References}{ \itemize{ -\item [1] Harrell, F.E., Jr., et al. \href{https://onlinelibrary.wiley.com/doi/10.1002/sim.4780030207}{"Regression modelling strategies for improved prognostic prediction."} Statistics in Medicine 3.2 (1984): 143-152. +\item [1] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152. } } diff --git a/man/cd_auc.Rd b/man/cd_auc.Rd index b0c922a3..9e2e78cb 100644 --- a/man/cd_auc.Rd +++ b/man/cd_auc.Rd @@ -20,8 +20,8 @@ a numeric vector of length equal to the length of the times vector, each value ( } \description{ This function calculates the Cumulative/Dynamic AUC metric for a survival model. It is done using the -estimator proposed proposed by Uno et al. [\href{https://www.jstor.org/stable/27639883}{1}], -and Hung and Chang [\href{https://www.jstor.org/stable/41000414}{2}]. +estimator proposed proposed by Uno et al. [1], +and Hung and Chang [2]. } \details{ C/D AUC is an extension of the AUC metric known from classification models. @@ -31,8 +31,8 @@ It can be integrated over the considered time range. \section{References}{ \itemize{ -\item [1] Uno, Hajime, et al. \href{https://www.jstor.org/stable/27639883}{"Evaluating prediction rules for t-year survivors with censored regression models."} Journal of the American Statistical Association 102.478 (2007): 527-537. -\item [2] Hung, Hung, and Chin‐Tsang Chiang. \href{https://www.jstor.org/stable/41000414}{"Optimal composite markers for time dependent receiver operating characteristic curves with censored survival data."} Scandinavian Journal of Statistics 37.4 (2010): 664-679. +\item [1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +\item [2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. } } diff --git a/man/integrated_brier_score.Rd b/man/integrated_brier_score.Rd index 0e4ed5b3..48d4adf5 100644 --- a/man/integrated_brier_score.Rd +++ b/man/integrated_brier_score.Rd @@ -35,8 +35,8 @@ It is useful to see how a model performs as a whole, not at specific time points \section{References}{ \itemize{ -\item [1] Brier, Glenn W. \href{https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml}{"Verification of forecasts expressed in terms of probability."} Monthly Weather Review 78.1 (1950): 1-3. -\item [2] Graf, Erika, et al. \href{https://onlinelibrary.wiley.com/doi/10.1002/(SICI)1097-0258(19990915/30)18:17/18\%3C2529::AID-SIM274\%3E3.0.CO;2-5}{"Assessment and comparison of prognostic classification schemes for survival data."} Statistics in Medicine 18.17‐18 (1999): 2529-2545. +\item [1] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3. +\item [2] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. } } diff --git a/man/integrated_cd_auc.Rd b/man/integrated_cd_auc.Rd index 0a7dfa6e..4188292c 100644 --- a/man/integrated_cd_auc.Rd +++ b/man/integrated_cd_auc.Rd @@ -20,8 +20,8 @@ numeric from 0 to 1, higher values indicate better performance #' @section References: \itemize{ -\item [1] Uno, Hajime, et al. \href{https://www.jstor.org/stable/27639883}{"Evaluating prediction rules for t-year survivors with censored regression models."} Journal of the American Statistical Association 102.478 (2007): 527-537. -\item [2] Hung, Hung, and Chin‐Tsang Chiang. \href{https://www.jstor.org/stable/41000414}{"Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data."} Scandinavian Journal of Statistics 37.4 (2010): 664-679. +\item [1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +\item [2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. } } \description{ diff --git a/man/loss_adapt_mlr3proba.Rd b/man/loss_adapt_mlr3proba.Rd index 592f4d6d..5c3b9a83 100644 --- a/man/loss_adapt_mlr3proba.Rd +++ b/man/loss_adapt_mlr3proba.Rd @@ -21,12 +21,14 @@ loss_adapt_mlr3proba(measure, reverse = FALSE, ...) } \value{ a function with standardized parameters (\code{y_true}, \code{risk}, \code{surv}, \code{times}) that can be used to calculate loss - -if(FALSE){ -measure <- msr("surv.calib_beta") -mlr_measure <- loss_adapt_mlr3proba(measure) -} } \description{ This function allows for usage of standardized measures from the mlr3proba package with \code{survex}. } +\examples{ +if(FALSE){ + measure <- msr("surv.calib_beta") + mlr_measure <- loss_adapt_mlr3proba(measure) +} + +} diff --git a/man/loss_integrate.Rd b/man/loss_integrate.Rd index e3d9cac9..53344ab7 100644 --- a/man/loss_integrate.Rd +++ b/man/loss_integrate.Rd @@ -32,7 +32,7 @@ This function allows for creating a function for calculation of integrated metri \section{References}{ \itemize{ -\item [1] Graf, Erika, et al. \href{https://onlinelibrary.wiley.com/doi/abs/10.1002/\%28SICI\%291097-0258\%2819990915/30\%2918\%3A17/18\%3C2529\%3A\%3AAID-SIM274\%3E3.0.CO\%3B2-5}{"Assessment and comparison of prognostic classification schemes for survival data."} Statistics in Medicine 18.17‐18 (1999): 2529-2545. +\item [1] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. } } diff --git a/man/loss_one_minus_c_index.Rd b/man/loss_one_minus_c_index.Rd index 87b04457..67e0e0c3 100644 --- a/man/loss_one_minus_c_index.Rd +++ b/man/loss_one_minus_c_index.Rd @@ -24,7 +24,7 @@ This function subtracts the C-index metric from one to obtain a loss function wh \section{References}{ \itemize{ -\item [1] Harrell, F.E., Jr., et al. \href{https://onlinelibrary.wiley.com/doi/10.1002/sim.4780030207}{"Regression modelling strategies for improved prognostic prediction."} Statistics in Medicine 3.2 (1984): 143-152. +\item [1] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152. } } diff --git a/man/loss_one_minus_cd_auc.Rd b/man/loss_one_minus_cd_auc.Rd index 677d00c4..5f39b6e8 100644 --- a/man/loss_one_minus_cd_auc.Rd +++ b/man/loss_one_minus_cd_auc.Rd @@ -20,8 +20,8 @@ a numeric vector of length equal to the length of the times vector, each value ( #' @section References: \itemize{ -\item [1] Uno, Hajime, et al. \href{https://www.jstor.org/stable/27639883}{"Evaluating prediction rules for t-year survivors with censored regression models."} Journal of the American Statistical Association 102.478 (2007): 527-537. -\item [2] Hung, Hung, and Chin‐Tsang Chiang. \href{https://www.jstor.org/stable/41000414}{"Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data."} Scandinavian Journal of Statistics 37.4 (2010): 664-679. +\item [1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +\item [2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. } } \description{ diff --git a/man/loss_one_minus_integrated_cd_auc.Rd b/man/loss_one_minus_integrated_cd_auc.Rd index ab15112e..928959a4 100644 --- a/man/loss_one_minus_integrated_cd_auc.Rd +++ b/man/loss_one_minus_integrated_cd_auc.Rd @@ -25,8 +25,8 @@ numeric from 0 to 1, lower values indicate better performance #' @section References: \itemize{ -\item [1] Uno, Hajime, et al. \href{https://www.jstor.org/stable/27639883}{"Evaluating prediction rules for t-year survivors with censored regression models."} Journal of the American Statistical Association 102.478 (2007): 527-537. -\item [2] Hung, Hung, and Chin‐Tsang Chiang. \href{https://www.jstor.org/stable/41000414}{"Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data."} Scandinavian Journal of Statistics 37.4 (2010): 664-679. +\item [1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. +\item [2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679. } } \description{ diff --git a/man/model_performance.surv_explainer.Rd b/man/model_performance.surv_explainer.Rd index 3b885366..e5787112 100644 --- a/man/model_performance.surv_explainer.Rd +++ b/man/model_performance.surv_explainer.Rd @@ -31,9 +31,9 @@ model_performance(explainer, ...) \value{ An object of class \code{"model_performance_survival"}. It's a list of metric values calculated for the model. It contains: \itemize{ -\item Harrell's concordance index [\href{https://onlinelibrary.wiley.com/doi/abs/10.1002/sim.4780030207}{1}] -\item Brier score [\href{https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml}{2}, \href{https://onlinelibrary.wiley.com/doi/abs/10.1002/\%28SICI\%291097-0258\%2819990915/30\%2918\%3A17/18\%3C2529\%3A\%3AAID-SIM274\%3E3.0.CO\%3B2-5}{3}] -\item C/D AUC using the estimator proposed by Uno et. al [\href{https://www.jstor.org/stable/27639883#metadata_info_tab_contents}{4}] +\item Harrell's concordance index [1] +\item Brier score [2, 3] +\item C/D AUC using the estimator proposed by Uno et. al [4] \item integral of the Brier score \item integral of the C/D AUC } @@ -44,10 +44,10 @@ This function calculates metrics for survival models. The metrics calculated are \section{References}{ \itemize{ -\item [1] Harrell, F.E., Jr., et al. \href{https://onlinelibrary.wiley.com/doi/abs/10.1002/sim.4780030207}{"Regression modelling strategies for improved prognostic prediction."} Statistics in Medicine 3.2 (1984): 143-152. -\item [2] Brier, Glenn W. \href{https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml}{"Verification of forecasts expressed in terms of probability."} Monthly Weather Review 78.1 (1950): 1-3. -\item [3] Graf, Erika, et al. \href{https://onlinelibrary.wiley.com/doi/abs/10.1002/\%28SICI\%291097-0258\%2819990915/30\%2918\%3A17/18\%3C2529\%3A\%3AAID-SIM274\%3E3.0.CO\%3B2-5}{"Assessment and comparison of prognostic classification schemes for survival data."} Statistics in Medicine 18.17‐18 (1999): 2529-2545. -\item [4] Uno, Hajime, et al. \href{https://www.jstor.org/stable/27639883#metadata_info_tab_contents}{"Evaluating prediction rules for t-year survivors with censored regression models."} Journal of the American Statistical Association 102.478 (2007): 527-537. +\item [1] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152. +\item [2] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3. +\item [3] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545. +\item [4] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537. } } diff --git a/man/model_survshap.surv_explainer.Rd b/man/model_survshap.surv_explainer.Rd index 840f6460..4f4773d5 100644 --- a/man/model_survshap.surv_explainer.Rd +++ b/man/model_survshap.surv_explainer.Rd @@ -11,6 +11,7 @@ model_survshap(explainer, ...) explainer, new_observation = NULL, y_true = NULL, + N = NULL, calculation_method = "kernelshap", aggregation_method = "integral", output_type = "survival", @@ -26,7 +27,9 @@ model_survshap(explainer, ...) \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} -\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} +\item{N}{a positive integer, number of observations used as the background data} + +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements), \code{"exact_kernel"} for exact Kernel SHAP estimation, or \code{"treeshap"} for use of \code{treeshap} library (efficient implementation to compute SHAP values for tree-based models).} \item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} diff --git a/man/plot.aggregated_surv_shap.Rd b/man/plot.aggregated_surv_shap.Rd index 22d7b8ff..6b3585ec 100644 --- a/man/plot.aggregated_surv_shap.Rd +++ b/man/plot.aggregated_surv_shap.Rd @@ -57,11 +57,13 @@ explanations of survival models created using the \code{model_survshap()} functi \item \code{variable} - variable for which the profile is to be plotted, by default first from result data \item \code{color_variable} - variable used to denote the color, by default equal to \code{variable} } +} -#' ## \code{plot.aggregated_surv_shap(geom = "curves")} +\subsection{\code{plot.aggregated_surv_shap(geom = "curves")}}{ \itemize{ \item \code{variable} - variable for which SurvSHAP(t) curves are to be plotted, by default first from result data \item \code{boxplot} - whether to plot functional boxplot with marked outliers or all curves colored by variable value +\item \code{coef} - length of the functional boxplot's whiskers as multiple of IQR, by default 1.5 } } } diff --git a/man/predict_parts.surv_explainer.Rd b/man/predict_parts.surv_explainer.Rd index 2b705be7..c79249e5 100644 --- a/man/predict_parts.surv_explainer.Rd +++ b/man/predict_parts.surv_explainer.Rd @@ -24,11 +24,11 @@ predict_parts(explainer, ...) \item{new_observation}{a new observation for which prediction need to be explained} -\item{N}{the maximum number of observations used for calculation of attributions. If \code{NULL} (default) all observations will be used.} +\item{N}{the number of observations used for calculation of attributions. If \code{NULL} (default) all explainer data will be used for SurvSHAP(t) and 100 neigbours for SurvLIME.} \item{type}{if \code{output_type == "survival"} must be either \code{"survshap"} or \code{"survlime"}, otherwise refer to the \code{DALEX::predict_parts}} -\item{output_type}{either \code{"survival"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_parts} function.} +\item{output_type}{either \code{"survival"}, \code{"chf"} or \code{"risk"} the type of survival model output that should be considered for explanations. If \code{"survival"} the explanations are based on the survival function. If \code{"chf"} the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the \code{DALEX::predict_parts} function.} \item{explanation_label}{a label that can overwrite explainer label (useful for multiple explanations for the same explainer/model)} } @@ -56,7 +56,6 @@ There are additional parameters that are passed to internal functions } \item for \code{survshap} \itemize{ -\item \code{timestamps} - a numeric vector, time points at which the survival function will be evaluated \item \code{y_true} - a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting \item \code{calculation_method} - a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation \item \code{aggregation_method} - a character, either \code{"mean_absolute"} or \code{"integral"}, \code{"max_absolute"}, \code{"sum_of_squares"} diff --git a/man/surv_shap.Rd b/man/surv_shap.Rd index 885e64f6..53dffe1c 100644 --- a/man/surv_shap.Rd +++ b/man/surv_shap.Rd @@ -9,9 +9,10 @@ surv_shap( new_observation, output_type, ..., + N = NULL, y_true = NULL, - calculation_method = "kernelshap", - aggregation_method = "integral" + calculation_method = c("kernelshap", "exact_kernel", "treeshap"), + aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares") ) } \arguments{ @@ -23,9 +24,11 @@ surv_shap( \item{...}{additional parameters, passed to internal functions} +\item{N}{a positive integer, number of observations used as the background data} + \item{y_true}{a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting} -\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements) or \code{"exact_kernel"} for exact Kernel SHAP estimation} +\item{calculation_method}{a character, either \code{"kernelshap"} for use of \code{kernelshap} library (providing faster Kernel SHAP with refinements), \code{"exact_kernel"} for exact Kernel SHAP estimation, or \code{"treeshap"} for use of \code{treeshap} library (efficient implementation to compute SHAP values for tree-based models).} \item{aggregation_method}{a character, either \code{"integral"}, \code{"integral_absolute"}, \code{"mean_absolute"}, \code{"max_absolute"}, or \code{"sum_of_squares"}} } diff --git a/tests/testthat/test-model_survshap.R b/tests/testthat/test-model_survshap.R index d56b75cc..4ee3be45 100644 --- a/tests/testthat/test-model_survshap.R +++ b/tests/testthat/test-model_survshap.R @@ -1,4 +1,5 @@ +# create objects here so that they do not have to be created redundantly veteran <- survival::veteran rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) @@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times)) expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data))) }) + +# testing if matrix works as input +rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) +rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) + +test_that("global survshap explanations with treeshap work for ranger", { + + new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))]) + ranger_global_survshap_tree <- model_survshap( + rsf_ranger_exp_matrix, + new_observation = new_obs, + y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(ranger_global_survshap_tree) + + expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap")) + expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times)) + expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data))) + +}) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 450fb1cf..6e8c733f 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -6,7 +6,7 @@ test_that("survshap explanations work", { rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph, verbose = FALSE) - rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) rsf_src_exp <- explain(rsf_src, verbose = FALSE) parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares") @@ -19,6 +19,19 @@ test_that("survshap explanations work", { parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute") plot(parts_ranger) + # test ranger with kernelshap when using a matrix as input for data and new observation + rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) + new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")]) + parts_ranger_kernelshap <- predict_parts( + rsf_ranger_exp_matrix, + new_observation = new_obs, + y_true = c(100, 1), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + plot(parts_ranger_kernelshap) + parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")]) plot(parts_src) @@ -46,6 +59,29 @@ test_that("survshap explanations work", { expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent")) expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent")) +}) + +test_that("local survshap explanations with treeshap work for ranger", { + + veteran <- survival::veteran + + rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE) + + + new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))])) + parts_ranger <- model_survshap( + rsf_ranger_exp_matrix, + new_obs, + y_true = c(veteran$time[2], veteran$status[2]), + aggregation_method = "mean_absolute", + calculation_method = "treeshap" + ) + plot(parts_ranger) + + expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times)) + expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data))) }) @@ -67,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", { plot(parts_cph, rug = "censors") plot(parts_cph, rug = "none") + # test global exact + parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf") + plot(parts_cph_glob) + parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf") plot(parts_ranger)