Skip to content

Latest commit

 

History

History
578 lines (490 loc) · 20.1 KB

README.md

File metadata and controls

578 lines (490 loc) · 20.1 KB

Implementation of model by Lourenço et al.

Original paper here: https://www.dropbox.com/s/oxmu2rwsnhi9j9c/Draft-COVID-19-Model%20%2813%29.pdf?dl=0

Note that you will also need to download the death data, amazingly compiled by JHU: https://github.com/CSSEGISandData/COVID-19

## install.packages("deSolve")
## devtools::install_github("jameshay218/lazymcmc")
library(tidyverse)
#> ── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──
#> ✓ ggplot2 3.2.1     ✓ purrr   0.3.3
#> ✓ tibble  2.1.3     ✓ dplyr   0.8.4
#> ✓ tidyr   1.0.2     ✓ stringr 1.4.0
#> ✓ readr   1.3.1     ✓ forcats 0.4.0
#> ── Conflicts ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag()    masks stats::lag()
library(lazymcmc)
library(doParallel)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel
library(patchwork)
setwd("~/Documents/GitHub/gupta_model_check")

rerun_fits <- FALSE

n_clusters <- 9
cl <- makeCluster(n_clusters)
registerDoParallel(cl)

mcmcPars1 <- c("iterations"=100000,"popt"=0.44,"opt_freq"=1000,
              "thin"=10,"adaptive_period"=50000,"save_block"=1000)
mcmcPars2 <- c("iterations"=200000,"popt"=0.234,"opt_freq"=1000,
               "thin"=10,"adaptive_period"=100000,"save_block"=1000)

get_index_par <- function(chain, index){
  par_names <- colnames(chain)[2:(ncol(chain)-1)]
  par <- as.numeric(chain[chain$sampno == index, 2:(ncol(chain)-1)])
  names(par) <- par_names
  return(par)
}

runnames <- c("prior_only_lourenco1","prior_only_lourenco2","prior_only_me","fitting_lourenco1","fitting_lourenco2","fitting_me")
n_runs <- length(runnames)
prior_controls <- c(TRUE,TRUE,TRUE,FALSE,FALSE,FALSE)
prior_use <- c(1,2,3,1,2,3)
n_chains <- 3

runnames <- rep(runnames, each=n_chains)
prior_controls <- rep(prior_controls,each=n_chains)
prior_use <- rep(prior_use, each=n_chains)
chain_nos <- rep(1:n_chains, n_runs)

## Define the set of ODEs for the model. Return such that we can solve with deSolve
SIR_odes <- function(t, x, params) {
  y <- x[1]
  z <- x[2]
  dead <- x[3]
  
  ## Extract model parameters
  R0 <- params[1]
  sigma <- params[2]
  rho <- params[3]
  theta <- params[4]
  N <- params[5]
  
  beta <- R0*sigma
  
  dY <- beta*y*(1-z) - sigma*y
  dZ <- beta*y*(1-z)
  
  ## Note need to shift this number forward by psi days post-hoc
  list(c(dY,dZ))
}

## ROUGH PARAMETERS FROM THE PAPER
sigma <- 1/4.5 ## 1/Infectious period
R0 <- 2.5 ## Basic reproductive number
rho <- 0.01 ## Proportion of population at risk of severe disease
theta <- 0.14 ## Probability of dying with severe disease
psi <- 17 ## Time between infection and death
t0 <- 31 + 28
psi <- 17

## Population of UK
N <- 66870000

## EXTRACT AND CLEAN DATA
## Number dead in UK to date
obs_deaths <- read_csv("~/Documents/GitHub/COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv")
#> Parsed with column specification:
#> cols(
#>   .default = col_double(),
#>   `Province/State` = col_character(),
#>   `Country/Region` = col_character()
#> )
#> See spec(...) for full column specifications.
obs_deaths <- obs_deaths %>% filter(`Country/Region` == "United Kingdom") %>% select(-c("Province/State","Country/Region","Lat","Long"))
obs_deaths <- colSums(obs_deaths)
obs_deaths <- data.frame(time=names(obs_deaths),dead=obs_deaths)
obs_deaths$time <- as.Date(as.character(obs_deaths$time),origin="2019-12-01", format="%m/%d/%y")
obs_deaths <- obs_deaths %>% filter(time <= "2020-03-19")
obs_deaths <- obs_deaths %>% filter(time >= "2020-03-05")
data <- obs_deaths

