forked from stan-dev/stancon_talks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kfold.Rmd
executable file
·407 lines (358 loc) · 14.4 KB
/
kfold.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
---
title: "Appendix"
subtitle: "Models of retrieval in sentence comprehension"
author:
- name: B. Nicenboim
- name: S. Vasishth
date: "January 10th, 2017"
output:
html_document:
toc: true
number_sections: true
fig_caption: true
css: styles.css
bibliography: bibliography.bib
---
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
TeX: { equationNumbers: { autoNumber: "AMS" } }
});
</script>
```{r, include = FALSE}
knitr::opts_chunk$set(tidy = TRUE, cache = TRUE)
```
```{r echo=FALSE}
# From http://stackoverflow.com/questions/37116632/rmarkdown-html-number-figures
#Determine the output format of the document
outputFormat = knitr::opts_knit$get("rmarkdown.pandoc.to")
#Figure and Table Caption Numbering, for HTML do it manually
capTabNo = 1; capFigNo = 1;
#Function to add the Table Number
capTab = function(x){
if(outputFormat == 'html'){
x = paste0("Table ",capTabNo,". ",x)
capTabNo <<- capTabNo + 1
}; x
}
#Function to add the Figure Number
capFig = function(x){
if(outputFormat == 'html'){
x = paste0("Figure ",capFigNo,". ",x)
capFigNo <<- capFigNo + 1
}; x
}
```
# Appendix: K-fold cross validation
Since there were a number of $\hat{k} > 0.7$ indicating an unreliable
calculation of $\hat{elpd}$ using PSIS-LOO for both models, we perform K-fold
cross validation with $K=10$ following @VehtariEtAl2016.
```{r data-loading, message=FALSE, warning=FALSE, results="hide"}
# Load R packages
library(ggplot2)
library(scales)
library(hexbin)
library(tidyr)
library(dplyr)
library(MASS)
library(rstan)
library(loo)
library(matrixStats)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
set.seed(42)
iter <- 2000
chains <- 3
```
We use the same data as before, but we create 10 lists, so that each list has
9/10 of the observations as training data and 1/10 of the observations as held
out data. In order to avoid having very few or no observations of a
given participant in the training set of some list, we group the data by
participants before doing the split.
```{r folding}
load("dataNIG-SRC.Rda")
# save the order of the data frame
dexp$row <- as.numeric(as.character(row.names(dexp)))
if(!file.exists("ldata.Rda")){
# The following code prevents that there will be a fold without a participant
# (or where a participant is underrepresented)
# (There might be a simpler way)
#
# Assuming K = 10, the procedure is the following:
# 1) We extract 1/10 of the data and save it in the list G with k = 1;
# 2) We extract 1/9 of the remaining data and save it with k = 2;
# 3) We extract 1/8 of the remaining data and save it with k = 3;
# 4) ...;
# 10) We extract all the data the data and save it with k = 10
K <- 10
d <- dexp
G <- list()
for(i in 1:K){
G[[i]] <- sample_frac(group_by(d,participant),(1/(K+1-i)))
G[[i]]$k <- i
d <<- anti_join(d,G[[i]], by = c("participant", "item", "winner", "RT", "condition", "row"))
}
# We create a dataframe again:
dK <- bind_rows(G)
# We save the order of the dataframe
dK <- dK[order(dK$row), ]
ldata <- plyr::llply(1:K, function(i) {
list(N_obs=nrow(dK),
winner = dK$winner,
RT = dK$RT,
N_choices = max(dK$winner),
subj = as.numeric(as.factor(dK$participant)),
N_subj = length(unique(dK$participant)),
item = as.numeric(as.factor(dK$item)),
N_item = length(unique(dK$item)),
holdout = ifelse(dK$k == i, 1, 0))
})
save(ldata, file = "ldata.Rda")
} else {
load("ldata.Rda")
}
# All the folds include all the participant
for(k in 1:10){
print(paste0("Fold = ",k,"; N_subj = ", ldata[[k]]$N_subj, "; N_heldout = ", sum(ldata[[k]]$holdout)))
}
```
In order to apply K-fold cross validation, we modify the *Stan* models so
that they increment the log-likelihood only when the data are not held out. We
use the `generated quantities` section to calculate the pointwise likelihood
for every observation, but we will later use only the likelihood of the
held-out data for the calculation of $\hat{elpd}$. See also the files
`activation-based_h_Kfold.stan` and `direct_access_h_Kfold.stan` for the details.
For the activation-based model:
```
(...)
model {
(...)
for (n in 1:N_obs) {
if(holdout[n] == 0){
vector[N_choices] alpha;
real psi;
alpha = alpha_0 + u[, subj[n]] + w[, item[n]];
psi = exp(psi_0 + u_psi[subj[n]]);
target += race(winner[n], RT[n], alpha, b, sigma, psi);
}
}
}
generated quantities {
(...)
vector[N_obs] log_lik;
(...)
for (n in 1:N_obs) {
vector[N_choices] alpha;
real psi;
alpha = alpha_0 + u[, subj[n]] + w[, item[n]];
psi = exp(psi_0 + u_psi[subj[n]]);
log_lik[n] = race(winner[n], RT[n], alpha, b, sigma, psi);
}
}
```
For the direct access model:
```
(...)
model {
(...)
for (n in 1:N_obs) {
if(holdout[n] == 0){
real mu_da;
real mu_b;
vector[N_choices] beta;
real psi;
mu_da = mu_da_0 + u_RT[1,subj[n]] + w_RT[1,item[n]];
mu_b = mu_b_0 + u_RT[2,subj[n]] + w_RT[2,item[n]];
beta = beta_0 + u[,subj[n]] + w[,item[n]];
psi = exp(psi_0 + u_psi[subj[n]]);
target += da(winner[n], RT[n], beta, P_b, mu_da, mu_b, sigma, psi);
}
}
generated quantities {
(...)
vector[N_obs] log_lik;
(...)
for (n in 1:N_obs) {
real mu_da;
real mu_b;
vector[N_choices] beta;
real psi;
vector[2] gen;
mu_da = mu_da_0 + u_RT[1,subj[n]] + w_RT[1,item[n]];
mu_b = mu_b_0 + u_RT[2,subj[n]] + w_RT[2,item[n]];
beta = beta_0 + u[,subj[n]] + w[,item[n]];
psi = exp(psi_0 + u_psi[subj[n]]);
log_lik[n] = da(winner[n], RT[n], beta, P_b, mu_da, mu_b, sigma, psi);
}
}
```
Then we use the following function to parallelize the 10 runs of both models:
```{r samplingfunction}
# The following function can run all the chains of all the folds of the model in parallel:
stan_kfold <- function(file, list_of_datas, chains, cores,...){
library(pbmcapply)
badRhat <- 1.1
K <- length(list_of_datas)
model <- stan_model(file=file)
# First parallelize all chains:
sflist <-
pbmclapply(1:(K*chains), mc.cores = cores,
function(i){
# Fold number:
k <- round((i+1) / chains)
s <- sampling(model, data = list_of_datas[[k]],
chains = 1, chain_id = i, ...)
return(s)
})
# Then merge the K * chains to create K stanfits:
stanfit <- list()
for(k in 1:K){
inchains <- (chains*k - 2):(chains*k)
# Merge `chains` of each fold
stanfit[[k]] <- sflist2stanfit(sflist[inchains])
}
return(stanfit)
}
```
We run the models and extract the log-likelihood evaluated at the
posterior simulations of the parameter values:
```{r sampling}
# Wrapper function to extract the log_lik of the held-out data, given a list of stanfits, and a list which indicates with 1 and 0 whether the observation was held out or not:
extract_log_lik_K <- function(list_of_stanfits, list_of_holdout, ...){
K <- length(list_of_stanfits)
list_of_log_liks <- plyr::llply(1:K, function(k){
extract_log_lik(list_of_stanfits[[k]],...)
})
# `log_lik_heldout` will include the loglike of all the held out data of all the folds.
# We define `log_lik_heldout` as a (samples x N_obs) matrix
# (similar to each log_lik matrix)
log_lik_heldout <- list_of_log_liks[[1]] * NA
for(k in 1:K){
log_lik <- list_of_log_liks[[k]]
samples <- dim(log_lik)[1]
N_obs <- dim(log_lik)[2]
# This is a matrix with the same size as log_lik_heldout
# with 1 if the data was held out in the fold k
heldout <- matrix(rep(list_of_holdout[[k]], each = samples), nrow = samples)
# Sanity check that the previous log_lik is not being overwritten:
if(any(!is.na(log_lik_heldout[heldout==1]))){
warning("Heldout log_lik has been overwritten!!!!")
}
# We save here the log_lik of the fold k in the matrix:
log_lik_heldout[heldout==1] <- log_lik[heldout==1]
}
return(log_lik_heldout)
}
# We apply the function to both models:
if(!file.exists("log_lik_ab.Rda")){
# We run all the chains of all the folds of the activation-based model in parallel:
# (We are using 30 cores of a server)
ab10Kfits <- stan_kfold("activation-based_h_Kfold.stan", list_of_datas = ldata, chains = 3, cores =30, seed = 42, iter = iter)
holdout <- lapply(ldata, '[[', "holdout")
# We extract all the held_out log_lik of all the folds
log_lik_ab <- extract_log_lik_K(ab10Kfits, holdout, "log_lik")
save(log_lik_ab, file = "log_lik_ab.Rda")
} else {
load("log_lik_ab.Rda")
}
if(!file.exists("log_lik_da.Rda")){
# We run all the chains of all the folds of the direct access model in parallel:
da10Kfits <- stan_kfold("direct_access_h_Kfold.stan", list_of_datas = ldata, chains = 3, cores = 30, seed = 42, iter = iter)
holdout <- lapply(ldata, '[[', "holdout")
# We extract all the held_out log_lik of all the folds
log_lik_da <- extract_log_lik_K(da10Kfits, holdout, "log_lik")
save(log_lik_da, file = "log_lik_da.Rda")
} else {
load("log_lik_da.Rda")
}
```
The following function is an adaptation of `loo` function from `R` package `loo` to calculate pointwise and total $\hat{elpd}$ of K-fold cross validation:
```{r kfold}
kfold <- function(log_lik_heldout) {
library(matrixStats)
logColMeansExp <- function(x) {
# should be more stable than log(colMeans(exp(x)))
S <- nrow(x)
colLogSumExps(x) - log(S)
}
# See equation (20) of @VehtariEtAl2016
pointwise <- matrix(logColMeansExp(log_lik_heldout), ncol= 1)
colnames(pointwise) <- "elpd"
# See equation (21) of @VehtariEtAl2016
elpd_kfold <- sum(pointwise)
se_elpd_kfold <- sqrt(ncol(log_lik_heldout) * var(pointwise))
out <- list(
pointwise = pointwise,
elpd_kfold = elpd_kfold,
se_elpd_kfold = se_elpd_kfold)
structure(out, class = "loo")
}
```
We can now repeat the same analysis using K-fold cross validation instead of PSIS-LOO:
```{r kfold-calc}
(kfold_ab <- kfold(log_lik_ab))
(kfold_da <- kfold(log_lik_da))
```
Comparing the models on K-fold cross validation reveals also an estimated
difference in $\hat{elpd}$ in favor of the direct access model in comparison
with the activation-based model:
```{r kfold_comparison}
(kfold_comparison <- compare(kfold_ab, kfold_da))
```
We compare models in their $\hat{elpd}$, point by point below according to
K-fold cross validation calcualtion. These graphs are very similar to the
ones using $\hat{elpd}$ calculated with PSIS-LOO. (The code producing the
graphs is available in the Rmd file.)
```{r kfold_elpd_comp, echo=F, fig.cap = capFig('Comparison of the activation-based and direct access models in terms of their predictive accuracy for each observation. Each axis shows the expected pointwise contributions to k-fold cross validation for each model ($\\hat{elpd}$ stands for the expected log pointwise predictive density of each observation). Higher (or less negative) values of $\\hat{elpd}$ indicate a better fit. Darker cells represent a higher concentration of observations with a given fit.')}
data_elpds <- data.frame(AB = kfold_ab$pointwise[,1],
DA = kfold_da$pointwise[,1])
y_axis_name <- expression(hat(elpd)[direct~~access~~model])
x_axis_name <- expression(hat(elpd)[activation-based~~model])
elpds <- ggplot(data_elpds, aes(x = AB, y = DA)) + theme_bw() +
scale_y_continuous(name = y_axis_name,
breaks = seq(-18, -2 ,2)) +
scale_x_continuous(name = x_axis_name,
breaks = seq(-18, -2, 2)) +
geom_hex(bins = 50) +
scale_fill_gradientn(colours = c("skyblue","darkblue"),
name="Number of\nobservations") +
theme(legend.key.size = unit(0.3, "cm"),
legend.title = element_text(size = 9),
legend.text = element_text(size = 8),
legend.position = c(.8,.2)) +
geom_abline(slope = 1 ,intercept = 0,linetype = "dotted") +
ggtitle("Activation-based vs. direct access models")
print(elpds)
```
```{r kfold_diff_elpd, echo=F, fig.cap = capFig('Comparison of the activation-based and direct access models in terms of their predictive accuracy for each observation depending on its log-transformed reading time (x-axis) and accuracy (left panel showing correct responses, and the right panel showing any of the possible incorrect responses). The y-axis shows the difference between the expected pointwise contributions to k-fold cross validation for each model ($\\hat{elpd}$ stands for the expected log pointwise predictive density of each observation); that is, positive values represent an advantage for the direct access model while negative values represent an advantage for the activation-based model. Darker cells represent a higher concentration of observations with a given fit.')}
dexp$diff_elpd <- kfold_da$pointwise -
kfold_ab$pointwise
# We group the responses in correct (1), and incorrect (2-4)
dexp$response <- factor(ifelse(dexp$winner==1,"Correct", "Incorrect")
,levels=c("Correct", "Incorrect"))
# Readable labels when back-converted from log-scale:
RT_labels <- c(seq(200, 1000, 100), rep("", 9), 2000, rep("", 9), 3000, rep("",9), 4000,rep("",9), 5000)
# Label for the y axis:
y_axis_da_ab <- expression(hat(elpd)[direct~~access~~model] -
hat(elpd)[activation-based~~model])
# Define the plot:
elpds_diff <- ggplot(dexp,aes(x = RT, y = diff_elpd)) +
scale_x_continuous(name= "log-scaled RT",
trans = log_trans(),
breaks=seq(200, 5000, 100),labels = RT_labels) + theme_bw() +
theme(axis.text.x =
element_text(angle = 90, vjust = 1,hjust = 1),
legend.key.size = unit(0.3, "cm"),
legend.title = element_text(size = 9),
legend.text = element_text(size = 8),
legend.position = c(.9, .80)) +
scale_y_continuous(name = y_axis_da_ab, breaks = seq(-4,4,.5)) +
geom_hex(bins = 50) +
scale_fill_gradientn(colours = c("skyblue","darkblue"),
name = "Number of\nobservations") +
facet_grid(.~response) +
geom_hline(yintercept = 0, linetype = "dotted") +
ggtitle("Comparison of models")
print(elpds_diff)
```
# References
<!-- This comment causes section to be numbered -->