Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the brmsframe layer #1653

Merged
merged 18 commits into from
May 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make rename_pars compatible with brmsframe
  • Loading branch information
paul-buerkner committed May 7, 2024
commit 3baef7fe0f208e7389dfbc5a38159983960ab94b
19 changes: 15 additions & 4 deletions R/brmsframe.R
Original file line number Diff line number Diff line change
@@ -23,6 +23,9 @@ brmsframe.mvbrmsterms <- function(x, data, old_levels = NULL, ...) {
#' @export
brmsframe.brmsterms <- function(x, data, frame = NULL,
old_levels = NULL, ...) {
x$sdata <- list(
resp = data_response(x, data, check_response = FALSE)
)
if (is.null(frame)) {
# this is a univariate model so brmsterms is at the top level
x$frame <- list(
@@ -53,19 +56,20 @@ brmsframe.btl <- function(x, data, frame = NULL, ...) {
x$sdata <- list(
fe = data_fe(x, data),
sm = data_sm(x, data),
cs = data_cs(x, data),
gp = data_gp(x, data, internal = TRUE),
offset = data_offset(x, data)
)
px <- check_prefix(x)
x$frame <- list(
fe = frame_fe(x),
re = subset2(frame$re, ls = px),
sp = tidy_spef(x, data),
sp = tidy_spef(x, data = data),
me = frame$me,
cs = colnames(get_model_matrix(x$cs, data = data)),
gp = tidy_gpef(x, data),
cs = frame_cs(x),
gp = tidy_gpef(x, data = data),
sm = tidy_smef(x),
ac = tidy_acef(x)
ac = tidy_acef(x, data = data)
)
class(x) <- c("bfrl", class(x))
x
@@ -91,6 +95,7 @@ brmsframe.default <- function(x, ...) {

frame_resp <- function(x, data, ....) {
stopifnot(is.brmsterms(x))
# TODO use sdata$resp info
out <- list(
values = model.response(model.frame(x$respform, data, na.action = na.pass)),
bounds = trunc_bounds(x, data),
@@ -114,6 +119,12 @@ frame_fe <- function(x) {
out
}

frame_cs <- function(x) {
stopifnot(is.btl(x), !is.null(x$sdata))
out <- list(vars = colnames(x$sdata$cs$Xcs))
out
}

frame_cnl <- function(x, ...) {
stopifnot(is.btnl(x), !is.null(x$sdata))
covars <- all.vars(x$covars)
8 changes: 7 additions & 1 deletion R/formula-ac.R
Original file line number Diff line number Diff line change
@@ -475,7 +475,7 @@ tidy_acef.brmsterms <- function(x, ...) {
}

#' @export
tidy_acef.btl <- function(x, ...) {
tidy_acef.btl <- function(x, data = NULL, ...) {
form <- x[["ac"]]
if (!is.formula(form)) {
return(empty_acef())
@@ -558,6 +558,12 @@ tidy_acef.btl <- function(x, ...) {
stop2("Explicit covariance terms can only be specified on 'mu'.")
}
}
if (!is.null(data)) {
# optional such that this function can be applied
# without data before brmsframe is being created
time <- get_ac_vars(out, "time", dim = "time")
attr(out, "times") <- extract_levels(get(time, data))
}
out
}

11 changes: 3 additions & 8 deletions R/priors.R
Original file line number Diff line number Diff line change
@@ -546,16 +546,14 @@ default_prior.default <- function(object, data, family = gaussian(), autocor = N
# @param internal return priors for internal use?
# @return a brmsprior object
.default_prior <- function(bterms, internal = FALSE, ...) {
# ranef <- tidy_ranef(bterms, data)
# meef <- tidy_meef(bterms, data)
# initialize output
prior <- empty_prior()
# priors for distributional parameters
prior <- prior + prior_predictor(bterms, internal = internal)
# priors of group-level parameters
prior <- prior + prior_re(bterms, internal = internal)
# priors for noise-free variables
prior <- prior + prior_me(bterms, internal = internal)
prior <- prior + prior_Xme(bterms, internal = internal)
# explicitly label default priors as such
prior$source <- "default"
# apply 'unique' as the same prior may have been included multiple times
@@ -606,7 +604,6 @@ prior_predictor.mvbrmsterms <- function(x, internal = FALSE, ...) {

#' @export
prior_predictor.brmsterms <- function(x, internal = FALSE, ...) {
# data <- subset_data(data, x)
def_scale_prior <- def_scale_prior(x)
valid_dpars <- valid_dpars(x)
prior <- empty_prior()
@@ -771,7 +768,6 @@ prior_bhaz <- function(bterms, ...) {
prior_sp <- function(bterms, ...) {
prior <- empty_prior()
spef <- bterms$frame$sp
#spef <- tidy_spef(bterms, data)
if (nrow(spef)) {
px <- check_prefix(bterms)
prior <- prior + brmsprior(
@@ -791,8 +787,7 @@ prior_sp <- function(bterms, ...) {
# priors for category spcific effects parameters
prior_cs <- function(bterms, ...) {
prior <- empty_prior()
# csef <- colnames(get_model_matrix(bterms$cs, data = data))
csef <- bterms$frame$cs
csef <- bterms$frame$cs$vars
if (length(csef)) {
px <- check_prefix(bterms)
prior <- prior +
@@ -802,7 +797,7 @@ prior_cs <- function(bterms, ...) {
}

# default priors for hyper-parameters of noise-free variables
prior_me <- function(bterms, internal = FALSE, ...) {
prior_Xme <- function(bterms, internal = FALSE, ...) {
meef <- bterms$frame$me
prior <- empty_prior()
if (!NROW(meef)) {
75 changes: 36 additions & 39 deletions R/rename_pars.R
Original file line number Diff line number Diff line change
@@ -30,14 +30,13 @@ rename_pars <- function(x) {
if (!length(x$fit@sim)) {
return(x)
}
bterms <- brmsterms(x$formula)
meef <- tidy_meef(bterms, data = x$data)
bframe <- brmsframe(x$formula, x$data)
pars <- variables(x)
# find positions of parameters and define new names
to_rename <- c(
rename_predictor(bterms, data = x$data, pars = pars, prior = x$prior),
rename_re(x$ranef, pars = pars),
rename_Xme(meef, pars = pars)
rename_predictor(bframe, pars = pars, prior = x$prior),
rename_re(bframe, pars = pars),
rename_Xme(bframe, pars = pars)
)
# perform the actual renaming in x$fit@sim
x <- save_old_par_order(x)
@@ -60,10 +59,10 @@ rename_predictor.default <- function(x, ...) {
}

#' @export
rename_predictor.mvbrmsterms <- function(x, data, pars, ...) {
rename_predictor.mvbrmsterms <- function(x, pars, ...) {
out <- list()
for (i in seq_along(x$terms)) {
c(out) <- rename_predictor(x$terms[[i]], data = data, pars = pars, ...)
c(out) <- rename_predictor(x$terms[[i]], pars = pars, ...)
}
if (x$rescor) {
rescor_names <- get_cornames(
@@ -75,20 +74,19 @@ rename_predictor.mvbrmsterms <- function(x, data, pars, ...) {
}

#' @export
rename_predictor.brmsterms <- function(x, data, ...) {
data <- subset_data(data, x)
rename_predictor.brmsterms <- function(x, ...) {
out <- list()
for (dp in names(x$dpars)) {
c(out) <- rename_predictor(x$dpars[[dp]], data = data, ...)
c(out) <- rename_predictor(x$dpars[[dp]], ...)
}
for (nlp in names(x$nlpars)) {
c(out) <- rename_predictor(x$nlpars[[nlp]], data = data, ...)
c(out) <- rename_predictor(x$nlpars[[nlp]], ...)
}
if (is.formula(x$adforms$mi)) {
c(out) <- rename_Ymi(x, data = data, ...)
c(out) <- rename_Ymi(x, ...)
}
c(out) <- rename_thres(x, data = data, ...)
c(out) <- rename_family_cor_pars(x, data = data, ...)
c(out) <- rename_thres(x, ...)
c(out) <- rename_family_cor_pars(x, ...)
out
}

@@ -105,12 +103,9 @@ rename_predictor.btl <- function(x, ...) {
}

# helps in renaming fixed effects parameters
rename_fe <- function(bterms, data, pars, prior, ...) {
rename_fe <- function(bterms, pars, prior, ...) {
out <- list()
fixef <- colnames(data_fe(bterms, data)$X)
if (stan_center_X(bterms)) {
fixef <- setdiff(fixef, "Intercept")
}
fixef <- bterms$frame$fe$vars_stan
if (!length(fixef)) {
return(out)
}
@@ -131,9 +126,9 @@ rename_fe <- function(bterms, data, pars, prior, ...) {
}

# helps in renaming special effects parameters
rename_sp <- function(bterms, data, pars, prior, ...) {
rename_sp <- function(bterms, pars, prior, ...) {
out <- list()
spef <- tidy_spef(bterms, data)
spef <- bterms$frame$sp
if (!nrow(spef)) {
return(out)
}
@@ -164,9 +159,9 @@ rename_sp <- function(bterms, data, pars, prior, ...) {
}

# helps in renaming category specific effects parameters
rename_cs <- function(bterms, data, pars, ...) {
rename_cs <- function(bterms, pars, ...) {
out <- list()
csef <- colnames(data_cs(bterms, data)$Xcs)
csef <- bterms$frame$cs$vars
if (length(csef)) {
p <- usc(combine_prefix(bterms))
bcsp <- paste0("bcs", p)
@@ -205,8 +200,8 @@ rename_thres <- function(bterms, pars, ...) {

# helps in renaming global noise free variables
# @param meef data.frame returned by 'tidy_meef'
rename_Xme <- function(meef, pars, ...) {
stopifnot(is.meef_frame(meef))
rename_Xme <- function(bterms, pars, ...) {
meef <- bterms$frame$me
out <- list()
levels <- attr(meef, "levels")
groups <- unique(meef$grname)
@@ -252,16 +247,15 @@ rename_Xme <- function(meef, pars, ...) {
}

# helps in renaming estimated missing values
rename_Ymi <- function(bterms, data, pars, ...) {
rename_Ymi <- function(bterms, pars, ...) {
stopifnot(is.brmsterms(bterms))
out <- list()
if (is.formula(bterms$adforms$mi)) {
resp <- usc(combine_prefix(bterms))
resp_data <- data_response(bterms, data, check_response = FALSE)
Ymi <- paste0("Ymi", resp)
pos <- grepl(paste0("^", Ymi, "\\["), pars)
if (any(pos)) {
Jmi <- resp_data$Jmi
Jmi <- bterms$sdata$resp$Jmi
fnames <- paste0(Ymi, "[", Jmi, "]")
lc(out) <- rlist(pos, fnames)
}
@@ -270,10 +264,10 @@ rename_Ymi <- function(bterms, data, pars, ...) {
}

# helps in renaming parameters of gaussian processes
rename_gp <- function(bterms, data, pars, ...) {
rename_gp <- function(bterms, pars, ...) {
out <- list()
p <- usc(combine_prefix(bterms), "prefix")
gpef <- tidy_gpef(bterms, data)
gpef <- bterms$frame$gp
for (i in seq_rows(gpef)) {
# rename GP hyperparameters
sfx1 <- gpef$sfx1[[i]]
@@ -318,9 +312,9 @@ rename_gp <- function(bterms, data, pars, ...) {
}

# helps in renaming smoothing term parameters
rename_sm <- function(bterms, data, pars, prior, ...) {
rename_sm <- function(bterms, pars, prior, ...) {
out <- list()
smef <- tidy_smef(bterms, data)
smef <- bterms$frame$sm
if (NROW(smef)) {
p <- usc(combine_prefix(bterms))
Xs_names <- attr(smef, "Xs_names")
@@ -360,13 +354,14 @@ rename_sm <- function(bterms, data, pars, prior, ...) {
}

# helps in renaming autocorrelation parameters
rename_ac <- function(bterms, data, pars, ...) {
rename_ac <- function(bterms, pars, ...) {
out <- list()
acef <- tidy_acef(bterms)
acef <- bterms$frame$ac
resp <- usc(bterms$resp)
if (has_ac_class(acef, "unstr")) {
time <- get_ac_vars(acef, "time", dim = "time")
times <- extract_levels(get(time, data))
#time <- get_ac_vars(acef, "time", dim = "time")
#times <- extract_levels(get(time, data))
times <- attr(acef, "times")
corname <- paste0("cortime", resp)
regex <- paste0("^", corname, "\\[")
cortime_names <- get_cornames(times, type = corname, brackets = FALSE)
@@ -377,8 +372,9 @@ rename_ac <- function(bterms, data, pars, ...) {

# helps in renaming group-level parameters
# @param ranef: data.frame returned by 'tidy_ranef'
rename_re <- function(ranef, pars, ...) {
rename_re <- function(bterms, pars, ...) {
out <- list()
ranef <- bterms$frame$re
if (has_rows(ranef)) {
for (id in unique(ranef$id)) {
r <- subset2(ranef, id = id)
@@ -414,7 +410,7 @@ rename_re <- function(ranef, pars, ...) {
}
}
if (any(grepl("^r_", pars))) {
c(out) <- rename_re_levels(ranef, pars = pars)
c(out) <- rename_re_levels(bterms, pars = pars)
}
tranef <- get_dist_groups(ranef, "student")
for (i in seq_rows(tranef)) {
@@ -428,8 +424,9 @@ rename_re <- function(ranef, pars, ...) {

# helps in renaming varying effects parameters per level
# @param ranef: data.frame returned by 'tidy_ranef'
rename_re_levels <- function(ranef, pars, ...) {
rename_re_levels <- function(bterms, pars, ...) {
out <- list()
ranef <- bterms$frame$re
for (i in seq_rows(ranef)) {
r <- ranef[i, ]
p <- usc(combine_prefix(r))
4 changes: 2 additions & 2 deletions R/stan-predictor.R
Original file line number Diff line number Diff line change
@@ -819,7 +819,7 @@ stan_sm <- function(bterms, prior, threads, normalize, ...) {
# @note not implemented for non-linear models
stan_cs <- function(bterms, prior, threads, normalize, ...) {
out <- list()
csef <- bterms$frame$cs
csef <- bterms$frame$cs$vars
px <- check_prefix(bterms)
p <- usc(combine_prefix(px))
resp <- usc(bterms$resp)
@@ -1805,7 +1805,7 @@ stan_nl <- function(bterms, nlpars, threads, ...) {
}

# global Stan definitions for noise-free variables
stan_me <- function(bterms, prior, threads, normalize) {
stan_Xme <- function(bterms, prior, threads, normalize) {
meef <- bterms$frame$me
stopifnot(is.meef_frame(meef))
if (!nrow(meef)) {
2 changes: 1 addition & 1 deletion R/stancode.R
Original file line number Diff line number Diff line change
@@ -125,7 +125,7 @@ stancode.default <- function(object, data, family = gaussian(),
scode_ranef <- stan_re(
bterms, prior = prior, threads = threads, normalize = normalize
)
scode_Xme <- stan_me(
scode_Xme <- stan_Xme(
bterms, prior = prior, threads = threads, normalize = normalize
)
scode_global_defs <- stan_global_defs(