## Take a look to see that model makes sense
## Times to solve model over
t <- seq(0,100,by=1)

## Note starting conditions for population size - we're working per capita here
y0 <- 1/N
results <- as.data.frame(deSolve::ode(y=c(y=y0,z=y0),
                                      times=t, func=SIR_odes,
                                      parms=c(R0,sigma, rho, theta, N)))


## Deaths for that time are actually reported psi days later
results$time_deaths <- results$time + psi + t0
results$time <- results$time + t0
results$susc <- 1 - results$z
results$dead <- N*rho*theta*results$z
results$time <- as.Date(results$time, origin="2019-12-01")
results$time_deaths <- as.Date(results$time_deaths, origin="2019-12-01")

par(mfrow=c(2,1))
plot(results[,c("time","susc")], type='l',col="green",xlab="Time (days)", ylab="Proportion of population",ylim=c(0,1))
lines(results[,c("time","z")],col="blue")
lines(results[,c("time","y")],col="red")
legend("topright", title="Key",
       legend=c("Proportion infectious","Proportion no longer susceptible","Proportion susceptible"),
       col=c("red","blue","green"),
       lty=c(1,1,1))
plot(results[,c("time_deaths","dead")],type='l')

##################################
## NOW SET UP CODE TO FIT MODEL
t_long <- seq(0,365,by=1)
pars <- c(R0, sigma, rho, theta, N, psi, t0)
names(pars) <- c("R0","sigma","rho","theta","N","psi","t0")
## Putting model solving code in a function for later use
solve_model <- function(pars, t){
  N <- pars["N"]
  ## Note starting conditions for population size - we're working per capita here
  results <- as.data.frame(deSolve::ode(y=c(y=1/N,z=1/N),
                                        times=t, func=SIR_odes,
                                        parms=pars))

  return(results)
}

## We need to put our likelihood function in a closure environment
## It's important to write the function in this form!
create_lik <- function(parTab, data, PRIOR_FUNC,t, PRIOR_ONLY=FALSE){
  par_names <- parTab$names
  
  ## Extract observed incidence
  obs_dead <- data$dead
  max_time <- max(data$time)
  min_time <- min(data$time)
  
  likelihood_func <- function(pars){
    names(pars) <- par_names
    ## Solve model
    if(!PRIOR_ONLY){
      results <- solve_model(pars, t)
      
      ## Deaths for that time are actually reported psi days later
      psi <- pars["psi"]
      t0 <- pars["t0"]
      theta <- pars["theta"]
      rho <- pars["rho"]
      
      ## Shift times and convert to dates
      results$time_deaths <- results$time + as.integer(psi + t0)
      results$time <- results$time + as.integer(t0)
      results$dead <- N*rho*theta*results$z
      results$time <- as.Date(results$time, origin="2019-12-01")
      results$time_deaths <- as.Date(results$time_deaths, origin="2019-12-01")
      
      ## Get deaths that match data
      predicted1 <-  results %>%
        filter(time_deaths <= max_time & time_deaths >= min_time) %>% pull(dead)
      lik <- sum(dpois(x=obs_dead,lambda=predicted1,log=TRUE))
    } else {
      lik <- 0
    }
    if(!is.null(PRIOR_FUNC)) lik <- lik + PRIOR_FUNC(pars)
    lik
  }
}

## Control parameters in MCMC
parTab <- data.frame(names=c("R0","sigma", "rho", "theta", "N", "psi", "t0"),
                     values=pars,
                     fixed=c(0,0,0,0,1,0,0),
                     steps=c(0.1,0.1,0.1,0.1,0.1,0.1,0.1),
                     lower_bound=c(1,0,0,0,0,0,0),
                     upper_bound=c(10,10,0.1,10,100000000,25,75))


## Test the model solves
f <- create_lik(parTab, data, NULL, t=t_long)
f(pars)
#> [1] -54.46621

## Starting points, chosen pseudo at random
## seeding chains for SIR models is hard with deSolve,
## so I've chosen points near the true values.
startTab <- parTab

