forked from avehtari/BDA_R_demos
-
Notifications
You must be signed in to change notification settings - Fork 0
/
stan_utility.R
59 lines (50 loc) · 2.27 KB
/
stan_utility.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright: Michael Betancourt <https://betanalpha.github.io/writing/>
# License: BSD (3 clause)
# See also http://mc-stan.org/users/documentation/case-studies/rstan_workflow.html
#' Check transitions that ended with a divergence
check_div <- function(fit) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
divergent <- do.call(rbind, sampler_params)[,'divergent__']
n = sum(divergent)
N = length(divergent)
print(sprintf('%s of %s iterations ended with a divergence (%s%%)',
n, N, 100 * n / N))
if (n > 0)
print('Try running with larger adapt_delta to remove the divergences')
}
# Check transitions that ended prematurely due to maximum tree depth limit
check_treedepth <- function(fit, max_depth = 10) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
treedepths <- do.call(rbind, sampler_params)[,'treedepth__']
n = length(treedepths[sapply(treedepths, function(x) x == max_depth)])
N = length(treedepths)
print(sprintf('%s of %s iterations saturated the maximum tree depth of %s (%s%%)',
n, N, max_depth, 100 * n / N))
if (n > 0)
print('Run again with max_depth set to a larger value to avoid saturation')
}
# Checks the energy Bayesian fraction of missing information (E-BFMI)
check_energy <- function(fit) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
for (n in 1:length(sampler_params)) {
energies = sampler_params[n][[1]][,'energy__']
numer = sum(diff(energies)**2) / length(energies)
denom = var(energies)
if (numer / denom < 0.2) {
print(sprintf('Chain %s: E-BFMI = %s', n, numer / denom))
print('E-BFMI below 0.2 indicates you may need to reparameterize your model')
}
}
}
# Returns parameter arrays separated into divergent and non-divergent transitions
partition_div <- function(fit) {
nom_params <- extract(fit, permuted=FALSE)
n_chains <- dim(nom_params)[2]
params <- as.data.frame(do.call(rbind, lapply(1:n_chains, function(n) nom_params[,n,])))
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
divergent <- do.call(rbind, sampler_params)[,'divergent__']
params$divergent <- divergent
div_params <- params[params$divergent == 1,]
nondiv_params <- params[params$divergent == 0,]
return(list(div_params, nondiv_params))
}