-
Notifications
You must be signed in to change notification settings - Fork 0
/
functions.R
83 lines (65 loc) · 2.14 KB
/
functions.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
## Functions for analysis
## epred_df ------------------------------------------------
epred_df = function(epred, newdata) {
## Convert epred output to a dataframe
##
## epred = output from brms::posterior_epred
## newdata = data used as input to brms::posterior_epred
newdata = as.data.table(newdata)
lst = vector("list", nrow(newdata))
for(i in seq_along(lst)) {
dt = newdata[i, ]
lst[[i]] = data.table(dt, sample = epred[ , i])
}
out = rbindlist(lst)
return(out)
}
## diagnostic checks ---------------------------------------
rhat_highest_dfa <- function(dfa, k = 4, pars = c("sigma", "x\\[", "Z\\[")) {
rhat <- dfa$monitor[which(grepl(paste(pars,
collapse = "|"), rownames(dfa$monitor)) == TRUE),
"Rhat"]
rhat.max <- rev(sort(rhat)[(length(rhat) - k):length(rhat)])
return(rhat.max)
}
neff_lowest_dfa <- function(dfa, k = 4, pars = c("sigma", "x\\[", "Z\\[")) {
neff <- dfa$monitor[which(grepl(paste(pars,
collapse = "|"), rownames(dfa$monitor)) == TRUE),
"n_eff"]
neff.min <- sort(neff)[1:k]
return(neff.min)
}
rhat_highest <- function(stanfit, k = 4, pars) {
rhat <- get_rhat(stanfit, pars = pars)
rhat.max <- rev(sort(rhat)[(length(rhat) - k):length(rhat)])
return(rhat.max)
}
neff_lowest <- function(stanfit, k = 4, pars) {
neff <- get_neff(stanfit, pars = pars)
neff.min <- sort(neff)[1:k]
return(neff.min)
}
pairs_lowest <- function(stanfit, k = 4, pars) {
n <- get_neff(stanfit, pars = pars)
n.min <- names(sort(n))[1:k]
pairs(stanfit, pars = n.min)
}
get_rhat <- function(stanfit, pars) {
if(!is(stanfit, "stanfit")) {
stop("Input not of class stanfit")
}
na.omit(summary(stanfit, pars = pars)$summary[ , "Rhat"])
}
get_neff <- function(stanfit, pars) {
if(!is(stanfit, "stanfit")) {
stop("Input not of class stanfit")
}
summary(stanfit, pars = pars)$summary[ , "n_eff"]
}
total_draws <- function(stanfit) {
## N chains * N draws -- post warmup
if(!is(stanfit, "stanfit")) {
stop("Input not of class stanfit")
}
dim(stanfit)[1] * dim(stanfit)[2]
}