From 66e4f0567f161a3640b9e2b85927181e2d86a943 Mon Sep 17 00:00:00 2001 From: "alex.hill@gmail.com" Date: Mon, 15 Jul 2024 19:39:41 +0100 Subject: [PATCH] make function interfaces more consistent --- R/scova.R | 66 +++++++++++-------- man/scova.Rd | 16 +++-- tests/testthat/test-extract-parameters.R | 49 ++++++++++++-- .../test-simulate-individual-trajectories.R | 2 +- .../test-simulate-population-trajectories.R | 6 +- 5 files changed, 97 insertions(+), 42 deletions(-) diff --git a/R/scova.R b/R/scova.R index cfe2f44..4919f8e 100644 --- a/R/scova.R +++ b/R/scova.R @@ -65,7 +65,6 @@ scova <- R6::R6Class( relevant_columns <- which(variance_per_column != 0) mm_reduced <- mm[, relevant_columns] private$design_matrix <- mm_reduced - mm_reduced }, build_covariate_lookup_table = function() { # Extract column names @@ -99,7 +98,6 @@ scova <- R6::R6Class( # Reorder columns to have 'i' first data.table::setcolorder(dt, "p") private$covariate_lookup_table <- dt - dt }, recover_covariate_names = function(dt) { # Declare variables to suppress notes when compiling package @@ -110,11 +108,11 @@ scova <- R6::R6Class( k = 1:private$data[, length(unique(titre_type))], titre_type = private$data[, unique(titre_type)]) + dt_out <- dt[dt_titre_lookup, on = "k"][, `:=`(k = NULL)] if ("p" %in% colnames(dt)) { - return(dt[private$covariate_lookup_table, on = "p"][dt_titre_lookup, on = "k"]) - } else { - return(dt[dt_titre_lookup, on = "k"]) + dt_out <- dt_out[private$covariate_lookup_table, on = "p"][, `:=`(p = NULL)] } + dt_out }, summarise_pop_fit = function(time_range, summarise, @@ -194,11 +192,10 @@ scova <- R6::R6Class( stan_data$t <- private$data[, t_since_min_date] } - X <- private$construct_design_matrix() - stan_data$X <- X - stan_data$P <- ncol(X) + stan_data$X <- private$design_matrix + stan_data$P <- ncol(private$design_matrix) - c(stan_data, private$priors) + private$stan_input_data <- c(stan_data, private$priors) }, adjust_parameters = function(dt) { params_to_adjust <- c( @@ -273,8 +270,9 @@ scova <- R6::R6Class( } logger::log_info("Preparing data for stan") private$data <- convert_log_scale(private$data, "titre") - private$stan_input_data <- private$prepare_stan_data() + private$construct_design_matrix() private$build_covariate_lookup_table() + private$prepare_stan_data() logger::log_info("Retrieving compiled model") private$model <- instantiate::stan_package_model( name = "antibody_kinetics_main", @@ -293,27 +291,38 @@ scova <- R6::R6Class( #' @description Extract fitted population parameters #' @return A data.table #' @param n_draws Numeric - extract_population_parameters = function(n_draws = 2500) { + #' @param human_readable_covariates Logical. Default TRUE. + extract_population_parameters = function(n_draws = 2500, + human_readable_covariates = TRUE) { private$check_fitted() params <- c("t0_pop[k]", "tp_pop[k]", "ts_pop[k]", "m1_pop[k]", "m2_pop[k]", "m3_pop[k]", "beta_t0[p]", "beta_tp[p]", "beta_ts[p]", "beta_m1[p]", "beta_m2[p]", "beta_m3[p]") + logger::log_info("Extracting parameters") dt_out <- private$extract_parameters(params, n_draws) data.table::setcolorder(dt_out, c("k", "p", ".draw")) + data.table::setnames(dt_out, ".draw", "draw") if (length(private$all_formula_vars) > 0) { + logger::log_info("Adjusting by covariates") dt_out <- private$adjust_parameters(dt_out) } - private$recover_covariate_names(dt_out) + if (human_readable_covariates) { + logger::log_info("Recovering covariate names") + dt_out <- private$recover_covariate_names(dt_out) + } + dt_out }, #' @description Extract fitted individual parameters #' @return A data.table #' @param n_draws Numeric #' @param include_variation_params Logical - extract_individual_parameters = function(include_variation_params = FALSE, - n_draws = 2500) { + #' @param human_readable_covariates Logical. Default TRUE. + extract_individual_parameters = function(n_draws = 2500, + include_variation_params = TRUE, + human_readable_covariates = TRUE) { private$check_fitted() params <- c("t0_ind[n, k]", "tp_ind[n, k]", "ts_ind[n, k]", "m1_ind[n, k]", "m2_ind[n, k]", "m3_ind[n, k]") @@ -324,12 +333,17 @@ scova <- R6::R6Class( params <- c(params, ind_var_params) } + logger::log_info("Extracting parameters") dt_out <- private$extract_parameters(params, n_draws) data.table::setcolorder(dt_out, c("n", "k", ".draw")) data.table::setnames(dt_out, c("n", ".draw"), c("stan_id", "draw")) - private$recover_covariate_names(dt_out) + if (human_readable_covariates) { + logger::log_info("Recovering covariate names") + dt_out <- private$recover_covariate_names(dt_out) + } + dt_out }, #' @description Process the model results into a data table of titre values over time. #' @return A data.table containing titre values at time points. If summarise = TRUE, columns are t, p, k, me, lo, hi, @@ -386,7 +400,8 @@ scova <- R6::R6Class( private$check_fitted() validate_numeric(n_draws) - dt_peak_switch <- self$extract_population_parameters(n_draws) + dt_peak_switch <- self$extract_population_parameters(n_draws, + human_readable_covariates = FALSE) logger::log_info("Calculating peak and switch titre values") dt_peak_switch[, `:=`( @@ -396,10 +411,10 @@ scova <- R6::R6Class( tp_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop), mu_s = scova_simulate_trajectory( ts_pop, t0_pop, tp_pop, ts_pop, m1_pop, m2_pop, m3_pop)), - by = c("p", "k", ".draw")] + by = c("p", "k", "draw")] - # logger::log_info("Recovering covariate names") - # dt_peak_switch <- private$recover_covariate_names(dt_peak_switch) + logger::log_info("Recovering covariate names") + dt_peak_switch <- private$recover_covariate_names(dt_peak_switch) dt_peak_switch <- convert_log_scale_inverse( dt_peak_switch, vars_to_transform = c("mu_0", "mu_p", "mu_s")) @@ -436,7 +451,9 @@ scova <- R6::R6Class( validate_numeric(time_shift) # Extracting parameters from fit - dt_params_ind <- self$extract_individual_parameters()[!is.nan(t0_ind)] + dt_params_ind <- self$extract_individual_parameters(n_draws, + human_readable_covariates = FALSE, + include_variation_params = FALSE)[!is.nan(t0_ind)] # Calculating the maximum time each individual has data for after the # exposure of interest @@ -462,14 +479,7 @@ scova <- R6::R6Class( dt_params_ind_traj <- data.table::setDT(convert_log_scale_inverse_cpp( dt_params_ind_traj, vars_to_transform = "mu")) - # dt_titre_types <- data.table( - # titre_type = private$data[, unique(titre_type)], - # titre_type_num = dt_params_ind_traj[, unique(titre_type_num)]) - # - # dt_params_ind_traj <- merge( - # dt_params_ind_traj, - # dt_titre_types, - # by = "titre_type_num")[, titre_type_num := NULL] + logger::log_info("Recovering covariate names") dt_params_ind_traj <- private$recover_covariate_names(dt_params_ind_traj) logger::log_info(paste("Calculating exposure dates. Adjusting exposures by", time_shift, "days")) diff --git a/man/scova.Rd b/man/scova.Rd index e2f586e..499fb5f 100644 --- a/man/scova.Rd +++ b/man/scova.Rd @@ -83,13 +83,18 @@ A CmdStanMCMC fitted model object: \url{https://mc-stan.org/cmdstanr/reference/C \subsection{Method \code{extract_population_parameters()}}{ Extract fitted population parameters \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{scova$extract_population_parameters(n_draws = 2500)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{scova$extract_population_parameters( + n_draws = 2500, + human_readable_covariates = TRUE +)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{n_draws}}{Numeric} + +\item{\code{human_readable_covariates}}{Logical. Default TRUE.} } \if{html}{\out{
}} } @@ -104,17 +109,20 @@ A data.table Extract fitted individual parameters \subsection{Usage}{ \if{html}{\out{
}}\preformatted{scova$extract_individual_parameters( - include_variation_params = FALSE, - n_draws = 2500 + n_draws = 2500, + include_variation_params = TRUE, + human_readable_covariates = TRUE )}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ +\item{\code{n_draws}}{Numeric} + \item{\code{include_variation_params}}{Logical} -\item{\code{n_draws}}{Numeric} +\item{\code{human_readable_covariates}}{Logical. Default TRUE.} } \if{html}{\out{
}} } diff --git a/tests/testthat/test-extract-parameters.R b/tests/testthat/test-extract-parameters.R index e8bf3d6..1690e5b 100644 --- a/tests/testthat/test-extract-parameters.R +++ b/tests/testthat/test-extract-parameters.R @@ -16,21 +16,58 @@ test_that("Cannot retrieve individual params until model is fitted", { expect_error(mod$extract_individual_parameters(), "Model has not been fitted yet. Call 'fit' before calling this function.") }) -test_that("Can extract population parameters", { +test_that("Can extract population parameters without human readable covariates", { mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), covariate_formula = ~0 + infection_history) mod$fit() - params <- mod$extract_population_parameters() - expect_equal(names(params), c("k", "p", ".draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", + params <- mod$extract_population_parameters(n_draws = 10, human_readable_covariates = FALSE) + expect_equal(names(params), c("k", "p", "draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", + "beta_t0", "beta_tp", "beta_ts", "beta_m1", "beta_m2", "beta_m3")) +}) + +test_that("Can extract population parameters with human readable covariates", { + mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), + covariate_formula = ~0 + infection_history) + mod$fit() + params <- mod$extract_population_parameters(n_draws = 10, human_readable_covariates = TRUE) + expect_equal(names(params), c("draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", "beta_t0", "beta_tp", "beta_ts", "beta_m1", "beta_m2", "beta_m3", - "infection_history", "titre_type")) + "titre_type", "infection_history")) }) -test_that("Can extract individual parameters", { +test_that("Can extract individual parameters without human readable covariates", { mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), covariate_formula = ~0 + infection_history) mod$fit() - params <- mod$extract_individual_parameters(n_draws = 10) + params <- mod$extract_individual_parameters(n_draws = 10, + human_readable_covariates = FALSE, + include_variation_params = FALSE) expect_equal(names(params), c("stan_id", "k", "draw", "t0_ind", "tp_ind", "ts_ind", + "m1_ind", "m2_ind", "m3_ind")) +}) + +test_that("Can extract individual parameters with human readable covariates", { + mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), + covariate_formula = ~0 + infection_history) + mod$fit() + params <- mod$extract_individual_parameters(n_draws = 10, + human_readable_covariates = TRUE, + include_variation_params = FALSE) + expect_equal(names(params), c("stan_id", "draw", "t0_ind", "tp_ind", "ts_ind", "m1_ind", "m2_ind", "m3_ind", "titre_type")) }) + +test_that("Can extract individual parameters with variation params", { + mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"), + covariate_formula = ~0 + infection_history) + mod$fit() + params <- mod$extract_individual_parameters(n_draws = 10, + human_readable_covariates = TRUE, + include_variation_params = TRUE) + expect_equal(names(params), c("stan_id", "draw", + "t0_ind", "tp_ind", "ts_ind", + "m1_ind", "m2_ind", "m3_ind", + "z_t0", "z_tp", "z_ts", + "z_m1", "z_m2", "z_m3", + "titre_type")) +}) diff --git a/tests/testthat/test-simulate-individual-trajectories.R b/tests/testthat/test-simulate-individual-trajectories.R index fbdce3c..208e54d 100644 --- a/tests/testthat/test-simulate-individual-trajectories.R +++ b/tests/testthat/test-simulate-individual-trajectories.R @@ -60,7 +60,7 @@ test_that("Can retrieve un-summarised trajectories", { covariate_formula = ~0 + infection_history) mod$fit() trajectories <- mod$simulate_individual_trajectories(summarise = FALSE, n_draws = 10) - expect_equal(names(trajectories), c("stan_id", "k", "draw", "t", "mu", "titre_type", "infection_history", + expect_equal(names(trajectories), c("stan_id", "draw", "t", "mu", "titre_type", "infection_history", "exposure_date", "calendar_date", "time_shift")) }) diff --git a/tests/testthat/test-simulate-population-trajectories.R b/tests/testthat/test-simulate-population-trajectories.R index 07fddde..cf69004 100644 --- a/tests/testthat/test-simulate-population-trajectories.R +++ b/tests/testthat/test-simulate-population-trajectories.R @@ -26,7 +26,7 @@ test_that("Can retrieve summarised trajectories", { covariate_formula = ~0 + infection_history) mod$fit() trajectories <- mod$simulate_population_trajectories(summarise = TRUE) - expect_equal(names(trajectories), c("t", "p", "k", "me", "lo", "hi", "infection_history", "titre_type")) + expect_equal(names(trajectories), c("t", "me", "lo", "hi", "titre_type", "infection_history")) }) test_that("Can retrieve un-summarised trajectories", { @@ -34,9 +34,9 @@ test_that("Can retrieve un-summarised trajectories", { covariate_formula = ~0 + infection_history) mod$fit() trajectories <- mod$simulate_population_trajectories(summarise = FALSE) - expect_equal(names(trajectories), c("t", "p", "k", ".draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", + expect_equal(names(trajectories), c("t", ".draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop", "m3_pop", "beta_t0", "beta_tp", "beta_ts", "beta_m1", "beta_m2", - "beta_m3", "mu", "infection_history", "titre_type")) + "beta_m3", "mu", "titre_type", "infection_history")) }) test_that("Absolute dates are returned if time_type is 'absolute'", {