Skip to content

Commit

Permalink
Fix posterior_predict.stanemaxbin
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshidk6 committed Dec 6, 2024
1 parent 0ce3a6c commit 8effcaf
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: rstanemax
Version: 0.1.6
Version: 0.1.6.1
Title: Emax Model Analysis with 'Stan'
Description: Perform sigmoidal Emax model fit using 'Stan' in a formula notation, without writing 'Stan' model code.
Authors@R: c(
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@

# rstanemax 0.1.6.1

## Patch release

* Fix match.arg(newDataType) for posterior_predict.stanemaxbin
* Correctly return prediction for posterior_predict.stanemaxbin

# rstanemax 0.1.6

## Major changes
Expand Down
18 changes: 14 additions & 4 deletions R/posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ rstantools::posterior_predict
#'
#' For continuous endpoint model ([stan_emax()]),
#'
#' - `respHat`: prediction without considering residual variability and is intended to provide credible interval of "mean" response.
#' - `response`: include residual variability in its calculation, therefore the range represents prediction interval of observed response.
#' - `.linpred` & `.epred`: prediction without considering residual variability and is intended to provide credible interval of "mean" response.
#' - `.prediction`: include residual variability in its calculation, therefore the range represents prediction interval of observed response.
#' - (deprecated) `respHat`: replaced by `.linpred` & `.epred`
#' - (deprecated) `response`: replaced by `.prediction`
#'
#' For binary endpoint model ([stan_emax_binary()]),
#'
#' - `.linpred`: predicted probability on logit scale
#' - `.epred`: predicted probability on probability scale
#' - `.prediction`: predicted event (1) or non-event (0)
#'
#' The return object also contains exposure and parameter values used for calculation.
NULL
Expand Down Expand Up @@ -126,6 +129,7 @@ posterior_predict.stanemaxbin <- function(
newDataType = c("raw", "modelframe"),
...) {
returnType <- match.arg(returnType)
newDataType <- match.arg(newDataType)

if (is.null(newdata)) {
df.model <- object$prminput$df.model
Expand Down Expand Up @@ -158,7 +162,7 @@ posterior_predict.stanemaxbin <- function(
)

if (returnType == "matrix") {
return(matrix(pred.response$.epred, ncol = nrow(df.model), byrow = TRUE))
return(matrix(pred.response$.prediction, ncol = nrow(df.model), byrow = TRUE))
} else if (returnType == "dataframe") {
return(as.data.frame(pred.response))
} else if (returnType == "tibble") {
Expand Down Expand Up @@ -191,13 +195,19 @@ pp_calc <- function(stanfit, df.model,
respHat = e0 + emax * exposure^gamma / (ec50^gamma + exposure^gamma),
response = stats::rnorm(respHat, respHat, sigma)
) %>%
dplyr::mutate(
.linpred = respHat,
.epred = respHat,
.prediction = response
) %>%
dplyr::select(mcmcid, exposure, dplyr::everything())
} else if (mod_type == "stanemaxbin") {
out <-
df %>%
dplyr::mutate(
.linpred = e0 + emax * exposure^gamma / (ec50^gamma + exposure^gamma),
.epred = 1 / (1 + exp(-.linpred))
.epred = 1 / (1 + exp(-.linpred)),
.prediction = stats::rbinom(.epred, 1, .epred)
) %>%
dplyr::select(mcmcid, exposure, dplyr::everything())
}
Expand Down
7 changes: 5 additions & 2 deletions man/posterior_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test_that("posterior prediction with new data with covariates", {

# Make sure posterior_predict works with covariates
test.pp.tibble <- posterior_predict.stanemax(test.fit.2cov, newdata = test.data.short, returnType = "tibble")
expect_equal(dim(test.pp.tibble), c(30000, 13))
expect_equal(dim(test.pp.tibble), c(30000, 16))

# Make sure data is not re-sorted
expect_equal(
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-stan_emax_binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ test_that("emax model run", {


test_that("posterior prediction with original data", {
set.seed(1234)
test.pp.matrix <- posterior_predict(test.fit2)
test.pp.df <- posterior_predict(test.fit2, returnType = "dataframe")

Expand Down

0 comments on commit 8effcaf

Please sign in to comment.