prior_func_lourenco1 <- function(pars){
  names(pars) <- parTab$names
  p1 <- dnorm(pars["R0"], 2.25, 0.025,log=TRUE)
  p2 <- dnorm(1/pars["sigma"], 4.5, 1, log=TRUE)
  p3 <- dnorm(pars["psi"], 17, 2, log=TRUE)
  p4 <- dgamma(pars["rho"], shape=5,rate=5/0.001,log=TRUE)
  p5 <- dnorm(pars["theta"],0.14,0.007,log=TRUE)
  return(sum(p1,p2,p3,p4,p5))
}
prior_func_lourenco2 <- function(pars){
  names(pars) <- parTab$names
  p1 <- dnorm(pars["R0"], 2.75, 0.025,log=TRUE)
  p2 <- dnorm(1/pars["sigma"], 4.5, 1, log=TRUE)
  p3 <- dnorm(pars["psi"], 17, 2, log=TRUE)
  p4 <- dgamma(pars["rho"], shape=5,rate=5/0.01,log=TRUE)
  p5 <- dnorm(pars["theta"],0.14,0.007,log=TRUE)
  return(sum(p1,p2,p3,p4,p5))
}

prior_func_me <- function(pars){
  names(pars) <- parTab$names
  p1 <- dnorm(pars["R0"], 2.5, 0.5,log=TRUE)
  p2 <- dnorm(1/pars["sigma"], 4.5, 1, log=TRUE)
  p3 <- dnorm(pars["psi"], 17, 2, log=TRUE)
  p4 <- 0
  p5 <- dnorm(pars["theta"],0.14,0.007,log=TRUE)
  return(sum(p1,p2,p3,p4,p5))
}
top_wd <- getwd()

if(rerun_fits){
  
res <- foreach(i=seq_along(runnames),.packages=c("lazymcmc","tidyverse")) %dopar% {
    setwd(top_wd)
    setwd("chains")
    
    if(!file.exists(runnames[i])) {
      dir.create(runnames[i])
    }
    setwd(runnames[i])
    
    filename_tmp <- paste0(runnames[i],"_",chain_nos[i])
    prior_only <- prior_controls[i]
    prior_use_tmp <- prior_use[i]
    
    if(prior_use_tmp == 1){
      prior_func <- prior_func_lourenco1
    } else if(prior_use_tmp == 2) {
      prior_func <- prior_func_lourenco2
    } else {
      prior_func <- prior_func_me
    }
    
    output <- run_MCMC(parTab=startTab, data=obs_deaths, mcmcPars=mcmcPars1, filename=filename_tmp,
                       CREATE_POSTERIOR_FUNC = create_lik, mvrPars = NULL, PRIOR_FUNC=prior_func, t=t_long,
                       PRIOR_ONLY=prior_only)
    chain <- read.csv(output$file)
    best_pars <- get_best_pars(chain)
    chain <- chain[chain$sampno >= mcmcPars1["adaptive_period"],2:(ncol(chain)-1)]
    covMat <- cov(chain)
    mvrPars <- list(covMat,0.5,w=0.8)
    
    ## Start from best location of previous chain
    startTab$values <- best_pars
    
    ## Run second chain
    output <- run_MCMC(parTab=startTab, data=obs_deaths, mcmcPars=mcmcPars2, filename=filename_tmp,
                       CREATE_POSTERIOR_FUNC = create_lik,PRIOR_FUNC=prior_func, t=t_long,
                       PRIOR_ONLY=prior_only, mvrPars=mvrPars)
  }
}
  
## Read in the MCMC chains
setwd(top_wd)
setwd("chains")

