Skip to content

Commit

Permalink
intersect ifor availability in asess-model helpers now
Browse files Browse the repository at this point in the history
  • Loading branch information
lgessl committed Jan 15, 2024
1 parent b864cb8 commit c4114ef
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
39 changes: 33 additions & 6 deletions R/assess_model_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,35 @@ calculate_2d_metric <- function(
model_spec,
benchmark = NULL
){
if(xor(is.null(benchmark), is.null(perf_plot_spec$benchmark))){
stop("`perf_plot_spec$benchmark` and `benchmark` must be both NULL or both ",
"not NULL")
}
# Extract most frequently used values
x_metric <- perf_plot_spec$x_metric
y_metric <- perf_plot_spec$y_metric
# Prepare for loop
tbl_list <- list()
aucs <- rep(NA, length(actual))
estimate_list <- list(predicted, benchmark)
if(is.null(benchmark)) estimate_list <- list(predicted)
names(estimate_list) <- c(model_spec$name, perf_plot_spec$benchmark)
estimate_list <- list()
estimate_list[[model_spec$name]] <- predicted
if(!is.null(benchmark)){
estimate_list[[perf_plot_spec$benchmark]] <- benchmark
}

for(i in model_spec$split_index){
for(estimate_name in names(estimate_list)){
estimate <- estimate_list[[estimate_name]]
estimate_actual <- intersect_by_names(
estimate[[i]],
actual[[i]],
rm_na = TRUE
)
# Calculate performance measures
rocr_prediction <- ROCR::prediction(estimate[[i]], actual[[i]])
rocr_prediction <- ROCR::prediction(
estimate_actual[[1]],
estimate_actual[[2]]
)
rocr_perf <- ROCR::performance(
rocr_prediction,
measure = y_metric,
Expand Down Expand Up @@ -83,8 +97,11 @@ plot_2d_metric <- function(
data[[y_metric]] >= perf_plot_spec$ylim[1] &
data[[y_metric]] <= perf_plot_spec$ylim[2],
]
bm_data <- data[data[["model"]] == perf_plot_spec$benchmark, ]
data <- data[data[["model"]] != perf_plot_spec$benchmark, ]
bm_data <- NULL
if(!is.null(perf_plot_spec$benchmark)){
bm_data <- data[data[["model"]] == perf_plot_spec$benchmark, ]
data <- data[data[["model"]] != perf_plot_spec$benchmark, ]
}

plt <- ggplot2::ggplot(
data = data,
Expand Down Expand Up @@ -165,6 +182,16 @@ plot_risk_scores <- function(
ncol = 2,
quiet = FALSE
){
# Get rid of NAs
for(i in seq_along(predicted)){
pred_act <- intersect_by_names(
predicted[[i]],
actual[[i]],
rm_na = TRUE
)
predicted[[i]] <- pred_act[[1]]
actual[[i]] <- pred_act[[2]]
}
tbl <- tibble::tibble(
patient_id = names(unlist(predicted)),
rank = unlist(lapply(predicted, function(x) rank(-x))),
Expand Down
18 changes: 18 additions & 0 deletions tests/testthat/test-assess_model_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ test_that("calculate_2d_metric() works", {
expect_s3_class(perf_tbl, "tbl_df")
expect_true(all(perf_tbl[["model"]] %in% c(model_spec$name, perf_plot_spec$benchmark)))
expect_true(perf_tbl[, 1:4] |> as.matrix() |> is.numeric() |> all())

perf_plot_spec$benchmark <- NULL
expect_silent(
perf_plot_spec <- calculate_2d_metric(
actual = actual,
predicted = predicted,
benchmark = NULL,
perf_plot_spec = perf_plot_spec,
model_spec = model_spec
)
)
})


Expand Down Expand Up @@ -73,6 +84,13 @@ test_that("plot_2d_metric() works", {
quiet = TRUE
)
)
perf_plot_spec$benchmark <- NULL
expect_no_error(
plot_2d_metric(
perf_plot_spec = perf_plot_spec,
quiet = TRUE
)
)
})


Expand Down

0 comments on commit c4114ef

Please sign in to comment.