Skip to content

Commit

Permalink
report roc-auc in assess_model() plot title
Browse files Browse the repository at this point in the history
  • Loading branch information
lgessl committed Dec 18, 2023
1 parent 0f61b71 commit 15902e6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 42 deletions.
38 changes: 20 additions & 18 deletions R/assess_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,53 +57,55 @@ assess_model <- function(
)
actual <- pred_act[["actual"]]
predicted <- pred_act[["predicted"]]
perf_tbl_model <- calculate_perf_metric(
perf_plot_spec <- calculate_perf_metric(
predicted = predicted,
actual = actual,
perf_plot_spec = perf_plot_spec
)
perf_tbl_model <- perf_plot_spec$data

# (b) For benchmark (if given)
perf_tbl_bm <- NULL
if(!is.null(perf_plot_spec$benchmark)){
perf_tbl_bm <- add_benchmark_perf_metric(
perf_plot_spec <- add_benchmark_perf_metric(
pheno_tbl = pheno_tbl,
data_spec = data_spec,
perf_plot_spec = perf_plot_spec,
model_spec = model_spec
)
perf_tbl_bm <- perf_plot_spec$bm_data
}

# Combine both into tibble
# Combine both performance tibbles into one
perf_tbl_list <- list()
perf_tbl_list[[model_spec$name]] <- perf_tbl_model
if(!is.null(perf_tbl_bm)){
perf_tbl_list[[perf_plot_spec$benchmark]] <- perf_tbl_bm
}
perf_tbl <- dplyr::bind_rows(perf_tbl_list, .id = "model")
perf_plot_spec$data <- perf_tbl

# Plot
if(plots){
plot_perf_metric(
perf_tbl = perf_tbl,
perf_plot_spec = perf_plot_spec,
quiet = quiet
)
}

if(perf_plot_spec$scores_plot){
perf_plot_spec$title <- stringr::str_c(
data_spec$name, ", pfs < ", model_spec$pfs_leq
)
perf_plot_spec$fname <- file.path(
dirname(perf_plot_spec$fname),
"scores.pdf"
)
plot_scores(
predicted = predicted,
actual = actual,
perf_plot_spec = perf_plot_spec
)
if(perf_plot_spec$scores_plot){
perf_plot_spec$title <- stringr::str_c(
data_spec$name, ", pfs < ", model_spec$pfs_leq
)
perf_plot_spec$fname <- file.path(
dirname(perf_plot_spec$fname),
"scores.pdf"
)
plot_scores(
predicted = predicted,
actual = actual,
perf_plot_spec = perf_plot_spec
)
}
}

return(perf_tbl_list)
Expand Down
20 changes: 16 additions & 4 deletions R/assess_model_helpers.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' @importFrom rlang .data
plot_perf_metric <- function(
perf_tbl,
perf_plot_spec,
quiet = FALSE
){
perf_tbl <- perf_plot_spec$data
plt <- ggplot2::ggplot(
perf_tbl,
ggplot2::aes(
Expand Down Expand Up @@ -65,6 +65,11 @@ calculate_perf_metric <- function(
measure = y_metric,
x.measure = x_metric
)
# By default, also infer ROC-AUC
roc_auc <- ROCR::performance(
rocr_prediction,
measure = "auc"
)@y.values[[1]]

# Store them in a tibble
tbl <- tibble::tibble(
Expand All @@ -82,7 +87,13 @@ calculate_perf_metric <- function(
any_na <- apply(tbl, 1, function(x) any(is.na(x)))
tbl <- tbl[!any_na, ]

return(tbl)
perf_plot_spec$data <- tbl
perf_plot_spec$title <- stringr::str_c(
perf_plot_spec$title,
", ROC-AUC = ", round(roc_auc, 3)
)

return(perf_plot_spec)
}


Expand Down Expand Up @@ -118,13 +129,14 @@ add_benchmark_perf_metric <- function(
rm_na = TRUE
)

perf_tbl <- calculate_perf_metric(
pps_bm <- calculate_perf_metric(
predicted = pred_act[[1]],
actual = pred_act[[2]],
perf_plot_spec = perf_plot_spec
)
perf_plot_spec$bm_data <- pps_bm$data

return(perf_tbl)
return(perf_plot_spec)
}


Expand Down
6 changes: 3 additions & 3 deletions R/assess_multiple_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ assess_multiple_models <- function(
}

perf_tbl <- dplyr::bind_rows(perf_tbls, .id = "model")
perf_plot_spec$data <- perf_tbl
if(comparison_plot){
plot_perf_metric(
perf_tbl = perf_tbl,
perf_plot_spec = perf_plot_spec,
quiet = TRUE
perf_plot_spec = perf_plot_spec,
quiet = TRUE
)
message("Saving comparative performance plot to ", perf_plot_spec$fname)
}
Expand Down
16 changes: 9 additions & 7 deletions tests/testthat/test-assess_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ test_that("assess_model() works", {
model_spec = list(model_spec_1, model_spec_2)
)

expect_no_error(assess_model(
expr_mat = expr_mat,
pheno_tbl = pheno_tbl,
data_spec = data_spec,
model_spec = model_spec_1,
perf_plot_spec = perf_plot_spec
))
expect_no_error(
assess_model(
expr_mat = expr_mat,
pheno_tbl = pheno_tbl,
data_spec = data_spec,
model_spec = model_spec_1,
perf_plot_spec = perf_plot_spec
)
)
expect_no_error(assess_model(
expr_mat = expr_mat,
pheno_tbl = pheno_tbl,
Expand Down
20 changes: 10 additions & 10 deletions tests/testthat/test-assess_model_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ test_that("plot_perf_metric() works", {
x_lab = "this x lab",
y_lab = "that y lab"
)
perf_plot_spec$data <- perf_tbl

expect_silent(
plot_perf_metric(
perf_tbl = perf_tbl,
perf_plot_spec = perf_plot_spec,
quiet = TRUE
)
Expand All @@ -45,12 +45,13 @@ test_that("calculate_perf_metric() works", {
)

expect_silent(
perf_tbl <- calculate_perf_metric(
perf_plot_spec <- calculate_perf_metric(
predicted = predicted,
actual = actual,
perf_plot_spec = perf_plot_spec
)
)
perf_tbl <- perf_plot_spec$data
expect_equal(names(perf_tbl), c("rpp", "prec", "cutoff"))
expect_s3_class(perf_tbl, "tbl_df")
expect_equal(dim(perf_tbl), c(n_samples, 3))
Expand Down Expand Up @@ -79,14 +80,13 @@ test_that("add_benchmark_perf_metric() works", {
y_metric = "prec"
)

# expect_silent(
perf_tbl <- add_benchmark_perf_metric(
pheno_tbl = pheno_tbl,
data_spec = data_spec,
perf_plot_spec = perf_plot_spec,
model_spec = model_spec
)
# )
perf_plot_spec <- add_benchmark_perf_metric(
pheno_tbl = pheno_tbl,
data_spec = data_spec,
perf_plot_spec = perf_plot_spec,
model_spec = model_spec
)
perf_tbl <- perf_plot_spec$bm_data

expect_equal(names(perf_tbl), c("rpp", "prec", "cutoff"))
expect_s3_class(perf_tbl, "tbl_df")
Expand Down

0 comments on commit 15902e6

Please sign in to comment.