-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
i2: support multiple categorical covariates #4
Changes from 5 commits
161da12
f1fa3df
bfaaa9e
673a36b
f0e2cae
07df22f
9280ceb
44541ea
08c4c38
35fbfe0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ src/stan/**/*.exe | |
src/stan/**/*.EXE | ||
inst/doc | ||
.idea | ||
*.png |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,41 +63,13 @@ | |
# Identify columns with no variance and remove them | ||
variance_per_column <- apply(mm, 2, var) | ||
relevant_columns <- which(variance_per_column != 0) | ||
mm_reduced <- mm[, relevant_columns] | ||
mm_reduced <- mm[, relevant_columns, drop = FALSE] | ||
private$design_matrix <- mm_reduced | ||
}, | ||
build_covariate_lookup_table = function() { | ||
# Extract column names | ||
col_names <- colnames(private$design_matrix) | ||
|
||
# Split column names based on the ':' delimiter | ||
split_data <- stringr::str_split(col_names, ":", simplify = TRUE) | ||
|
||
# Convert the matrix to a data.table | ||
dt <- data.table::as.data.table(split_data) | ||
|
||
# Set the new column names | ||
data.table::setnames(dt, private$all_formula_vars) | ||
|
||
for (col_name in names(dt)) { | ||
# Find the matching formula variable for current column | ||
matching_formula_var <- private$all_formula_vars[which(startsWith(col_name, private$all_formula_vars))] | ||
if (length(matching_formula_var) > 0) { | ||
pattern_to_remove <- paste0("^", matching_formula_var) | ||
dt[, (col_name) := stringr::str_remove_all(get(col_name), pattern_to_remove)] | ||
} | ||
} | ||
|
||
# Declare variables to suppress notes when compiling package | ||
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 | ||
p <- NULL | ||
|
||
# .I is a special symbol in data.table for row number | ||
dt[, p := .I] | ||
|
||
# Reorder columns to have 'i' first | ||
data.table::setcolorder(dt, "p") | ||
private$covariate_lookup_table <- dt | ||
private$covariate_lookup_table <- build_covariate_lookup_table(private$data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved logic out to own function, mostly for ease of unit testing (but also could be re-used if/when we add CT model) |
||
private$design_matrix, | ||
private$all_formula_vars) | ||
}, | ||
recover_covariate_names = function(dt) { | ||
# Declare variables to suppress notes when compiling package | ||
|
@@ -110,35 +82,46 @@ | |
|
||
dt_out <- dt[dt_titre_lookup, on = "k"][, `:=`(k = NULL)] | ||
if ("p" %in% colnames(dt)) { | ||
dt_out <- dt_out[private$covariate_lookup_table, on = "p"][, `:=`(p = NULL)] | ||
dt_out <- dt_out[private$covariate_lookup_table, on = "p", nomatch = NULL][, `:=`(p = NULL)] | ||
} | ||
dt_out | ||
}, | ||
summarise_pop_fit = function(time_range, | ||
summarise, | ||
n_draws) { | ||
|
||
has_covariates <- length(private$all_formula_vars) > 0 | ||
|
||
# Declare variables to suppress notes when compiling package | ||
# https://github.com/Rdatatable/data.table/issues/850#issuecomment-259466153 | ||
t0_pop <- tp_pop <- ts_pop <- m1_pop <- m2_pop <- m3_pop <- NULL | ||
beta_t0 <- beta_tp <- beta_ts <- beta_m1 <- beta_m2 <- beta_m3 <- NULL | ||
k <- p <- .draw <- t_id <- mu <- NULL | ||
|
||
params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", | ||
"m1_pop[k]", "m2_pop[k]", "m3_pop[k]") | ||
if (has_covariates) { | ||
params <- c(params, "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", | ||
"beta_m1[p]", "beta_m2[p]", "beta_m3[p]") | ||
} | ||
|
||
params_proc <- rlang::parse_exprs(params) | ||
|
||
dt_samples_wide <- tidybayes::spread_draws( | ||
private$fitted, | ||
t0_pop[k], tp_pop[k], ts_pop[k], | ||
m1_pop[k], m2_pop[k], m3_pop[k], | ||
beta_t0[p], beta_tp[p], beta_ts[p], | ||
beta_m1[p], beta_m2[p], beta_m3[p]) |> | ||
private$fitted, !!!params_proc) |> | ||
data.table() | ||
|
||
dt_samples_wide <- dt_samples_wide[.draw %in% 1:n_draws] | ||
|
||
dt_samples_wide[, `:=`(.chain = NULL, .iteration = NULL)] | ||
|
||
if (!has_covariates) { | ||
# there are no covariates, so add dummy column | ||
# that will be removed after processing | ||
dt_samples_wide$p <- 1 | ||
} | ||
|
||
data.table::setcolorder(dt_samples_wide, c("k", "p", ".draw")) | ||
|
||
if (length(private$all_formula_vars) > 0) { | ||
if (has_covariates) { | ||
logger::log_info("Adjusting by regression coefficients") | ||
dt_samples_wide <- private$adjust_parameters(dt_samples_wide) | ||
} | ||
|
@@ -165,6 +148,11 @@ | |
} | ||
|
||
data.table::setcolorder(dt_out, c("t", "p", "k")) | ||
|
||
if (!has_covariates) { | ||
dt_out[, p:= NULL] | ||
} | ||
dt_out | ||
}, | ||
prepare_stan_data = function() { | ||
stan_id <- titre <- censored <- titre_type_num <- titre_type <- obs_id <- t_since_last_exp <- t_since_min_date <- NULL | ||
|
@@ -280,6 +268,9 @@ | |
package = "epikinetics" | ||
) | ||
}, | ||
get_design_matrix = function() { | ||
private$design_matrix | ||
}, | ||
#' @description Fit the model and return CmdStanMCMC fitted model object. | ||
#' @return A CmdStanMCMC fitted model object: <https://mc-stan.org/cmdstanr/reference/CmdStanMCMC.html> | ||
#' @param ... Named arguments to the `sample()` method of CmdStan model. | ||
|
@@ -296,14 +287,23 @@ | |
extract_population_parameters = function(n_draws = 2500, | ||
human_readable_covariates = TRUE) { | ||
private$check_fitted() | ||
params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", | ||
"m3_pop[k]", "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", | ||
"beta_m2[p]", "beta_m3[p]") | ||
has_covariates <- length(private$all_formula_vars) > 0 | ||
|
||
params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", "m3_pop[k]") | ||
|
||
if (has_covariates) { | ||
params <- c(params, "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", "beta_m2[p]", "beta_m3[p]") | ||
} | ||
|
||
logger::log_info("Extracting parameters") | ||
dt_out <- private$extract_parameters(params, n_draws) | ||
|
||
data.table::setcolorder(dt_out, c("k", "p", ".draw")) | ||
if (has_covariates){ | ||
data.table::setcolorder(dt_out, c("p", "k", ".draw")) | ||
} else { | ||
data.table::setcolorder(dt_out, c("k", ".draw")) | ||
} | ||
|
||
data.table::setnames(dt_out, ".draw", "draw") | ||
|
||
if (length(private$all_formula_vars) > 0) { | ||
|
@@ -338,6 +338,7 @@ | |
dt_out <- private$extract_parameters(params, n_draws) | ||
|
||
data.table::setcolorder(dt_out, c("n", "k", ".draw")) | ||
|
||
data.table::setnames(dt_out, c("n", ".draw"), c("stan_id", "draw")) | ||
|
||
if (human_readable_covariates) { | ||
|
@@ -408,14 +409,20 @@ | |
human_readable_covariates = FALSE) | ||
|
||
logger::log_info("Calculating peak and switch titre values") | ||
|
||
by <- c("k", "draw") | ||
if ("p" %in% colnames(dt_peak_switch)) { | ||
by <- c("p", by) | ||
} | ||
|
||
dt_peak_switch[, `:=`( | ||
mu_0 = scova_simulate_trajectory( | ||
0, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), | ||
mu_p = scova_simulate_trajectory( | ||
tp_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), | ||
mu_s = scova_simulate_trajectory( | ||
ts_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop)), | ||
by = c("p", "k", "draw")] | ||
by = by] | ||
|
||
logger::log_info("Recovering covariate names") | ||
dt_peak_switch <- private$recover_covariate_names(dt_peak_switch) | ||
|
@@ -482,7 +489,7 @@ | |
dt_params_ind_traj <- scova_simulate_trajectories(dt_params_ind_trim) | ||
|
||
dt_params_ind_traj <- data.table::setDT(convert_log_scale_inverse_cpp( | ||
dt_params_ind_traj, vars_to_transform = "mu")) | ||
dt_params_ind_traj, vars_to_transform = "mu")) | ||
|
||
logger::log_info("Recovering covariate names") | ||
dt_params_ind_traj <- private$recover_covariate_names(dt_params_ind_traj) | ||
|
@@ -510,7 +517,7 @@ | |
by = c("calendar_date", "titre_type")) | ||
} | ||
|
||
dt_out[, time_shift:= time_shift] | ||
dt_out[, time_shift := time_shift] | ||
} | ||
) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
library(ggplot2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added this file just for manual testing |
||
library(epikinetics) | ||
# plot functions for manual testing | ||
|
||
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), | ||
priors = scova_priors()) | ||
|
||
mod$fit(chains = 4, | ||
parallel_chains = 4, | ||
iter_warmup = 50, | ||
iter_sampling = 200, | ||
threads_per_chain = 4) | ||
|
||
dat <- mod$simulate_population_trajectories() | ||
dat[, titre_type := forcats::fct_relevel( | ||
titre_type, | ||
c("Ancestral", "Alpha", "Delta"))] | ||
|
||
ggplot(data = dat) + | ||
geom_line(aes(x = t, | ||
y = me, | ||
colour = titre_type)) + | ||
geom_ribbon(aes(x = t, | ||
ymin = lo, | ||
ymax = hi, | ||
fill = titre_type), alpha = 0.65) + | ||
coord_cartesian(clip = "off") + | ||
labs(x = "Time since last exposure (days)", | ||
y = expression(paste("Titre (IC"[50], ")"))) + | ||
scale_y_continuous( | ||
trans = "log2") + | ||
facet_wrap(~titre_type) | ||
|
||
dat <- mod$population_stationary_points() | ||
dat[, titre_type := forcats::fct_relevel( | ||
titre_type, | ||
c("Ancestral", "Alpha", "Delta"))] | ||
|
||
ggplot(data = dat, aes( | ||
x = mu_p, y = mu_s, | ||
colour = titre_type)) + | ||
geom_density_2d( | ||
aes( | ||
group = titre_type)) + | ||
geom_point(alpha = 0.05, size = 0.2) + | ||
geom_point(aes(x = mu_p_me, y = mu_s_me), | ||
colour = "black") + | ||
geom_path(aes(x = mu_p_me, y = mu_s_me, | ||
group = titre_type), | ||
colour = "black") + | ||
geom_vline(xintercept = 2560, linetype = "twodash", colour = "gray30") + | ||
scale_x_continuous( | ||
trans = "log2", | ||
breaks = c(40, 80, 160, 320, 640, 1280, 2560, 5120, 10240), | ||
labels = c(expression(" " <= 40), | ||
"80", "160", "320", "640", "1280", "2560", "5120", "10240"), | ||
limits = c(NA, 10240)) + | ||
geom_hline(yintercept = 2560, linetype = "twodash", colour = "gray30") + | ||
scale_y_continuous( | ||
trans = "log2", | ||
breaks = c(40, 80, 160, 320, 640, 1280, 2560, 5120, 10240), | ||
labels = c(expression(" " <= 40), | ||
"80", "160", "320", "640", "1280", "2560", "5120", "10240"), | ||
limits = c(NA, 5120)) + | ||
scale_shape_manual(values = c(1, 2, 3)) + | ||
labs(x = expression(paste("Population-level titre value at peak (IC"[50], ")")), | ||
y = expression(paste("Population-level titre value at set-point (IC"[50], ")"))) + | ||
guides(colour = guide_legend(title = "Titre type", override.aes = list(alpha = 1, size = 1)), | ||
shape = guide_legend(title = "Infection history")) | ||
|
||
dat <- mod$simulate_individual_trajectories(summarise = FALSE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just in case this becomes a 1 column matrix, stop it from being converted to a vector with colnames removed