diff --git a/DESCRIPTION b/DESCRIPTION index 59eb49380..2fb2f44e2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Type: Package Package: insight Title: Easy Access to Model Information for Various Model Objects -Version: 0.20.2 +Version: 0.20.2.1 Authors@R: c(person(given = "Daniel", family = "Lüdecke", diff --git a/NAMESPACE b/NAMESPACE index 91477c592..a52271336 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -322,6 +322,7 @@ S3method(find_predictors,selection) S3method(find_random,afex_aov) S3method(find_random,default) S3method(find_response,bfsl) +S3method(find_response,brmsfit) S3method(find_response,default) S3method(find_response,joint) S3method(find_response,logitr) diff --git a/R/find_response.R b/R/find_response.R index 320c2936d..737fd5f4a 100644 --- a/R/find_response.R +++ b/R/find_response.R @@ -49,6 +49,35 @@ find_response.default <- function(x, combine = TRUE, ...) { } +#' @export +find_response.brmsfit <- function(x, combine = TRUE, ...) { + f <- find_formula(x, verbose = FALSE) + + if (is.null(f)) { + return(NULL) + } + + # this is for multivariate response models, + # where we have a list of formulas + if (is_multivariate(f)) { + resp <- unlist(lapply(f, function(i) { + resp_formula <- safe_deparse(i$conditional[[2L]]) + if (grepl("|", resp_formula, fixed = TRUE)) { + resp_formula <- all.vars(i$conditional[[2L]]) + } + resp_formula + })) + } else { + resp <- safe_deparse(f$conditional[[2L]]) + if (grepl("|", resp, fixed = TRUE)) { + resp <- all.vars(f$conditional[[2L]]) + } + } + + check_cbind(resp, combine, model = x) +} + + #' @export find_response.logitr <- function(x, ...) { get_call(x)$outcome diff --git a/tests/testthat/test-brms_missing.R b/tests/testthat/test-brms_missing.R new file mode 100644 index 000000000..137818b4b --- /dev/null +++ b/tests/testthat/test-brms_missing.R @@ -0,0 +1,35 @@ +skip_on_cran() +skip_if_offline() +skip_on_os("mac") +skip_if_not_installed("brms") +skip_if_not_installed("httr") + +# Model fitting ----------------------------------------------------------- + +miss_1 <- suppressWarnings(download_model("brms_miss_1")) +skip_if(is.null(miss_1)) + +# Tests ------------------------------------------------------------------- +test_that("get_response brms aterms-trials 1", { + expect_equal( + find_formula(miss_1), + structure( + list( + survived = list(conditional = survived ~ woman * mi(age) + passengerClass), + age = list(conditional = age | mi() ~ passengerClass + woman) + ), + is_mv = "1", + class = c("insight_formula", "list") + ), + ignore_attr = TRUE + ) + expect_identical( + find_response(miss_1), + c(survived = "survived", age = "age") + ) + expect_true(is_multivariate(miss_1)) + out <- get_response(miss_1) + expect_named(out, c("survived", "age")) + expect_equal(head(out$age), c(29, 0.9167, 2, 30, 25, 48), tolerance = 1e-4, ignore_attr = TRUE) + expect_equal(head(out$survived), c(1, 1, 0, 0, 0, 1), tolerance = 1e-4, ignore_attr = TRUE) +})