Skip to content

Commit

Permalink
Merge pull request #31 from seroanalytics/priors
Browse files Browse the repository at this point in the history
refactor priors
  • Loading branch information
hillalex authored Nov 15, 2024
2 parents edeefbb + 85979fc commit 1c2c50a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 42 deletions.
37 changes: 16 additions & 21 deletions R/priors.R
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
gaussian_priors <- function(names, mu_values, sigma_values) {
if (length(names) != length(mu_values) ||
length(names) != length(sigma_values)) {
stop("The lengths of the vectors do not match.")
}
mu_names <- paste0("mu_", names)
sigma_names <- paste0("sigma_", names)
ret <- as.list(c(mu_values, sigma_values))
names(ret) <- c(mu_names, sigma_names)
class(ret) <- append("gaussian_priors", class(ret))
ret
}

#' @title Construct priors for the biomarker model.
#' @export
#' @description The biokinetics model has 6 parameters: t0, tp, ts, m1, m2, m3 corresponding to critical time points and
#' gradients. See the model vignette for details: \code{vignette("model", package = "epikinetics")}. Each of these
#' parameters has a Gaussian prior, and these can be specified by the user. This function takes means and standard
#' deviations for each prior and constructs an object of type 'biokinetics_priors' to be passed to the model.
#' @return A named list of type 'biokinetics_priors'.
#' @param mu_values Mean of Gaussian prior for each of t0, tp, ts, m1, m2, m3, in order.
#' @param sigma_values Standard deviation of Gaussian prior for each of t0, tp, ts, m1, m2, m3, in order.
#' @param mu_t0 Numeric. Mean for t0, baseline titre value. Default 4.0.
#' @param mu_tp Numeric. Mean for tp, time at peak titre. Default 10.
#' @param mu_ts Numeric. Mean for ts, time at start of warning. Default 60.
#' @param mu_m1 Numeric. Mean for m1, boosting rate. Default 0.25.
#' @param mu_m2 Numeric. Mean for m2, plateau rate. Default 0.25.
#' @param mu_m3 Numeric. Mean for m3, waning rate. Default -0.02.
#' @param sigma_t0 Numeric. Standard deviation for t0, baseline titre value. Default 2.0.
#' @param sigma_tp Numeric. Standard deviation for tp, time at peak titre. Default 2.0.
#' @param sigma_ts Numeric. Standard deviation for ts, time at start of warning. Default 3.0.
#' @param sigma_m1 Numeric. Standard deviation for m1, boosting rate. Default 0.01.
#' @param sigma_m2 Numeric. Standard deviation for m2, plateau rate. Default 0.01.
#' @param sigma_m3 Numeric. Standard deviation for m3, waning rate. Default 0.01.
#' @examples
#' priors <- biokinetics_priors(mu_values = c(4.0, 10, 60, 0.25, -0.02, 0),
#' sigma_values = c(2.0, 2.0, 3.0, 0.01, 0.01, 0.01))
biokinetics_priors <- function(mu_values = c(4.0, 10, 60, 0.25, -0.02, 0),
sigma_values = c(2.0, 2.0, 3.0, 0.01, 0.01, 0.01)) {
names <- c("t0", "tp", "ts", "m1", "m2", "m3")
ret <- gaussian_priors(names, mu_values, sigma_values)
#' priors <- biokinetics_priors(mu_t0 = 5.0, mu_ts = 61)
biokinetics_priors <- function(mu_t0 = 4.0, mu_tp = 10, mu_ts = 60, mu_m1 = 0.25, mu_m2 = -0.02, mu_m3 = 0,
sigma_t0 = 2.0, sigma_tp = 2.0, sigma_ts = 3.0, sigma_m1 = 0.01, sigma_m2 = 0.01, sigma_m3 = 0.01) {
ret <- as.list(environment())
class(ret) <- append("biokinetics_priors", class(ret))
ret
}
41 changes: 35 additions & 6 deletions man/biokinetics_priors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ test_that("Can provide data directly", {

test_that("Can construct stan data", {
dat <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
priors <- biokinetics_priors(mu_values = c(1, 2, 3, 4, 5, 6),
sigma_values = c(0.1, 0.2, 0.3, 0.4, 0.5, 0.6))
priors <- biokinetics_priors(1, 2, 3, 4, 5, 6,
0.1, 0.2, 0.3, 0.4, 0.5, 0.6)
mod <- biokinetics$new(data = dat, priors = priors)
stan_data <- mod$get_stan_data()
expect_true(is.list(stan_data))
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test-plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ test_that("Can plot prior prediction with data points", {

test_that("Can plot prior predictions from model", {
data <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
priors <- biokinetics_priors(mu_values = c(4.1, 11, 65, 0.2, -0.01, 0.01),
sigma_values = c(2.0, 2.0, 3.0, 0.01, 0.01, 0.001))
priors <- biokinetics_priors(4.1, 11, 65, 0.2, -0.01, 0.01,
2.0, 2.0, 3.0, 0.01, 0.01, 0.001)

mod <- biokinetics$new(priors = priors,
data = data)
Expand All @@ -29,8 +29,8 @@ test_that("Can plot prior predictions from model", {

test_that("Prior predictions from model are the same", {
data <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
priors <- biokinetics_priors(mu_values = c(4.1, 11, 65, 0.2, -0.01, 0.01),
sigma_values = c(2.0, 2.0, 3.0, 0.01, 0.01, 0.001))
priors <- biokinetics_priors(4.1, 11, 65, 0.2, -0.01, 0.01,
2.0, 2.0, 3.0, 0.01, 0.01, 0.001)

mod <- biokinetics$new(priors = priors,
data = data)
Expand Down
10 changes: 1 addition & 9 deletions tests/testthat/test-priors.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
test_that("Can construct named list of Gaussian prior parameters", {
priors <- gaussian_priors(names = c("a", "b"), mu_values = c(0.1, 0.2), sigma_values = c(0.5, 0.6))
expect_s3_class(priors, "gaussian_priors")
expect_true(is.list(priors))
expect_equal(unclass(priors), list("mu_a" = 0.1, "mu_b" = 0.2, "sigma_a" = 0.5, "sigma_b" = 0.6))
})

test_that("Can construct cab prior parameters", {
priors <- biokinetics_priors(mu_values = c(1, 2, 3, 4, 5, 6), sigma_values = c(7, 8, 9, 10, 11, 12))
expect_s3_class(priors, "gaussian_priors")
priors <- biokinetics_priors(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
expect_s3_class(priors, "biokinetics_priors")
expect_true(is.list(priors))
expect_equal(unclass(priors), list("mu_t0" = 1, "mu_tp" = 2, "mu_ts" = 3,
Expand Down

0 comments on commit 1c2c50a

Please sign in to comment.