From f50de28dbfefdebe595c3b35e9bd329762d74dc2 Mon Sep 17 00:00:00 2001 From: "alex.hill@gmail.com" Date: Tue, 15 Oct 2024 22:13:31 +0100 Subject: [PATCH] fix titre type lookup --- R/biokinetics.R | 25 +++---- R/utils.R | 8 +++ tests/testthat/_snaps/snapshots.md | 92 +++++++++++++------------- tests/testthat/test-non-numeric-pids.R | 25 ------- tests/testthat/test-utils.R | 34 ++++++++++ 5 files changed, 97 insertions(+), 87 deletions(-) create mode 100644 tests/testthat/test-utils.R diff --git a/R/biokinetics.R b/R/biokinetics.R index 8dba2d2..041fa4b 100644 --- a/R/biokinetics.R +++ b/R/biokinetics.R @@ -19,6 +19,7 @@ biokinetics <- R6::R6Class( design_matrix = NULL, covariate_lookup_table = NULL, pid_lookup = NULL, + titre_type_lookup = NULL, check_fitted = function() { if (is.null(private$fitted)) { stop("Model has not been fitted yet. Call 'fit' before calling this function.") @@ -74,19 +75,11 @@ biokinetics <- R6::R6Class( build_pid_lookup = function() { private$pid_lookup <- build_pid_lookup(private$data) }, + build_titre_type_lookup = function() { + private$titre_type_lookup <- build_titre_type_lookup(private$data) + }, recover_covariate_names = function(dt) { - # Declare variables to suppress notes when compiling package - # https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 - titre_type <- NULL - - titre_types <- as.factor(unique(private$data$titre_type)) - - dt_titre_lookup <- data.table( - k = as.numeric(titre_types), - titre_type = levels(titre_types) - ) - - dt_out <- dt[dt_titre_lookup, on = "k"][, `:=`(k = NULL)] + dt_out <- dt[, titre_type := names(private$titre_type_lookup)[k]][, `:=`(k = NULL)] if ("p" %in% colnames(dt)) { dt_out <- dt_out[private$covariate_lookup_table, on = "p", nomatch = NULL][, `:=`(p = NULL)] } @@ -162,14 +155,14 @@ biokinetics <- R6::R6Class( dt_out }, prepare_stan_data = function() { - pid <- value <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- NULL + pid <- value <- censored <- titre_type <- obs_id <- t_since_last_exp <- NULL stan_data <- list( N = private$data[, .N], N_events = private$data[, data.table::uniqueN(pid)], id = private$data[, private$pid_lookup[pid]], value = private$data[, value], censored = private$data[, censored], - titre_type = private$data[, titre_type_num], + titre_type = private$data[, private$titre_type_lookup[titre_type]], preds_sd = private$preds_sd, K = private$data[, data.table::uniqueN(titre_type)], N_uncens = private$data[censored == 0, .N], @@ -249,8 +242,7 @@ biokinetics <- R6::R6Class( validate_formula_vars(private$all_formula_vars, private$data) logger::log_info("Preparing data for stan") private$data <- convert_log_scale(private$data, "value") - private$data[, `:=`(titre_type_num = as.numeric(as.factor(titre_type)), - obs_id = seq_len(.N), + private$data[, `:=`(obs_id = seq_len(.N), t_since_last_exp = as.integer(day - last_exp_day, units = "days"))] if (!("censored" %in% colnames(private$data))) { private$data$censored <- 0 @@ -258,6 +250,7 @@ biokinetics <- R6::R6Class( private$construct_design_matrix() private$build_covariate_lookup_table() private$build_pid_lookup() + private$build_titre_type_lookup() private$prepare_stan_data() logger::log_info("Retrieving compiled model") private$model <- instantiate::stan_package_model( diff --git a/R/utils.R b/R/utils.R index bae23d1..576cfc2 100644 --- a/R/utils.R +++ b/R/utils.R @@ -97,3 +97,11 @@ build_pid_lookup <- function(data) { names(pid_lookup) <- pids pid_lookup } + +build_titre_type_lookup <- function(data) { + titre_types <- unique(data$titre_type) + titre_type_nums <- seq_along(titre_types) + titre_type_lookup <- titre_type_nums + names(titre_type_lookup) <- titre_types + titre_type_lookup +} diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index ce37820..0798161 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -4,27 +4,27 @@ delta Output variable mean median sd mad q5 q95 rhat ess_bulk - lp__ -1174.70 -1176.14 50.78 53.93 -1251.74 -1090.28 1.08 34 - t0_pop[1] 4.13 4.13 0.28 0.27 3.66 4.55 1.02 201 - t0_pop[2] 4.80 4.83 0.26 0.27 4.37 5.22 1.02 173 - t0_pop[3] 3.52 3.51 0.28 0.27 3.09 3.99 1.02 189 - tp_pop[1] 9.52 9.53 0.65 0.66 8.54 10.56 1.01 200 - tp_pop[2] 10.72 10.74 0.63 0.59 9.68 11.70 1.02 215 - tp_pop[3] 8.91 8.91 0.73 0.75 7.71 10.11 1.00 253 - ts_pop_delta[1] 52.70 52.57 2.56 2.32 48.91 57.15 1.00 349 - ts_pop_delta[2] 61.50 61.35 2.65 2.72 57.32 65.69 1.00 327 - ts_pop_delta[3] 50.15 50.21 2.61 2.66 45.77 54.31 1.00 329 + lp__ -1195.29 -1182.41 78.03 45.74 -1340.68 -1107.94 1.07 48 + t0_pop[1] 4.80 4.79 0.24 0.24 4.45 5.21 1.04 97 + t0_pop[2] 4.12 4.11 0.25 0.27 3.75 4.57 1.02 233 + t0_pop[3] 3.52 3.52 0.26 0.27 3.10 3.92 1.02 205 + tp_pop[1] 10.72 10.73 0.60 0.60 9.81 11.68 1.01 304 + tp_pop[2] 9.55 9.55 0.64 0.65 8.61 10.57 1.01 349 + tp_pop[3] 8.88 8.87 0.77 0.80 7.68 10.09 1.02 250 + ts_pop_delta[1] 61.55 61.45 2.61 2.60 57.50 65.70 1.00 203 + ts_pop_delta[2] 52.55 52.44 2.56 2.50 48.64 56.37 1.01 304 + ts_pop_delta[3] 50.10 50.10 2.65 2.51 45.80 54.36 1.02 173 ess_tail - 111 - 340 - 301 - 264 - 247 - 199 - 368 - 403 - 360 - 331 + 20 + 351 + 248 + 359 + 415 + 412 + 203 + 239 + 330 + 341 # showing 10 of 10103 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option) @@ -33,19 +33,19 @@ Code trajectories Output - time_since_last_exp me lo hi titre_type - - 1: 0 121.1892 94.45425 154.0367 Alpha - 2: 1 150.7446 121.39549 186.0631 Alpha - 3: 2 188.0531 155.65597 228.6746 Alpha - 4: 3 233.8627 196.06447 281.9184 Alpha - 5: 4 291.1750 244.87488 347.7999 Alpha - --- - 902: 146 162.4485 128.80162 200.2273 Delta - 903: 147 161.7977 128.30153 199.5112 Delta - 904: 148 161.1543 127.80338 198.8686 Delta - 905: 149 160.6696 127.30716 198.2282 Delta - 906: 150 159.9921 126.81288 197.5898 Delta + time_since_last_exp me lo hi titre_type + + 1: 0 120.1929 93.3773 154.1345 Ancestral + 2: 1 150.4494 118.7096 186.3689 Ancestral + 3: 2 187.4234 152.6142 227.3441 Ancestral + 4: 3 233.3859 192.6505 280.5822 Ancestral + 5: 4 291.5457 244.1629 346.3841 Ancestral + --- + 902: 146 160.6464 127.8557 205.7500 Delta + 903: 147 159.9748 127.1635 204.8898 Delta + 904: 148 159.2377 126.6307 204.0333 Delta + 905: 149 158.5859 126.1045 203.1803 Delta + 906: 150 157.9107 125.5806 202.4049 Delta infection_history 1: Infection naive @@ -65,17 +65,17 @@ Code trajectories Output - calendar_day titre_type me lo hi time_shift - - 1: 2021-03-08 Alpha 1179.41596 898.59408 1592.5041 0 - 2: 2021-03-09 Alpha 1158.69205 865.99919 1558.7359 0 - 3: 2021-03-10 Alpha 1208.48440 953.30901 1520.2822 0 - 4: 2021-03-11 Alpha 1154.03970 900.55636 1490.7134 0 - 5: 2021-03-12 Alpha 1166.98651 885.90001 1506.8642 0 - --- - 1775: 2022-08-07 Delta 82.89897 31.27422 308.8199 0 - 1776: 2022-08-08 Delta 84.04324 30.86336 312.8336 0 - 1777: 2022-08-09 Delta 84.28422 30.20377 319.3140 0 - 1778: 2022-08-10 Delta 86.52412 29.56398 320.8826 0 - 1779: 2022-08-11 Delta 86.69664 31.20098 317.9618 0 + calendar_day titre_type me lo hi time_shift + + 1: 2021-03-08 Ancestral 1172.55130 909.767640 1501.3053 0 + 2: 2021-03-09 Ancestral 1151.63927 882.770140 1488.7275 0 + 3: 2021-03-10 Ancestral 1184.73564 910.075137 1566.1617 0 + 4: 2021-03-11 Ancestral 1136.15980 888.320759 1485.9868 0 + 5: 2021-03-12 Ancestral 1153.10110 875.290893 1506.4791 0 + --- + 1775: 2022-08-07 Delta 84.12329 11.738509 387.2564 0 + 1776: 2022-08-08 Delta 84.26154 10.517460 382.8336 0 + 1777: 2022-08-09 Delta 85.79343 10.647211 386.2053 0 + 1778: 2022-08-10 Delta 85.99704 9.774407 385.3051 0 + 1779: 2022-08-11 Delta 84.49904 11.370238 395.8927 0 diff --git a/tests/testthat/test-non-numeric-pids.R b/tests/testthat/test-non-numeric-pids.R index d936f16..63529d3 100644 --- a/tests/testthat/test-non-numeric-pids.R +++ b/tests/testthat/test-non-numeric-pids.R @@ -1,28 +1,3 @@ -test_that("Can convert character pids to numeric ids and back again", { - dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) - - dat$pid <- paste0("ID", dat$pid) - lookup <- build_pid_lookup(dat) - - pids <- dat$pid - dat[, nid := lookup[pid]] - dat[, recovered := names(lookup)[nid]] - - expect_equal(dat$recovered, dat$pid) -}) - -test_that("Can convert numeric pids to numeric ids and back again", { - dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) - - lookup <- build_pid_lookup(dat) - - pids <- dat$pid - dat[, nid := lookup[pid]] - dat[, recovered := as.numeric(names(lookup)[nid])] - - expect_equal(dat$recovered, dat$pid) -}) - test_that("Using numeric and non-numeric pids gives the same answer", { # these take a while, so don't run on CI skip_on_ci() diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R new file mode 100644 index 0000000..6807221 --- /dev/null +++ b/tests/testthat/test-utils.R @@ -0,0 +1,34 @@ +test_that("Can convert character pids to numeric ids and back again", { + dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) + + dat$pid <- paste0("ID", dat$pid) + lookup <- build_pid_lookup(dat) + + dat[, nid := lookup[pid]] + dat[, recovered := names(lookup)[nid]] + + expect_equal(dat$recovered, dat$pid) +}) + +test_that("Can convert numeric pids to numeric ids and back again", { + dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) + + lookup <- build_pid_lookup(dat) + + dat[, nid := lookup[pid]] + dat[, recovered := as.numeric(names(lookup)[nid])] + + expect_equal(dat$recovered, dat$pid) +}) + +test_that("Can convert titre types to numbers and back again", { + dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics")) + + lookup <- build_titre_type_lookup(dat) + + dat[, titre_type_num := lookup[titre_type]] + dat[, recovered := names(lookup)[titre_type_num]] + + expect_equal(dat$recovered, dat$titre_type) + expect_equal(dat$recovered, dat$titre_type) +})