-
Notifications
You must be signed in to change notification settings - Fork 0
/
titanic_bayes_miss.Rmd
249 lines (180 loc) · 9.4 KB
/
titanic_bayes_miss.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
---
title: "Bayesian Titanic Data Analysis"
author: "Nathan T. James"
date: "<small>`r Sys.Date()`</small>"
output:
html_document:
toc: no
toc_depth: 3
number_sections: false
toc_float:
collapsed: false
code_folding: hide
theme: paper
code_folding: hide
editor_options:
chunk_output_type: inline
---
```{r setup, echo=FALSE, include=FALSE}
rm(list=ls())
# setup switch for Windows/Linux
sn <- Sys.info()['sysname']
wd <- switch(sn,
Windows=file.path("C:/Users/nj115/Dropbox/njames/school/PhD/courses/2018_19/ta_rms/bayes_ex/titanic_bayes_ex"),
Linux=file.path("/home/nathan/Dropbox/njames/school/PhD/courses/2018_19/ta_rms/bayes_ex/titanic_bayes_ex"))
knitr::opts_chunk$set(echo = TRUE)
knitr::opts_knit$set(root.dir=file.path(wd))
# load libraries
libs <- c("rms", "brms","splines")
invisible(lapply(libs, library, character.only = TRUE))
#knitrSet(lang='markdown', fig.path='png/', fig.align='center', w=6.5, h=4.5, cache=TRUE)
set.seed(3742)
```
```{r load_data}
require(rms)
getHdata(titanic3)
# List variables to be analyzed
v <- c('pclass', 'survived', 'age', 'sex', 'sibsp', 'parch')
t3 <- titanic3[, v]
units(t3$age) <- 'years'
```
```{r desc, warning=FALSE}
html(describe(t3))
```
## Casewise deletion of missing values
The models are implemented using the `brms` (Bayesian Regression Models using 'Stan') package which is a wrapper for the more general probabilistic programming language `Stan` and its R implementation `rstan`. The syntax of the main `brms` function `brm()` uses `R` formula notation is similar to other regression functions such as `glm()`.
The code below produces a Bayesian logistic model for the binary survival outcome with a linear term for age and indicators for sex and passenger class (pclass). For simplicity the siblings/spouse/parent variables are not used. Notice that the `brm()` function produces a warning that 'Rows containing NAs were excluded from the model' which indicates casewise deletion was used.
```{r mod1, cache=TRUE}
require(brms)
# Recommended rstan options
#For execution on a local, multicore CPU with excess RAM
options(mc.cores = parallel::detectCores())
fit1 <- brm(survived ~ age + sex + pclass,
family = bernoulli(), data=t3,
chains = 2, iter = 2000, refresh = 0)
```
The summary and plot don't show any major convergence issues. Using `prior_summary()` we can see that a (weakly informative) student-t prior was used for the Intercept and (improper) flat priors were used for the other coefficients.
```{r mod1_summ, fig.width=10}
# summarize posterior
summary(fit1)
# posterior densities and MCMC traces
plot(fit1)
# priors
prior_summary(fit1)
```
We can also add non-linear terms for age using natural splines and all two-way interactions between sex, pclass, and the splines for age.
```{r mod3, cache=TRUE}
require(splines)
# fit model with natural splines
fit3 <- brm(survived ~ (sex + pclass + ns(age,df=4))^2,
family = bernoulli(), data=t3,
chains = 2, iter=2000, refresh = 0)
```
The marginal effects plot shows the predicted median survival by age, sex, and class along with 95% credible intervals.
```{r mod3_summ, fig.width=10}
summary(fit3)
# similar to Fig 12.5
marginal_effects(fit3, effects ='age:sex', robust = TRUE,
conditions = data.frame(pclass=c('1st','2nd','3rd')))
```
The Bayesian model predictions from `brm()` and the frequentist model predictions from `lrm()` are nearly identical.
```{r mod3_lrm, fig.width=10}
dd <- datadist(t3)
options(datadist='dd')
fit3_lrm <- lrm(survived ~ (sex + pclass + rcs(age,5))^2, data=t3)
plt_lrm <- Predict(fit3_lrm, age, sex, pclass,fun=plogis)
ggplot(plt_lrm)
```
The Bayesian model can also be used to get the posterior survival probability distribution for an individual passenger with a given set of covariates. In the plots below the posterior survival probabilities for 3 types of passenger are shown. The red line is the point estimate from the frequentist model.
```{r ind_pred, cache=TRUE, fig.height=6, fig.width=10}
combos <- expand.grid(age=c(2,21,50), sex=levels(t3$sex),
pclass=levels(t3$pclass))
phat_brm <- fitted(fit3, newdata = combos, summary=FALSE)
phat_lrm <- predict(fit3_lrm, combos, type="fitted")
op<-par(mfrow=c(3,1))
# 2yr, female, 1st class
hist(phat_brm[,1], xlab="Pr(survive)",
main="Predicted survival for 2yr, female, 1st class")
abline(v=phat_lrm[1], col="red")
# 21yr, male, 2nd class
hist(phat_brm[,11], xlab="Pr(survive)",
main="Predicted survival for 21yr, male, 2nd class")
abline(v=phat_lrm[11], col="red")
# 50yr, female, 3rd class
hist(phat_brm[,15], xlab="Pr(survive)",
main="Predicted survival for 50yr, female, 3rd class")
abline(v=phat_lrm[15], col="red")
par(op)
```
## Imputation during model fitting
Rather than performing casewise deletion, we can impute the values of missing age during the model fitting. The overall model includes two sub-models, one for the survival outcome and one for the missing ages. To start, we specify a model for missing age with only an intercept term. This is similar to replacing all the missing ages with the mean of the observed ages.
```{r mod4, cache=TRUE}
# outcome model
bf_outcome <- bf(survived ~ mi(age) + sex + pclass) + bernoulli()
# imputation model for age with just intercept
bf_imp <- bf(age|mi() ~ 1) + gaussian()
# set_rescor(FALSE) --- no residual correlation between multivariate responses vars (age and survived)
fit4 <- brm(bf_outcome + bf_imp + set_rescor(FALSE), data=t3,
chains=2, iter=2000, refresh = 0)
```
The summary for this model looks similar to previous models, but includes coefficients for both the survival outcome model and the age imputation model.
```{r mod4_summ}
summary(fit4)
```
We can specify a more complex imputation for age by modifying the second sub-model.
```{r mod5, cache=TRUE}
# imputation model for age with all main effects and 2-way interactions
bf_imp2 <- bf(age|mi() ~ (sex+pclass+survived)^2) + gaussian()
fit5 <- brm(bf_outcome + bf_imp2 + set_rescor(FALSE), data=t3,
chains=2, iter=2000, refresh = 0)
```
The outcome model includes linear terms for age, sex, and pclass while the imputation model for age includes main effects and two-way interactions for sex, pclass, and survival status.
```{r mod5_summ}
summary(fit5)
```
The distribution of the imputed ages looks reasonable compared to the observed ages given that the imputation model used only 3 variables.
```{r, fig.align="center"}
# compare imputed and observed ages
imp_age<-fitted(fit5,newdata=t3[is.na(t3$age),])[,,'age']
obs_age<-t3[!is.na(t3$age),'age']
ages <- data.frame(age = c(imp_age[,1],obs_age),
imputed = factor(c(rep(1,nrow(imp_age)),
rep(0,length(obs_age))),
labels=c("no","yes")))
ggplot(data=ages,aes(x=imputed,y=age))+geom_boxplot()
```
Unfortunately the `brms` formula syntax doesn't currently allow imputed variables to have non-linear effects in the main outcome model (`survived ~ ns(mi(age),df=4) + sex + pclass` won't work). One possible solution is to examine and modify the underlying Stan code. Running `stancode(fit3)` will show the Stan code underlying the `brms` model for `fit3` (see `?rstan` for more details on using Stan directly in `R`). This option can be tricky because the `R` function used to define the splines needs to be rewritten as a function *within* Stan (https://mc-stan.org/users/documentation/case-studies/splines_in_stan.html).
## Imputation before model fitting
The second possible solution is to separate the imputation from the model by first performing traditional multiple imputation and then combining the results with `brm_multiple()`. Posterior draws from all the imputation models are 'stacked' and inference is performed using this combined posterior dataset. The downside with this approach is that the model needs to be re-fit for each imputation dataset which can be computational intensive with a complex model and many imputations. Below, we use `aregImpute` to get 10 datasets with imputed values for the missing age variable.
```{r mi, cache=TRUE}
# number of imputation datasets
n_imp <- 10
aregimp <- aregImpute(~ age + sex + pclass + survived,
data=t3, n.impute=n_imp, nk=4, pr=FALSE)
# format imputation datasets as list of data.frames
imputed <- lapply(1:n_imp, function(x) impute.transcan(aregimp, imputation=x, data=t3,
list.out=TRUE, pr=FALSE, check=FALSE))
mk_df <- function(x){
data.frame(age=imputed[[x]][1],sex=imputed[[x]][2],
pclass=imputed[[x]][3],survived=imputed[[x]][4])
}
imp_lst <- lapply(1:n_imp, mk_df)
```
```{r mod6, cache=TRUE}
fit_imp1 <- brm_multiple(survived ~ (sex + pclass + ns(age,df=4))^2,
family = bernoulli(), data = imp_lst,
chains = 2, iter = 2000, refresh = 0)
```
The results of the combined multiple imputations and the casewise deletion model are similar. Note that the combined model may issue false positive convergence warnings (see `?brm_multiple`)
### Casewise deletion model
```{r, fig.width=10}
summary(fit3)
marginal_effects(fit3, effects ='age:sex', robust = TRUE,
conditions = data.frame(pclass=c('1st','2nd','3rd')))
```
### Multiple imputation model
```{r, fig.width=10}
summary(fit_imp1)
marginal_effects(fit_imp1, effects ='age:sex', robust = TRUE,
conditions = data.frame(pclass=c('1st','2nd','3rd')))
```