Skip to content

Commit

Permalink
add tests and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hillalex committed Jul 12, 2024
1 parent 5c6c031 commit d84efd7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 19 deletions.
6 changes: 2 additions & 4 deletions R/scova.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,7 @@ scova <- R6::R6Class(
stop("'priors' must be of type 'scova_priors'")
}
private$priors <- priors
if (!is.numeric(preds_sd)) {
stop("'preds_sd' must be a number")
}
validate_numeric(preds_sd)
private$preds_sd <- preds_sd
validate_time_type(time_type)
private$time_type <- time_type
Expand Down Expand Up @@ -359,7 +357,7 @@ scova <- R6::R6Class(

if (time_type == "absolute") {
logger::log_info("Converting to absolute time")
dt_out[, date := dt[, unique(min(date))] + t,
dt_out[, date := private$data[, unique(min(date))] + t,
by = c(private$all_formula_vars, "titre_type")]
}

Expand Down
15 changes: 0 additions & 15 deletions R/validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,3 @@ validate_time_type <- function(time_type) {
stop("'time_type' must be one of 'relative' or 'absolute'")
}
}

validate_scale <- function(scale) {
if (!(scale %in% c("natural", "log"))) {
stop("'scale' must be one of 'natural' or 'log'")
}
}

validate_covariates <- function(vec, formula) {
all_vars <- all.vars(formula)
res <- vec[!sapply(vec, function(v) v %in% all_vars)]
if (length(res) > 0) {
stop(paste0("'by' must contain variables present in hierarchical model. '",
paste0(res, collapse = ", "), "' not present in model."))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ test_that("Priors must be of type 'scova_priors'", {
"'priors' must be of type 'scova_priors'")
})

test_that("preds_sd must be numeric", {
expect_error(scova$new(preds_sd = "bad"),
"'preds_sd' must be numeric")
})

test_that("Time type must be 'absolute' or 'relative'", {
expect_error(scova$new(time_type = "bad"),
"'time_type' must be one of 'relative' or 'absolute'")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ test_that("Cannot retrieve trajectories until model is fitted", {
expect_error(mod$simulate_population_trajectories(), "Model has not been fitted yet. Call 'fit' before calling this function.")
})

test_that("Validates inputs", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
expect_error(mod$simulate_population_trajectories(summarise = "bad"), "'summarise' must be logical")
expect_error(mod$simulate_population_trajectories(n_draws = "bad"), "'n_draws' must be numeric")
expect_error(mod$simulate_population_trajectories(time_type = "bad"), "'time_type' must be one of 'relative' or 'absolute'")
expect_error(mod$simulate_population_trajectories(t_max = "bad"), "'t_max' must be numeric")
})

test_that("Can retrieve summarised trajectories", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
Expand All @@ -28,3 +38,20 @@ test_that("Can retrieve un-summarised trajectories", {
"m3_pop", "beta_t0", "beta_tp", "beta_ts", "beta_m1", "beta_m2",
"beta_m3", "mu", "infection_history", "titre_type"))
})

test_that("Absolute dates are returned if time_type is 'absolute'", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
trajectories <- mod$simulate_population_trajectories(summarise = TRUE, time_type = "absolute")
expect_equal(class(trajectories$date), c("IDate", "Date"))
expect_equal(trajectories$date, as.IDate("2021-01-29") + trajectories$t)
})

test_that("Only times up to t_max are returned", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
trajectories <- mod$simulate_population_trajectories(summarise = TRUE, t_max = 10)
expect_true(all(trajectories$t <= 10))
})
57 changes: 57 additions & 0 deletions tests/testthat/test-simulate-population-trajectories.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
mock_model <- function(name, package) {
list(sample = function(x, ...) readRDS(test_path("testdata", "testdraws.rds")))
}

local_mocked_bindings(
stan_package_model = mock_model, .package = "instantiate"
)

test_that("Cannot retrieve trajectories until model is fitted", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"))
expect_error(mod$simulate_population_trajectories(), "Model has not been fitted yet. Call 'fit' before calling this function.")
})

test_that("Validates inputs", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
expect_error(mod$simulate_population_trajectories(summarise = "bad"), "'summarise' must be logical")
expect_error(mod$simulate_population_trajectories(n_draws = "bad"), "'n_draws' must be numeric")
expect_error(mod$simulate_population_trajectories(time_type = "bad"), "'time_type' must be one of 'relative' or 'absolute'")
expect_error(mod$simulate_population_trajectories(t_max = "bad"), "'t_max' must be numeric")
})

test_that("Can retrieve summarised trajectories", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
trajectories <- mod$simulate_population_trajectories(summarise = TRUE)
expect_equal(names(trajectories), c("t", "p", "k", "me", "lo", "hi", "infection_history", "titre_type"))
})

test_that("Can retrieve un-summarised trajectories", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
trajectories <- mod$simulate_population_trajectories(summarise = FALSE)
expect_equal(names(trajectories), c("t", "p", "k", ".draw", "t0_pop", "tp_pop", "ts_pop", "m1_pop", "m2_pop",
"m3_pop", "beta_t0", "beta_tp", "beta_ts", "beta_m1", "beta_m2",
"beta_m3", "mu", "infection_history", "titre_type"))
})

test_that("Absolute dates are returned if time_type is 'absolute'", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
trajectories <- mod$simulate_population_trajectories(summarise = TRUE, time_type = "absolute")
expect_equal(class(trajectories$date), c("IDate", "Date"))
expect_equal(trajectories$date, as.IDate("2021-01-29") + trajectories$t)
})

test_that("Only times up to t_max are returned", {
mod <- scova$new(file_path = system.file("delta_full.rds", package = "epikinetics"),
covariate_formula = ~0 + infection_history)
mod$fit()
trajectories <- mod$simulate_population_trajectories(summarise = TRUE, t_max = 10)
expect_true(all(trajectories$t <= 10))
})

0 comments on commit d84efd7

Please sign in to comment.