all_runs <- list.files()
all_chains <- NULL
for(run in all_runs){
  chains <- load_mcmc_chains(paste0(top_wd, "/chains/",run),parTab,TRUE,1,200000)
  
  tmp_chain <- as.data.frame(chains[[2]])
  tmp_chain$sampno <- 1:nrow(tmp_chain)
  tmp_chain$run <- run
  tmp_chain <- reshape2::melt(tmp_chain, id.vars=c("sampno","run"))
  all_chains <- rbind(all_chains, tmp_chain)
  
  #pdf(paste0(top_wd,"/",run,"_chains.pdf"))
  #plot(chains[[1]])
  #dev.off()
}
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_lourenco1/fitting_lourenco1_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_lourenco1/fitting_lourenco1_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_lourenco1/fitting_lourenco1_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
#> 
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_lourenco2/fitting_lourenco2_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_lourenco2/fitting_lourenco2_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_lourenco2/fitting_lourenco2_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
#> 
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_me/fitting_me_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_me/fitting_me_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_me/fitting_me_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
#> 
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_lourenco1/prior_only_lourenco1_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_lourenco1/prior_only_lourenco1_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_lourenco1/prior_only_lourenco1_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
#> 
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_lourenco2/prior_only_lourenco2_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_lourenco2/prior_only_lourenco2_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_lourenco2/prior_only_lourenco2_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
#> 
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_me/prior_only_me_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_me/prior_only_me_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/prior_only_me/prior_only_me_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
all_chains1 <- all_chains
all_chains <- all_chains %>% mutate(value=ifelse(variable=="sigma", 1/value, value))
all_chains$group <- "new"
all_chains <- all_chains %>% mutate(group = ifelse(run %in% c("fitting_lourenco1","prior_only_lourenco1"), "original_low", group),
                                    group = ifelse(run %in% c("fitting_lourenco2","prior_only_lourenco2"), "original_high",group))

var_key <- c("R0"="R[0]",
             "sigma"="sigma",
             "rho"="rho",
             "theta"="theta",
             "psi"="phi",
             "t0"="Seed_date",
             "lnlike"="Posterior_probability")

run_names_key <- c("fitting_lourenco1"="After fitting to data, original",
                   "prior_only_lourenco1"="Before fitting to data, original",
                   "fitting_lourenco2"="After fitting to data, original",
                   "prior_only_lourenco2"="Before fitting to data, original",
                   "fitting_me"="After fitting to data, my analysis",
                   "prior_only_me"="Before fitting to data, new")

all_chains$run <- run_names_key[all_chains$run]

all_chains$variable <- var_key[all_chains$variable]
all_chains$variable <- factor(all_chains$variable, levels=var_key)

colnames(all_chains)[2] <- "Version"

blank_limits <- data.frame(variable=c("R[0]","R[0]","rho","rho"),x=c(1,4,0,0.1))


## Plot the posteriors and priors
p1 <- ggplot(all_chains[all_chains$group == "original_low",]) + 
  geom_density(aes(x=value,fill=Version), alpha=0.5) +
  geom_blank(data=blank_limits,aes(x=x))+
  scale_fill_manual(values=c("#E69F00", "#56B4E9")) +
  facet_wrap(~variable,scales="free",ncol=3, labeller=label_parsed) + 
  xlab("Estimate") + ylab("Posterior density") +
  theme_bw() +
  theme(legend.position=c(0.7,0.2))
p2 <- ggplot(all_chains[all_chains$group == "original_high",]) + 
  geom_density(aes(x=value,fill=Version), alpha=0.5) +
  geom_blank(data=blank_limits,aes(x=x))+
  scale_fill_manual(values=c("#E69F00", "#56B4E9")) +
  facet_wrap(~variable,scales="free",ncol=3, labeller=label_parsed) + 
  xlab("Estimate") + ylab("Posterior density") +
  theme_bw() +
  theme(legend.position=c(0.7,0.2))
p3 <- ggplot(all_chains[all_chains$group == "new",]) + 
  geom_density(aes(x=value,fill=Version), alpha=0.5) + 
  scale_fill_manual(values=c("#E69F00", "#56B4E9")) +
  geom_blank(data=blank_limits,aes(x=x))+
  facet_wrap(~variable,scales="free",ncol=3, labeller=label_parsed) +
  xlab("Estimate") + ylab("Posterior density") +
  theme_bw() +
  theme(legend.position=c(0.7,0.2))

#png(paste0(top_wd,"/plots/densities_original_low.png"),height=5,width=8,res=300,units="in")
#p1
#dev.off()

#png(paste0(top_wd,"/plots/densities_original_high.png"),height=5,width=8,res=300,units="in")
#p2
#dev.off()

