Skip to content

Commit

Permalink
Merge branch 'main' into more_plots
Browse files Browse the repository at this point in the history
  • Loading branch information
hillalex authored Dec 2, 2024
2 parents 0d6660f + cf7a1e2 commit 9049cca
Show file tree
Hide file tree
Showing 21 changed files with 415 additions and 71 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ Imports:
logger,
mosaic,
pkgload,
plotly,
R6,
shiny,
stats,
stringr,
tidybayes
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ importFrom(ggplot2,scale_x_continuous)
importFrom(ggplot2,scale_x_date)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,sec_axis)
importFrom(ggplot2,theme)
importFrom(ggplot2,unit)
useDynLib(epikinetics, .registration = TRUE)
32 changes: 19 additions & 13 deletions R/biokinetics.R
Original file line number Diff line number Diff line change
Expand Up @@ -347,37 +347,43 @@ biokinetics <- R6::R6Class(
#' this plot is of the data as provided to the Stan model so is on a log scale,
#' regardless of whether data was provided on a log or a natural scale.
#' @param tmax Integer. Maximum time since last exposure to include. Default 150.
#' @param ncol Optional number of cols to display facets in.
#' @return A ggplot2 object.
plot_model_inputs = function(tmax = 150) {
plot_model_inputs = function(tmax = 150, ncol = NULL) {
plot_sero_data(private$data,
tmax = tmax,
ncol = ncol,
covariates = private$all_formula_vars,
upper_censoring_limit = private$stan_input_data$upper_censoring_limit,
lower_censoring_limit = private$stan_input_data$lower_censoring_limit)
},
#' @description View the data that is passed to the stan model, for debugging purposes.
#' @return A list of arguments that will be passed to the stan model.
#' @description Opens an RShiny app to help with model diagnostics.
inspect = function() {
inspect_model(self, private)
},
#' @description View the data that is passed to the stan model, for debugging purposes.
#' @return A list of arguments that will be passed to the stan model.
get_stan_data = function() {
private$stan_input_data
},
#' @description View the mapping of human readable covariate names to the model variable p.
#' @return A data.table mapping the model variable p to human readable covariates.
#' @description View the mapping of human readable covariate names to the model variable p.
#' @return A data.table mapping the model variable p to human readable covariates.
get_covariate_lookup_table = function() {
private$covariate_lookup_table
},
#' @description Fit the model and return CmdStanMCMC fitted model object.
#' @return A CmdStanMCMC fitted model object: <https://mc-stan.org/cmdstanr/reference/CmdStanMCMC.html>
#' @param ... Named arguments to the `sample()` method of CmdStan model.
#' objects: <https://mc-stan.org/cmdstanr/reference/model-method-sample.html>
#' @description Fit the model and return CmdStanMCMC fitted model object.
#' @return A CmdStanMCMC fitted model object: <https://mc-stan.org/cmdstanr/reference/CmdStanMCMC.html>
#' @param ... Named arguments to the `sample()` method of CmdStan model.
#' objects: <https://mc-stan.org/cmdstanr/reference/model-method-sample.html>
fit = function(...) {
logger::log_info("Fitting model")
private$fitted <- private$model$sample(private$stan_input_data, ...)
private$fitted
},
#' @description Extract fitted population parameters
#' @return A data.table
#' @param n_draws Integer. Default 2000.
#' @param human_readable_covariates Logical. Default TRUE.
#' @description Extract fitted population parameters
#' @return A data.table
#' @param n_draws Integer. Default 2000.
#' @param human_readable_covariates Logical. Default TRUE.
extract_population_parameters = function(n_draws = 2000,
human_readable_covariates = TRUE) {
private$check_fitted()
Expand Down
6 changes: 3 additions & 3 deletions R/epikinetics-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#' @importFrom data.table .NGRP
#' @importFrom data.table .SD
#' @importFrom data.table data.table
#' @importFrom ggplot2 aes annotate facet_wrap geom_point geom_ribbon geom_line geom_smooth geom_bar geom_density_2d
#' geom_vline geom_hline geom_path labs ggplot guides guide_legend scale_y_continuous
#' scale_x_continuous scale_x_date sec_axis
#' @importFrom ggplot2 aes annotate facet_wrap geom_point geom_ribbon geom_line geom_smooth geom_bar
#' geom_vline geom_hline geom_path labs ggplot guides guide_legend scale_y_continuous theme unit
#' geom_density_2d scale_x_continuous scale_x_date sec_axis
#' @useDynLib epikinetics, .registration = TRUE
## usethis namespace: end

Expand Down
199 changes: 199 additions & 0 deletions R/inspect-model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
inspect_model <- function(mod, private) {

prior_inputs <- function(name, description) {
mu <- paste("mu", name, sep = "_")
sigma <- paste("sigma", name, sep = "_")
shiny::div(shiny::fluidRow(
shiny::column(4,
description
),
shiny::column(4,
shiny::fluidRow(class = "form-group",
shiny::tags$label(paste0("mean (", mu, ")"), class = "col-sm-6 col-form-label text-right"),
shiny::column(6, raw_numeric_input(mu, value = private$priors[[mu]]))
)
),
shiny::column(4,
shiny::fluidRow(class = "form-group",
shiny::tags$label(paste0("SD (", sigma, ")"), class = "col-sm-6 col-form-label text-right"),
shiny::column(6, raw_numeric_input(sigma, value = private$priors[[sigma]])),
)
))
)
}

all_covariates <- c("None", detect_covariates(private$data))

ui <- shiny::fluidPage(style = "margin: 0.5em",
shiny::fluidRow(
shiny::column(5,
shiny::h3("Prior predictive check"),
plotly::plotlyOutput("prior_predicted"),
shiny::tags$pre(style = "overflow: hidden; text-wrap: auto; word-break: keep-all; white-space: pre-line; margin-top: 20px;",
shiny::textOutput("prior_code", inline = TRUE)
),
prior_inputs("t0", "Baseline titre value"),
prior_inputs("tp", "Time to peak titre"),
prior_inputs("ts", "Time to start of waning"),
prior_inputs("m1", "Boosting rate"),
prior_inputs("m2", "Plateau rate"),
prior_inputs("m3", "Waning rate")
),
shiny::column(7,
shiny::h3("Model input data"),
shiny::uiOutput(
"data_plot"
),
shiny::div(style = "margin-top: 20px;",
shiny::fluidRow(class = "form-group",
shiny::column(2,
shiny::numericInput("ncol", label = "Number of columns", value = 3)
),
shiny::column(3,
shiny::selectInput("covariate", "Facet by",
choices = all_covariates,
selected = "None",
selectize = FALSE)
),
shiny::column(7,
shiny::div(class = "form-group",
shiny:::shinyInputLabel("filter", "Filter by"),
shiny::fluidRow(
shiny::column(5,
raw_select_input("filter",
choices = all_covariates,
selected = "None")
),
shiny::column(1, style = "padding-top: 5px;", "~="),
shiny::column(5,
raw_text_input("filter_value", placeholder = "regex")
)
)
)
)
)
)
)
),
shiny::fluidRow(style = "margin-top: 20px;",
shiny::column(12,
shiny::h3(shiny::textOutput("fitted"))
)
)
)

server <- function(input, output, session) {
# priors
prior <- shiny::reactive(
biokinetics_priors(mu_t0 = input$mu_t0, mu_tp = input$mu_tp,
mu_ts = input$mu_ts, mu_m1 = input$mu_m1,
mu_m2 = input$mu_m2, mu_m3 = input$mu_m3,
sigma_t0 = input$sigma_t0, sigma_tp = input$sigma_tp,
sigma_ts = input$sigma_ts, sigma_m1 = input$sigma_m1,
sigma_m2 = input$sigma_m2, sigma_m3 = input$sigma_m3)
)
output$prior_code <- shiny::renderText({
prior_code(input)
})
output$prior_predicted <- plotly::renderPlotly({
plotly::style(plotly::ggplotly(plot(prior(),
data = private$data,
upper_censoring_limit = private$stan_input_data$upper_censoring_limit,
lower_censoring_limit = private$stan_input_data$lower_censoring_limit)), textposition = "right")
})

# model inputs
cols <- shiny::reactive({
if (is.na(input$ncol)) {
return(NULL)
} else {
return(input$ncol)
}
})

selected_covariate <- shiny::reactive({
input$covariate
})

filter <- shiny::reactive({
input$filter
})

filter_value <- shiny::reactive({
input$filter_value
})

data <- shiny::reactive({
if (filter_value() != "" &&
!is.null(filter()) &&
filter() != "None") {
return(private$data[grepl(filter_value(), get(filter()), ignore.case = TRUE)])
} else {
return(private$data)
}
})

plot_inputs <- shiny::reactive({
selected <- selected_covariate()
if (is.null(selected) || selected == "None") {
selected <- character(0)
}
plot_sero_data(data(),
ncol = cols(),
covariates = selected,
upper_censoring_limit = private$stan_input_data$upper_censoring_limit,
lower_censoring_limit = private$stan_input_data$lower_censoring_limit) +
theme(plot.margin = unit(c(1, 0, 0, 0), "cm"))
})

output$data <- plotly::renderPlotly({
if (nrow(data()) > 0) {
gp <- plotly::style(plotly::ggplotly(plot_inputs()), textposition = "right")
if (selected_covariate() != "None") {
return(facet_strip_bigger(gp, 30))
} else {
return(gp)
}
}
})

output$data_plot <- shiny::renderUI({
if (nrow(data()) > 0) {
plotly::plotlyOutput("data")
} else {
shiny::h3("No rows selected. Please change your filter.")
}
})

# model outputs
output$fitted <- shiny::renderText({
if (is.null(private$fitted)) {
"Model has not been fitted yet. Once fitted, inspect the model again to see posterior predictions."
}
})
}

logger::log_info(
"Starting Shiny app for model review; use Ctrl + C to quit"
)
shiny::runApp(
shiny::shinyApp(ui, server),
quiet = TRUE,
launch.browser = shiny::paneViewer()
)
invisible()
}

# plotly can't handle multi-line facet titles, so manually make
# the facet titles a little bigger when there are covariates
facet_strip_bigger <- function(gp, size) {

n_facets <- c(1:length(gp[["x"]][["layout"]][["shapes"]]))

for (i in n_facets) {
gp[["x"]][["layout"]][["shapes"]][[i]][["y0"]] <- +as.numeric(size)
gp[["x"]][["layout"]][["shapes"]][[i]][["y1"]] <- 0
}

return(gp)
}
12 changes: 6 additions & 6 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ plot.biokinetics_priors <- function(x,
geom_point(data = dat, size = 0.5,
aes(x = time_since_last_exp,
y = value))

plot <- add_limits(plot, upper_censoring_limit, lower_censoring_limit)
}
plot
add_limits(plot, upper_censoring_limit, lower_censoring_limit)
}

#' @title Plot serological data
Expand All @@ -69,11 +67,13 @@ plot.biokinetics_priors <- function(x,
#' @return A ggplot2 object.
#' @param data A data.table with required columns time_since_last_exp, value and titre_type.
#' @param tmax Integer. The number of time points in each simulated trajectory. Default 150.
#' @param ncol Integer. Optional number of columns to display facets in.
#' @param covariates Optional vector of covariate names to facet by (these must correspond to columns in the data.table)
#' @param upper_censoring_limit Optional upper detection limit.
#' @param lower_censoring_limit Optional lower detection limit.
plot_sero_data <- function(data,
tmax = 150,
ncol = NULL,
covariates = character(0),
upper_censoring_limit = NULL,
lower_censoring_limit = NULL) {
Expand All @@ -87,7 +87,7 @@ plot_sero_data <- function(data,
geom_point(aes(x = time_since_last_exp, y = value, colour = titre_type),
size = 0.5, alpha = 0.5) +
geom_smooth(aes(x = time_since_last_exp, y = value, colour = titre_type)) +
facet_wrap(eval(parse(text = facet_formula(covariates)))) +
facet_wrap(eval(parse(text = facet_formula(covariates))), ncol = ncol) +
guides(colour = guide_legend(title = "Titre type"))

add_limits(plot, upper_censoring_limit, lower_censoring_limit)
Expand Down Expand Up @@ -305,7 +305,7 @@ add_limits <- function(plot, upper_censoring_limit, lower_censoring_limit) {
linetype = 'dotted') +
annotate("text", x = 1,
y = lower_censoring_limit,
label = "Lower detection limit",
label = "Lower censoring limit",
vjust = -0.5,
hjust = 0,
size = 3)
Expand All @@ -316,7 +316,7 @@ add_limits <- function(plot, upper_censoring_limit, lower_censoring_limit) {
linetype = 'dotted') +
annotate("text", x = 1,
y = upper_censoring_limit,
label = "Upper detection limit",
label = "Upper censoring limit",
vjust = -0.5,
hjust = 0,
size = 3)
Expand Down
44 changes: 44 additions & 0 deletions R/shiny-utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
raw_numeric_input <- function(inputId, value, min = NA, max = NA, step = NA) {
value <- shiny::restoreInput(id = inputId, default = value)
inputTag <- shiny::tags$input(id = inputId, type = "number", class = "shiny-input-number form-control", value = shiny:::formatNoSci(value))
if (!is.na(min)) inputTag$attribs$min <- min
if (!is.na(max)) inputTag$attribs$max <- max
if (!is.na(step)) inputTag$attribs$step <- step
inputTag
}

raw_text_input <- function(inputId, value = "", placeholder = NULL) {
value <- shiny::restoreInput(id = inputId, default = value)
shiny::tags$input(id = inputId, type = "text", class = "shiny-input-text form-control", value = value, placeholder = placeholder)
}

raw_select_input <- function(inputId, choices, selected = NULL, multiple = FALSE) {
selected <- shiny::restoreInput(id = inputId, default = selected)
choices <- shiny:::choicesWithNames(choices)
if (is.null(selected)) {
if (!multiple) selected <- shiny:::firstChoice(choices)
} else selected <- as.character(selected)
shiny::tags$select(id = inputId, class = "shiny-input-select", class = "form-control", shiny:::selectOptions(choices, selected, inputId))
}


prior_code <- function(input) {
deparse(substitute(biokinetics_priors(mu_t0 = a, mu_tp = b,
mu_ts = c, mu_m1 = d,
mu_m2 = e, mu_m3 = f,
sigma_t0 = g, sigma_tp = h,
sigma_ts = i, sigma_m1 = j,
sigma_m2 = k, sigma_m3 = l),
list(a = input$mu_t0, b = input$mu_tp,
c = input$mu_ts, d = input$mu_m1,
e = input$mu_m2, f = input$mu_m3,
g = input$sigma_t0, h = input$sigma_tp,
i = input$sigma_ts, j = input$sigma_m1,
k = input$sigma_m2, l = input$sigma_m3)), width.cutoff = 500L)
}

detect_covariates <- function(data) {
setdiff(colnames(data), c("pid", "day", "last_exp_day",
"titre_type", "value", "censored",
"obs_id", "time_since_last_exp"))
}
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ To run all tests locally:
devtools::test()
```

To run tests in a single file:

```{r}
devtools::test(filter="filename")
```

Some tests are skipped on CI to avoid exorbitantly long build times, but this means
it is important to run all tests locally at least once before merging a pull request.

Expand Down
Loading

0 comments on commit 9049cca

Please sign in to comment.