Skip to content

Commit

Permalink
Merge pull request #18 from seroanalytics/relative_dates
Browse files Browse the repository at this point in the history
fix titre type lookup
  • Loading branch information
hillalex authored Oct 16, 2024
2 parents d74abf2 + f50de28 commit bc4da25
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 87 deletions.
25 changes: 9 additions & 16 deletions R/biokinetics.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ biokinetics <- R6::R6Class(
design_matrix = NULL,
covariate_lookup_table = NULL,
pid_lookup = NULL,
titre_type_lookup = NULL,
check_fitted = function() {
if (is.null(private$fitted)) {
stop("Model has not been fitted yet. Call 'fit' before calling this function.")
Expand Down Expand Up @@ -74,19 +75,11 @@ biokinetics <- R6::R6Class(
build_pid_lookup = function() {
private$pid_lookup <- build_pid_lookup(private$data)
},
build_titre_type_lookup = function() {
private$titre_type_lookup <- build_titre_type_lookup(private$data)
},
recover_covariate_names = function(dt) {
# Declare variables to suppress notes when compiling package
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153
titre_type <- NULL

titre_types <- as.factor(unique(private$data$titre_type))

dt_titre_lookup <- data.table(
k = as.numeric(titre_types),
titre_type = levels(titre_types)
)

dt_out <- dt[dt_titre_lookup, on = "k"][, `:=`(k = NULL)]
dt_out <- dt[, titre_type := names(private$titre_type_lookup)[k]][, `:=`(k = NULL)]
if ("p" %in% colnames(dt)) {
dt_out <- dt_out[private$covariate_lookup_table, on = "p", nomatch = NULL][, `:=`(p = NULL)]
}
Expand Down Expand Up @@ -162,14 +155,14 @@ biokinetics <- R6::R6Class(
dt_out
},
prepare_stan_data = function() {
pid <- value <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- NULL
pid <- value <- censored <- titre_type <- obs_id <- t_since_last_exp <- NULL
stan_data <- list(
N = private$data[, .N],
N_events = private$data[, data.table::uniqueN(pid)],
id = private$data[, private$pid_lookup[pid]],
value = private$data[, value],
censored = private$data[, censored],
titre_type = private$data[, titre_type_num],
titre_type = private$data[, private$titre_type_lookup[titre_type]],
preds_sd = private$preds_sd,
K = private$data[, data.table::uniqueN(titre_type)],
N_uncens = private$data[censored == 0, .N],
Expand Down Expand Up @@ -249,15 +242,15 @@ biokinetics <- R6::R6Class(
validate_formula_vars(private$all_formula_vars, private$data)
logger::log_info("Preparing data for stan")
private$data <- convert_log_scale(private$data, "value")
private$data[, `:=`(titre_type_num = as.numeric(as.factor(titre_type)),
obs_id = seq_len(.N),
private$data[, `:=`(obs_id = seq_len(.N),
t_since_last_exp = as.integer(day - last_exp_day, units = "days"))]
if (!("censored" %in% colnames(private$data))) {
private$data$censored <- 0
}
private$construct_design_matrix()
private$build_covariate_lookup_table()
private$build_pid_lookup()
private$build_titre_type_lookup()
private$prepare_stan_data()
logger::log_info("Retrieving compiled model")
private$model <- instantiate::stan_package_model(
Expand Down
8 changes: 8 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ build_pid_lookup <- function(data) {
names(pid_lookup) <- pids
pid_lookup
}

build_titre_type_lookup <- function(data) {
titre_types <- unique(data$titre_type)
titre_type_nums <- seq_along(titre_types)
titre_type_lookup <- titre_type_nums
names(titre_type_lookup) <- titre_types
titre_type_lookup
}
92 changes: 46 additions & 46 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@
delta
Output
variable mean median sd mad q5 q95 rhat ess_bulk
lp__ -1174.70 -1176.14 50.78 53.93 -1251.74 -1090.28 1.08 34
t0_pop[1] 4.13 4.13 0.28 0.27 3.66 4.55 1.02 201
t0_pop[2] 4.80 4.83 0.26 0.27 4.37 5.22 1.02 173
t0_pop[3] 3.52 3.51 0.28 0.27 3.09 3.99 1.02 189
tp_pop[1] 9.52 9.53 0.65 0.66 8.54 10.56 1.01 200
tp_pop[2] 10.72 10.74 0.63 0.59 9.68 11.70 1.02 215
tp_pop[3] 8.91 8.91 0.73 0.75 7.71 10.11 1.00 253
ts_pop_delta[1] 52.70 52.57 2.56 2.32 48.91 57.15 1.00 349
ts_pop_delta[2] 61.50 61.35 2.65 2.72 57.32 65.69 1.00 327
ts_pop_delta[3] 50.15 50.21 2.61 2.66 45.77 54.31 1.00 329
lp__ -1195.29 -1182.41 78.03 45.74 -1340.68 -1107.94 1.07 48
t0_pop[1] 4.80 4.79 0.24 0.24 4.45 5.21 1.04 97
t0_pop[2] 4.12 4.11 0.25 0.27 3.75 4.57 1.02 233
t0_pop[3] 3.52 3.52 0.26 0.27 3.10 3.92 1.02 205
tp_pop[1] 10.72 10.73 0.60 0.60 9.81 11.68 1.01 304
tp_pop[2] 9.55 9.55 0.64 0.65 8.61 10.57 1.01 349
tp_pop[3] 8.88 8.87 0.77 0.80 7.68 10.09 1.02 250
ts_pop_delta[1] 61.55 61.45 2.61 2.60 57.50 65.70 1.00 203
ts_pop_delta[2] 52.55 52.44 2.56 2.50 48.64 56.37 1.01 304
ts_pop_delta[3] 50.10 50.10 2.65 2.51 45.80 54.36 1.02 173
ess_tail
111
340
301
264
247
199
368
403
360
331
20
351
248
359
415
412
203
239
330
341
# showing 10 of 10103 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)

Expand All @@ -33,19 +33,19 @@
Code
trajectories
Output
time_since_last_exp me lo hi titre_type
<int> <num> <num> <num> <char>
1: 0 121.1892 94.45425 154.0367 Alpha
2: 1 150.7446 121.39549 186.0631 Alpha
3: 2 188.0531 155.65597 228.6746 Alpha
4: 3 233.8627 196.06447 281.9184 Alpha
5: 4 291.1750 244.87488 347.7999 Alpha
---
902: 146 162.4485 128.80162 200.2273 Delta
903: 147 161.7977 128.30153 199.5112 Delta
904: 148 161.1543 127.80338 198.8686 Delta
905: 149 160.6696 127.30716 198.2282 Delta
906: 150 159.9921 126.81288 197.5898 Delta
time_since_last_exp me lo hi titre_type
<int> <num> <num> <num> <char>
1: 0 120.1929 93.3773 154.1345 Ancestral
2: 1 150.4494 118.7096 186.3689 Ancestral
3: 2 187.4234 152.6142 227.3441 Ancestral
4: 3 233.3859 192.6505 280.5822 Ancestral
5: 4 291.5457 244.1629 346.3841 Ancestral
---
902: 146 160.6464 127.8557 205.7500 Delta
903: 147 159.9748 127.1635 204.8898 Delta
904: 148 159.2377 126.6307 204.0333 Delta
905: 149 158.5859 126.1045 203.1803 Delta
906: 150 157.9107 125.5806 202.4049 Delta
infection_history
<char>
1: Infection naive
Expand All @@ -65,17 +65,17 @@
Code
trajectories
Output
calendar_day titre_type me lo hi time_shift
<IDat> <char> <num> <num> <num> <num>
1: 2021-03-08 Alpha 1179.41596 898.59408 1592.5041 0
2: 2021-03-09 Alpha 1158.69205 865.99919 1558.7359 0
3: 2021-03-10 Alpha 1208.48440 953.30901 1520.2822 0
4: 2021-03-11 Alpha 1154.03970 900.55636 1490.7134 0
5: 2021-03-12 Alpha 1166.98651 885.90001 1506.8642 0
---
1775: 2022-08-07 Delta 82.89897 31.27422 308.8199 0
1776: 2022-08-08 Delta 84.04324 30.86336 312.8336 0
1777: 2022-08-09 Delta 84.28422 30.20377 319.3140 0
1778: 2022-08-10 Delta 86.52412 29.56398 320.8826 0
1779: 2022-08-11 Delta 86.69664 31.20098 317.9618 0
calendar_day titre_type me lo hi time_shift
<IDat> <char> <num> <num> <num> <num>
1: 2021-03-08 Ancestral 1172.55130 909.767640 1501.3053 0
2: 2021-03-09 Ancestral 1151.63927 882.770140 1488.7275 0
3: 2021-03-10 Ancestral 1184.73564 910.075137 1566.1617 0
4: 2021-03-11 Ancestral 1136.15980 888.320759 1485.9868 0
5: 2021-03-12 Ancestral 1153.10110 875.290893 1506.4791 0
---
1775: 2022-08-07 Delta 84.12329 11.738509 387.2564 0
1776: 2022-08-08 Delta 84.26154 10.517460 382.8336 0
1777: 2022-08-09 Delta 85.79343 10.647211 386.2053 0
1778: 2022-08-10 Delta 85.99704 9.774407 385.3051 0
1779: 2022-08-11 Delta 84.49904 11.370238 395.8927 0

25 changes: 0 additions & 25 deletions tests/testthat/test-non-numeric-pids.R
Original file line number Diff line number Diff line change
@@ -1,28 +1,3 @@
test_that("Can convert character pids to numeric ids and back again", {
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))

dat$pid <- paste0("ID", dat$pid)
lookup <- build_pid_lookup(dat)

pids <- dat$pid
dat[, nid := lookup[pid]]
dat[, recovered := names(lookup)[nid]]

expect_equal(dat$recovered, dat$pid)
})

test_that("Can convert numeric pids to numeric ids and back again", {
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))

lookup <- build_pid_lookup(dat)

pids <- dat$pid
dat[, nid := lookup[pid]]
dat[, recovered := as.numeric(names(lookup)[nid])]

expect_equal(dat$recovered, dat$pid)
})

test_that("Using numeric and non-numeric pids gives the same answer", {
# these take a while, so don't run on CI
skip_on_ci()
Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
test_that("Can convert character pids to numeric ids and back again", {
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))

dat$pid <- paste0("ID", dat$pid)
lookup <- build_pid_lookup(dat)

dat[, nid := lookup[pid]]
dat[, recovered := names(lookup)[nid]]

expect_equal(dat$recovered, dat$pid)
})

test_that("Can convert numeric pids to numeric ids and back again", {
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))

lookup <- build_pid_lookup(dat)

dat[, nid := lookup[pid]]
dat[, recovered := as.numeric(names(lookup)[nid])]

expect_equal(dat$recovered, dat$pid)
})

test_that("Can convert titre types to numbers and back again", {
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))

lookup <- build_titre_type_lookup(dat)

dat[, titre_type_num := lookup[titre_type]]
dat[, recovered := names(lookup)[titre_type_num]]

expect_equal(dat$recovered, dat$titre_type)
expect_equal(dat$recovered, dat$titre_type)
})

0 comments on commit bc4da25

Please sign in to comment.