#png(paste0(top_wd,"/plots/densities_new.png"),height=5,width=8,res=300,units="in")
#p3
#dev.off()
p1

p2

p3

all_chains_subset <- all_chains %>% filter(Version %in% c("After fitting to data, original", "After fitting to data, my analysis") & variable %in% var_key[names(var_key) %in% c("R0","sigma","rho","t0")])
all_chains_subset$variable <- as.character(all_chains_subset$variable)
all_chains_subset <- all_chains_subset %>% pivot_wider(names_from=variable,values_from=value)

## Plot the 2D densities
p4 <- ggplot(all_chains_subset) + 
  geom_hex(aes(x=`sigma`,y=`R[0]`,fill=..density..),bins=100) +
  ylab("Basic reproductive number, R[0]") + xlab("Infectious period in days") +
  facet_wrap(~Version) +
  scale_fill_gradient2(low="#5E4FA2",mid="#FAFDB8",high="#9E0142",midpoint= 0.008) +
  theme_bw()

p5 <- ggplot(all_chains_subset) + 
  geom_hex(aes(y=`rho`,x=as.Date(`Seed_date`,origin="2019-12-01"),fill=..density..),bins=100) +
  xlab("Seed date") + ylab("Proportion of population at\n risk of severe disease, ρ") +
  scale_fill_gradient2(low="#5E4FA2",mid="#FAFDB8",high="#9E0142",midpoint= 0.008) +
  facet_wrap(~Version) + theme_bw()

#png(paste0(top_wd,"/plots/correlations.png"),height=6,width=8,res=300,units="in")
#p4/p5
#dev.off()
p4/p5

## Calculate 95% CI on proportion infected
setwd(top_wd)
setwd("chains")
run <- "fitting_me/"
chains <- load_mcmc_chains(paste0(top_wd, "/chains/",run),parTab,FALSE,1,200000)
#> [1] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_me//fitting_me_1_multivariate_chain.csv"
#> [2] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_me//fitting_me_2_multivariate_chain.csv"
#> [3] "/Users/james/Documents/GitHub/gupta_model_check/chains/fitting_me//fitting_me_3_multivariate_chain.csv"
#> [[1]]
#> [1] 10001
#> 
#> [[2]]
#> [1] 10001
#> 
#> [[3]]
#> [1] 10001
chain <- as.data.frame(chains[[2]])  
chain$sampno <- 1:nrow(chain)
samps <- sample(unique(chain$sampno), 10000)

prop_infected <- numeric(10000)
rhos <- numeric(10000)
t0s <- numeric(10000)
for(i in seq_along(samps)){
  pars <- get_index_par(chain, samps[i])
  results <- solve_model(pars, t_long)
  rhos[i] <- pars["rho"]
  t0s[i] <- pars["t0"]
  psi <- pars["psi"]
  t0 <- pars["t0"]
  theta <- pars["theta"]
  rho <- pars["rho"]
  results$time_deaths <- results$time + as.integer(psi + t0)
  results$time <- results$time + as.integer(t0)
  results$dead <- N*rho*theta*results$z
  results$time <- as.Date(results$time, origin="2019-12-01")
  results$time_deaths <- as.Date(results$time_deaths, origin="2019-12-01")
  
  prop_infected[i] <- results[results$time == "2020-03-19","z.N"]
}
print(paste0("Proportion immune: ", quantile(prop_infected,c(0.025,0.5, 0.975))*100))
#> [1] "Proportion immune: 1.91350113599975" "Proportion immune: 11.2128233814482"
#> [3] "Proportion immune: 53.6527219983157"

final_dat <- data.frame(rho=rhos,prop_infected=prop_infected)

p7 <- ggplot(final_dat) + 
  geom_hex(aes(y=`rho`,x=prop_infected,fill=..density..),bins=100) +
  xlab("Proportion of UK population immune by 19/03/2020") + ylab("Proportion of population at\n risk of severe disease") +
  scale_fill_gradient2(low="#5E4FA2",mid="#FAFDB8",high="#9E0142",midpoint= 0.0015) + theme_bw() + theme(legend.position=c(0.8,0.7))

#png(paste0(top_wd,"/plots/final_new.png"),height=4,width=5,res=300,units="in")
#p7
#dev.off()
p7