Skip to content

Commit

Permalink
add prior predictive plot
Browse files Browse the repository at this point in the history
  • Loading branch information
hillalex committed Oct 21, 2024
1 parent 8155ae7 commit 9842a31
Show file tree
Hide file tree
Showing 10 changed files with 2,461 additions and 2 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ Suggests:
VignetteBuilder: knitr
LinkingTo:
cpp11
Config/testthat/edition: 3
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export(add_exposure_data)
export(biokinetics)
export(biokinetics_priors)
export(convert_log2_scale_inverse)
export(plot_prior_predictive)
importFrom(R6,R6Class)
importFrom(data.table,":=")
importFrom(data.table,.BY)
Expand Down
11 changes: 11 additions & 0 deletions R/biokinetics.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ biokinetics <- R6::R6Class(
package = "epikinetics"
)
},
#' @description Plot the kinetics trajectory predicted by the model priors.
#' @return A ggplot2 object.
#' @param tmax Integer. The number of time points in each simulated trajectory. Default 150.
#' @param n_draws Integer. The number of trajectories to simulate. Default 2000.
plot_prior_predictive = function(tmax = 150,
n_draws = 2000) {
plot_prior_predictive(private$priors,
tmax = tmax,
n_draws = n_draws,
data = private$data)
},
#' @description View the data that is passed to the stan model, for debugging purposes.
#' @return A list of arguments that will be passed to the stan model.
get_stan_data = function() {
Expand Down
48 changes: 48 additions & 0 deletions R/plot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#' @title Simulate biomarker kinetics predicted by the given biokinetics priors
#' and optionally compare to a dataset.
#' @export
#' @description Simulate trajectories by drawing random samples from the given
#' priors for each parameter in the biokinetics model.
#' @return A ggplot2 object.
#' @param priors A named list of type 'biokinetics_priors'.
#' @param tmax Integer. The number of time points in each simulated trajectory. Default 150.
#' @param n_draws Integer. The number of trajectories to simulate. Default 2000.
#' @param data Optional data.frame with columns t_since_last_exp and value. The raw data to compare to.
plot_prior_predictive <- function(priors,
tmax = 150,
n_draws = 2000,
data = NULL) {
validate_priors(priors)
if (!is.null(data)) {
validate_required_cols(data, c("t_since_last_exp", "value"))
}
params <- data.table(
t0 = rnorm(n_draws, priors$mu_t0, priors$sigma_t0), # titre value at t0
tp = rnorm(n_draws, priors$mu_tp, priors$sigma_tp), # time of peak
ts = rnorm(n_draws, priors$mu_ts, priors$sigma_ts), # time of set point
m1 = rnorm(n_draws, priors$mu_m1, priors$sigma_m1), # gradient 1
m2 = rnorm(n_draws, priors$mu_m2, priors$sigma_m2), # gradient 2
m3 = rnorm(n_draws, priors$mu_m3, priors$sigma_m3) # gradient 3
)

times <- data.table(t = 1:tmax)
params_and_times <- times[, as.list(params), by = times]

params_and_times[, mu := biokinetics_simulate_trajectory(t, t0, tp, ts, m1, m2, m3),
by = c("t", "t0", "tp", "ts", "m1", "m2", "m3")]

summary <- params_and_times %>%
group_by(t) %>%
summarise(me = quantile(mu, 0.5, names = FALSE),
lo = quantile(mu, 0.025, names = FALSE),
hi = quantile(mu, 0.975, names = FALSE))

plot <- ggplot(summary) +
geom_line(aes(x = t, y = me)) +
geom_ribbon(aes(x = t, ymin = lo, ymax = hi), alpha = 0.5)

if (!is.null(data)) {
plot <- plot + geom_point(data = data, aes(x = t_since_last_exp, y = value))
}
plot
}
23 changes: 23 additions & 0 deletions man/biokinetics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/plot_prior_predictive.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/testthat.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
library(testthat)
library(epikinetics)
library(vdiffr)

test_check("epikinetics")
2,310 changes: 2,310 additions & 0 deletions tests/testthat/_snaps/plots/priorpredictive.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 41 additions & 0 deletions tests/testthat/test-plots.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
test_that("Can plot prior prediction up to tmax", {
priors <- biokinetics_priors()
plot <- plot_prior_predictive(priors, tmax = 100, n_draws = 500)
expect_equal(nrow(plot$data), 100)
expect_equal(length(plot$layers), 2)
})

test_that("Can plot prior prediction with data points", {
data <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
priors <- biokinetics_priors()
expect_error(plot_prior_predictive(priors, data = data), "Missing required columns: t_since_last_exp")
data[, `:=`(t_since_last_exp = as.integer(day - last_exp_day, units = "days"))]
plot <- plot_prior_predictive(priors, data = data, n_draws = 500)
expect_equal(length(plot$layers), 3)
})

test_that("Can plot prior predictions from model", {
data <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
priors <- biokinetics_priors(mu_values = c(4.1, 11, 65, 0.2, -0.01, 0.01),
sigma_values = c(2.0, 2.0, 3.0, 0.01, 0.01, 0.001))

mod <- biokinetics$new(priors = priors,
data = data)
set.seed(1)
plot <- mod$plot_prior_predictive(tmax = 400, n_draws = 500)
expect_equal(nrow(plot$data), 400)
expect_equal(length(plot$layers), 3)
})

test_that("Prior predictions from model are the same", {
skip_on_ci()
data <- data.table::fread(system.file("delta_full.rds", package = "epikinetics"))
priors <- biokinetics_priors(mu_values = c(4.1, 11, 65, 0.2, -0.01, 0.01),
sigma_values = c(2.0, 2.0, 3.0, 0.01, 0.01, 0.001))

mod <- biokinetics$new(priors = priors,
data = data)
set.seed(1)
plot <- mod$plot_prior_predictive(tmax = 400, n_draws = 500)
vdiffr::expect_doppelganger("priorpredictive", plot)
})
2 changes: 0 additions & 2 deletions tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ delta <- mod$fit(parallel_chains = 4,
iter_sampling = 100,
seed = 100)

local_edition(3)

test_that("Model fits are the same", {
skip_on_ci()
expect_snapshot(delta)
Expand Down

0 comments on commit 9842a31

Please sign in to comment.