From fad509b91ebb92e71f6d2893ac7f7f3430b78404 Mon Sep 17 00:00:00 2001 From: cbernalz Date: Sun, 25 Aug 2024 10:21:55 -0700 Subject: [PATCH] 2024-08-25 update : added mult. dispatch for model w/out ww. --- .../uciwwiehr_model_fitting_forecast.md | 8 +- .../uciwwiehr_model_fitting_no_forecast.md | 8 +- src/uciwweihr_fit.jl | 113 ++++++--- src/uciwweihr_gq_pp.jl | 237 +++++++++++++----- src/uciwweihr_model.jl | 142 ++++++++++- 5 files changed, 396 insertions(+), 112 deletions(-) diff --git a/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md b/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md index b0b1175..aa4cc5b 100644 --- a/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md +++ b/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md @@ -48,16 +48,16 @@ forecast_weeks = 4 samples = uciwweihr_fit( data_hosp, - data_wastewater, + data_wastewater; obstimes, param_change_times, priors_only, n_samples ) model_output = uciwweihr_gq_pp( - samples = samples, - data_hosp = data_hosp, - data_wastewater = data_wastewater, + samples, + data_hosp, + data_wastewater; obstimes = obstimes, param_change_times = param_change_times, forecast = forecast, diff --git a/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md b/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md index aaf0536..12c7e4d 100644 --- a/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md +++ b/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md @@ -53,16 +53,16 @@ n_samples = 50 samples = uciwweihr_fit( data_hosp, - data_wastewater, + data_wastewater; obstimes, param_change_times, priors_only, n_samples ) model_output = uciwweihr_gq_pp( - samples = samples, - data_hosp = data_hosp, - data_wastewater = data_wastewater, + samples, + data_hosp, + data_wastewater; obstimes = obstimes, param_change_times = param_change_times, ) diff --git a/src/uciwweihr_fit.jl b/src/uciwweihr_fit.jl index bc79cac..4934359 100644 --- a/src/uciwweihr_fit.jl +++ b/src/uciwweihr_fit.jl @@ -7,7 +7,7 @@ The defaults for this fuction will follow those of the default simulation in gen # Arguments - `data_hosp`: An array of hospital data. -- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. +- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. If this is not avaliable, the model used will be one that only uses hospital data. - `obstimes`: An array of timepoints for observed hosp/wastewater. - `priors_only::Bool=false`: A boolean to indicate if only priors are to be sampled. - `n_samples::Int64=500`: Number of samples to be drawn. @@ -48,7 +48,7 @@ The defaults for this fuction will follow those of the default simulation in gen """ function uciwweihr_fit( data_hosp, - data_wastewater, + data_wastewater; obstimes, param_change_times, priors_only::Bool=false, @@ -68,41 +68,92 @@ function uciwweihr_fit( w_init_sd::Float64=0.1, w_init_mean::Float64=log(0.35), sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5 ) + println("Using uciwweihr_model with wastewater!!!") + obstimes = convert(Vector{Float64}, obstimes) + param_change_times = convert(Vector{Float64}, param_change_times) -obstimes = convert(Vector{Float64}, obstimes) -param_change_times = convert(Vector{Float64}, param_change_times) + my_model = uciwweihr_model( + data_hosp, + data_wastewater; + obstimes, + param_change_times, + E_init_sd, E_init_mean, + I_init_sd, I_init_mean, + H_init_sd, H_init_mean, + gamma_sd, log_gamma_mean, + nu_sd, log_nu_mean, + epsilon_sd, log_epsilon_mean, + rho_gene_sd, log_rho_gene_mean, + tau_sd, log_tau_mean, + df_shape, df_scale, + sigma_hosp_sd, sigma_hosp_mean, + Rt_init_sd, Rt_init_mean, + sigma_Rt_sd, sigma_Rt_mean, + w_init_sd, w_init_mean, + sigma_w_sd, sigma_w_mean + ) + + # Sample Posterior + if priors_only + Random.seed!(seed) + samples = sample(my_model, Prior(), MCMCThreads(), 400, n_chains) + else + Random.seed!(seed) + samples = sample(my_model, NUTS(), MCMCThreads(), n_samples, n_chains) + end + return(samples) +end -my_model = uciwweihr_model( - data_hosp, - data_wastewater, - obstimes, +function uciwweihr_fit( + data_hosp; + obstimes, param_change_times, - E_init_sd, E_init_mean, - I_init_sd, I_init_mean, - H_init_sd, H_init_mean, - gamma_sd, log_gamma_mean, - nu_sd, log_nu_mean, - epsilon_sd, log_epsilon_mean, - rho_gene_sd, log_rho_gene_mean, - tau_sd, log_tau_mean, - df_shape, df_scale, - sigma_hosp_sd, sigma_hosp_mean, - Rt_init_sd, Rt_init_mean, - sigma_Rt_sd, sigma_Rt_mean, - w_init_sd, w_init_mean, - sigma_w_sd, sigma_w_mean + priors_only::Bool=false, + n_samples::Int64=500, n_chains::Int64=1, seed::Int64=2024, + E_init_sd::Float64=50.0, E_init_mean::Int64=200, + I_init_sd::Float64=20.0, I_init_mean::Int64=100, + H_init_sd::Float64=5.0, H_init_mean::Int64=20, + gamma_sd::Float64=0.02, log_gamma_mean::Float64=log(1/4), + nu_sd::Float64=0.02, log_nu_mean::Float64=log(1/7), + epsilon_sd::Float64=0.02, log_epsilon_mean::Float64=log(1/5), + sigma_hosp_sd::Float64=50.0, sigma_hosp_mean::Float64=500.0, + Rt_init_sd::Float64=0.3, Rt_init_mean::Float64=0.2, + sigma_Rt_sd::Float64=0.2, sigma_Rt_mean::Float64=-3.0, + w_init_sd::Float64=0.1, w_init_mean::Float64=log(0.35), + sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5 ) + println("Using uciwweihr_model without wastewater!!!") + obstimes = convert(Vector{Float64}, obstimes) + param_change_times = convert(Vector{Float64}, param_change_times) -# Sample Posterior -if priors_only -Random.seed!(seed) -samples = sample(my_model, Prior(), MCMCThreads(), 400, n_chains) -else -Random.seed!(seed) -samples = sample(my_model, NUTS(), MCMCThreads(), n_samples, n_chains) -end -return(samples) + my_model = uciwweihr_model( + data_hosp; + obstimes, + param_change_times, + E_init_sd, E_init_mean, + I_init_sd, I_init_mean, + H_init_sd, H_init_mean, + gamma_sd, log_gamma_mean, + nu_sd, log_nu_mean, + epsilon_sd, log_epsilon_mean, + sigma_hosp_sd, sigma_hosp_mean, + Rt_init_sd, Rt_init_mean, + sigma_Rt_sd, sigma_Rt_mean, + w_init_sd, w_init_mean, + sigma_w_sd, sigma_w_mean + ) + + + # Sample Posterior + if priors_only + Random.seed!(seed) + samples = sample(my_model, Prior(), MCMCThreads(), 400, n_chains) + else + Random.seed!(seed) + samples = sample(my_model, NUTS(), MCMCThreads(), n_samples, n_chains) + end + return(samples) end \ No newline at end of file diff --git a/src/uciwweihr_gq_pp.jl b/src/uciwweihr_gq_pp.jl index 57ca348..45e097e 100644 --- a/src/uciwweihr_gq_pp.jl +++ b/src/uciwweihr_gq_pp.jl @@ -8,7 +8,7 @@ The defaults for this fuction will follow those of the default simulation in gen # Arguments - `samples`: Samples from the posterior/prior distribution. - `data_hosp`: An array of hospital data. -- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. +- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. If this is not avaliable, the model used will be one that only uses hospital data. - `obstimes`: An array of timepoints for observed hosp/wastewater. - `param_change_times`: An array of timepoints where the parameters change. - `seed::Int64=2024`: Seed for the random number generator. @@ -47,10 +47,10 @@ The defaults for this fuction will follow those of the default simulation in gen - Samples from the posterior or prior distribution. """ -function uciwweihr_gq_pp(; +function uciwweihr_gq_pp( samples, data_hosp, - data_wastewater, + data_wastewater; obstimes, param_change_times, seed::Int64=2024, @@ -71,91 +71,190 @@ function uciwweihr_gq_pp(; forecast::Bool=false, forecast_weeks::Int64=4 ) -obstimes = convert(Vector{Float64}, obstimes) -param_change_times = convert(Vector{Float64}, param_change_times) + println("Using uciwweihr_model with wastewater!!!") + obstimes = convert(Vector{Float64}, obstimes) + param_change_times = convert(Vector{Float64}, param_change_times) -if forecast - last_value = obstimes[end] - for i in 1:forecast_weeks - next_value = last_value + 7 - push!(param_change_times, next_value) - push!(obstimes, next_value) - last_value = next_value + if forecast + last_value = obstimes[end] + for i in 1:forecast_weeks + next_value = last_value + 7 + push!(param_change_times, next_value) + push!(obstimes, next_value) + last_value = next_value + end + missing_data_ww = repeat([missing], length(obstimes)) + missing_data_hosp = repeat([missing], length(obstimes)) + data_hosp = vcat(data_hosp, repeat([data_hosp[end]], forecast_weeks)) + data_wastewater = vcat(data_wastewater, repeat([data_wastewater[end]], forecast_weeks)) + else + missing_data_ww = repeat([missing], length(data_wastewater)) + missing_data_hosp = repeat([missing], length(data_hosp)) end - missing_data_ww = repeat([missing], length(obstimes)) - missing_data_hosp = repeat([missing], length(obstimes)) - data_hosp = vcat(data_hosp, repeat([data_hosp[end]], forecast_weeks)) - data_wastewater = vcat(data_wastewater, repeat([data_wastewater[end]], forecast_weeks)) -else - missing_data_ww = repeat([missing], length(data_wastewater)) - missing_data_hosp = repeat([missing], length(data_hosp)) + + my_model = uciwweihr_model( + data_hosp, + data_wastewater; + obstimes, + param_change_times, + E_init_sd, E_init_mean, + I_init_sd, I_init_mean, + H_init_sd, H_init_mean, + gamma_sd, log_gamma_mean, + nu_sd, log_nu_mean, + epsilon_sd, log_epsilon_mean, + rho_gene_sd, log_rho_gene_mean, + tau_sd, log_tau_mean, + df_shape, df_scale, + sigma_hosp_sd, sigma_hosp_mean, + Rt_init_sd, Rt_init_mean, + sigma_Rt_sd, sigma_Rt_mean, + w_init_sd, w_init_mean, + sigma_w_sd, sigma_w_mean + ) + + + #indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) + #samples_randn = ChainsCustomIndex(samples, indices_to_keep) + + #Random.seed!(seed) + #gq_randn = Chains(generated_quantities(my_model, samples_randn)) + + my_model_forecast_missing = uciwweihr_model( + missing_data_hosp, + missing_data_ww; + obstimes, + param_change_times, + E_init_sd, E_init_mean, + I_init_sd, I_init_mean, + H_init_sd, H_init_mean, + gamma_sd, log_gamma_mean, + nu_sd, log_nu_mean, + epsilon_sd, log_epsilon_mean, + rho_gene_sd, log_rho_gene_mean, + tau_sd, log_tau_mean, + df_shape, df_scale, + sigma_hosp_sd, sigma_hosp_mean, + Rt_init_sd, Rt_init_mean, + sigma_Rt_sd, sigma_Rt_mean, + w_init_sd, w_init_mean, + sigma_w_sd, sigma_w_mean + ) + + + indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) + samples_randn = ChainsCustomIndex(samples, indices_to_keep) + + + Random.seed!(seed) + predictive_randn = predict(my_model_forecast_missing, samples_randn) + + Random.seed!(seed) + gq_randn = Chains(generated_quantities(my_model, samples_randn)) + + samples_df = DataFrame(samples) + + results = [DataFrame(predictive_randn), DataFrame(gq_randn), samples_df] + + + return(results) end -my_model = uciwweihr_model( - data_hosp, - data_wastewater, - obstimes, +function uciwweihr_gq_pp( + samples, + data_hosp; + obstimes, param_change_times, - E_init_sd, E_init_mean, - I_init_sd, I_init_mean, - H_init_sd, H_init_mean, - gamma_sd, log_gamma_mean, - nu_sd, log_nu_mean, - epsilon_sd, log_epsilon_mean, - rho_gene_sd, log_rho_gene_mean, - tau_sd, log_tau_mean, - df_shape, df_scale, - sigma_hosp_sd, sigma_hosp_mean, - Rt_init_sd, Rt_init_mean, - sigma_Rt_sd, sigma_Rt_mean, - w_init_sd, w_init_mean, - sigma_w_sd, sigma_w_mean + seed::Int64=2024, + E_init_sd::Float64=50.0, E_init_mean::Int64=200, + I_init_sd::Float64=20.0, I_init_mean::Int64=100, + H_init_sd::Float64=5.0, H_init_mean::Int64=20, + gamma_sd::Float64=0.02, log_gamma_mean::Float64=log(1/4), + nu_sd::Float64=0.02, log_nu_mean::Float64=log(1/7), + epsilon_sd::Float64=0.02, log_epsilon_mean::Float64=log(1/5), + sigma_hosp_sd::Float64=50.0, sigma_hosp_mean::Float64=500.0, + Rt_init_sd::Float64=0.3, Rt_init_mean::Float64=0.2, + sigma_Rt_sd::Float64=0.2, sigma_Rt_mean::Float64=-3.0, + w_init_sd::Float64=0.1, w_init_mean::Float64=log(0.35), + sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5, + forecast::Bool=false, forecast_weeks::Int64=4 ) + println("Using uciwweihr_model without wastewater!!!") + obstimes = convert(Vector{Float64}, obstimes) + param_change_times = convert(Vector{Float64}, param_change_times) -#indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) -#samples_randn = ChainsCustomIndex(samples, indices_to_keep) + if forecast + last_value = obstimes[end] + for i in 1:forecast_weeks + next_value = last_value + 7 + push!(param_change_times, next_value) + push!(obstimes, next_value) + last_value = next_value + end + missing_data_hosp = repeat([missing], length(obstimes)) + data_hosp = vcat(data_hosp, repeat([data_hosp[end]], forecast_weeks)) + else + missing_data_hosp = repeat([missing], length(data_hosp)) + end -#Random.seed!(seed) -#gq_randn = Chains(generated_quantities(my_model, samples_randn)) + my_model = uciwweihr_model( + data_hosp; + obstimes, + param_change_times, + E_init_sd, E_init_mean, + I_init_sd, I_init_mean, + H_init_sd, H_init_mean, + gamma_sd, log_gamma_mean, + nu_sd, log_nu_mean, + epsilon_sd, log_epsilon_mean, + sigma_hosp_sd, sigma_hosp_mean, + Rt_init_sd, Rt_init_mean, + sigma_Rt_sd, sigma_Rt_mean, + w_init_sd, w_init_mean, + sigma_w_sd, sigma_w_mean + ) -my_model_forecast_missing = uciwweihr_model( - missing_data_hosp, - missing_data_ww, - obstimes, - param_change_times, - E_init_sd, E_init_mean, - I_init_sd, I_init_mean, - H_init_sd, H_init_mean, - gamma_sd, log_gamma_mean, - nu_sd, log_nu_mean, - epsilon_sd, log_epsilon_mean, - rho_gene_sd, log_rho_gene_mean, - tau_sd, log_tau_mean, - df_shape, df_scale, - sigma_hosp_sd, sigma_hosp_mean, - Rt_init_sd, Rt_init_mean, - sigma_Rt_sd, sigma_Rt_mean, - w_init_sd, w_init_mean, - sigma_w_sd, sigma_w_mean + + #indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) + #samples_randn = ChainsCustomIndex(samples, indices_to_keep) + + #Random.seed!(seed) + #gq_randn = Chains(generated_quantities(my_model, samples_randn)) + + my_model_forecast_missing = uciwweihr_model( + missing_data_hosp; + obstimes, + param_change_times, + E_init_sd, E_init_mean, + I_init_sd, I_init_mean, + H_init_sd, H_init_mean, + gamma_sd, log_gamma_mean, + nu_sd, log_nu_mean, + epsilon_sd, log_epsilon_mean, + sigma_hosp_sd, sigma_hosp_mean, + Rt_init_sd, Rt_init_mean, + sigma_Rt_sd, sigma_Rt_mean, + w_init_sd, w_init_mean, + sigma_w_sd, sigma_w_mean ) -indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) -samples_randn = ChainsCustomIndex(samples, indices_to_keep) + indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) + samples_randn = ChainsCustomIndex(samples, indices_to_keep) -Random.seed!(seed) -predictive_randn = predict(my_model_forecast_missing, samples_randn) + Random.seed!(seed) + predictive_randn = predict(my_model_forecast_missing, samples_randn) -Random.seed!(seed) -gq_randn = Chains(generated_quantities(my_model, samples_randn)) + Random.seed!(seed) + gq_randn = Chains(generated_quantities(my_model, samples_randn)) -samples_df = DataFrame(samples) + samples_df = DataFrame(samples) -results = [DataFrame(predictive_randn), DataFrame(gq_randn), samples_df] + results = [DataFrame(predictive_randn), DataFrame(gq_randn), samples_df] -return(results) + return(results) end \ No newline at end of file diff --git a/src/uciwweihr_model.jl b/src/uciwweihr_model.jl index 7df816d..ca3ca42 100644 --- a/src/uciwweihr_model.jl +++ b/src/uciwweihr_model.jl @@ -5,7 +5,7 @@ The defaults for this fuction will follow those of the default simulation in gen # Arguments - `data_hosp`: An array of hospital data. -- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. +- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. If this is not avaliable, the model used will be one that only uses hospital data. - `obstimes`: An array of timepoints for observed hosp/wastewater. - `param_change_times`: An array of timepoints where the parameters change. - `E_init_sd::Float64=50.0`: Standard deviation for the initial number of exposed individuals. @@ -40,7 +40,7 @@ The defaults for this fuction will follow those of the default simulation in gen """ @model function uciwweihr_model( data_hosp, - data_wastewater, + data_wastewater; obstimes, param_change_times, E_init_sd::Float64=50.0, E_init_mean::Int64=200, @@ -59,7 +59,6 @@ The defaults for this fuction will follow those of the default simulation in gen sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5 ) - # Prelims max_neg_bin_sigma = 1e10 min_neg_bin_sigma = 1e-10 @@ -143,7 +142,6 @@ The defaults for this fuction will follow those of the default simulation in gen verbose=false, abstol=abstol, reltol=reltol, u0=u0, p=p0, tspan=(0.0, obstimes[end])) # If the ODE solver fails, reject the sample by adding -Inf to the likelihood if sol.retcode != :Success - println("An error occurred during ODE solution!!!") Turing.@addlogprob! -Inf return end @@ -190,3 +188,139 @@ The defaults for this fuction will follow those of the default simulation in gen end + + + +@model function uciwweihr_model( + data_hosp; + obstimes, + param_change_times, + E_init_sd::Float64=50.0, E_init_mean::Int64=200, + I_init_sd::Float64=20.0, I_init_mean::Int64=100, + H_init_sd::Float64=5.0, H_init_mean::Int64=20, + gamma_sd::Float64=0.02, log_gamma_mean::Float64=log(1/4), + nu_sd::Float64=0.02, log_nu_mean::Float64=log(1/7), + epsilon_sd::Float64=0.02, log_epsilon_mean::Float64=log(1/5), + sigma_hosp_sd::Float64=50.0, sigma_hosp_mean::Float64=500.0, + Rt_init_sd::Float64=0.3, Rt_init_mean::Float64=0.2, + sigma_Rt_sd::Float64=0.2, sigma_Rt_mean::Float64=-3.0, + w_init_sd::Float64=0.04, w_init_mean::Float64=logit(0.35), + sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5 + ) + + + # Prelims + max_neg_bin_sigma = 1e10 + min_neg_bin_sigma = 1e-10 + + + # Calculate number of observed datapoints timepoints + l_obs = length(obstimes) + l_param_change_times = length(param_change_times) + + + # PRIORS----------------------------- + # Compartments + E_init_non_centered ~ Normal() + I_init_non_centered ~ Normal() + H_init_non_centered ~ Normal() + # Parameters for compartments + gamma_non_centered ~ Normal() + nu_non_centered ~ Normal() + epsilon_non_centered ~ Normal() + # Parameters for hospital + sigma_hosp_non_centered ~ Normal() + # Non-constant Rt + Rt_params_non_centered ~ MvNormal(zeros(l_param_change_times + 2), I) # +2 for sigma and init + sigma_Rt_non_centered = Rt_params_non_centered[1] + Rt_init_non_centered = Rt_params_non_centered[2] + log_Rt_steps_non_centered = Rt_params_non_centered[3:end] + # Non-constant Hosp Rate w + w_params_non_centered ~ MvNormal(zeros(l_param_change_times + 2), I) # +2 for sigma and init + sigma_w_non_centered = w_params_non_centered[1] + w_init_non_centered = w_params_non_centered[2] + logit_w_steps_non_centered = w_params_non_centered[3:end] + + + # TRANSFORMATIONS-------------------- + # Compartments + E_init = E_init_non_centered * E_init_sd + E_init_mean + I_init = I_init_non_centered * I_init_sd + I_init_mean + H_init = H_init_non_centered * H_init_sd + H_init_mean + # Parameters for compartments + gamma = exp(gamma_non_centered * gamma_sd + log_gamma_mean) + nu = exp(nu_non_centered * nu_sd + log_nu_mean) + epsilon = exp(epsilon_non_centered * epsilon_sd + log_epsilon_mean) + # Parameters for hospital + sigma_hosp = clamp.(sigma_hosp_non_centered * sigma_hosp_sd + sigma_hosp_mean, min_neg_bin_sigma, max_neg_bin_sigma) + # Non-constant Rt + Rt_init = exp(Rt_init_non_centered * Rt_init_sd + Rt_init_mean) + sigma_Rt = exp(sigma_Rt_non_centered * sigma_Rt_sd + sigma_Rt_mean) + alpha_t_no_init = exp.(log(Rt_init) .+ cumsum(log_Rt_steps_non_centered) * sigma_Rt) * nu + alpha_init = Rt_init * nu + alpha_t = vcat(alpha_init, alpha_t_no_init) + # Non-constant Hosp Prob w + w_init_logit = w_init_non_centered * w_init_sd + w_init_mean + sigma_w = exp(sigma_w_non_centered * sigma_w_sd + sigma_w_mean) + logit_w_no_init = w_init_logit .+ cumsum(logit_w_steps_non_centered) * sigma_w + w_init = logistic(w_init_logit) + w_no_init = logistic.(logit_w_no_init) + w_t = vcat(w_init, w_no_init) + + + # ODE SETUP-------------------------- + prob = ODEProblem{true}(eihr_ode!, zeros(3), (0.0, obstimes[end]), ones(5)) + function param_affect_beta!(integrator) + ind_t = searchsortedfirst(param_change_times, integrator.t) # Find the index of param_change_times that contains the current timestep + integrator.p[1] = alpha_t_no_init[ind_t] # Replace alpha with a new value from alpha_t_no_init + integrator.p[4] = w_no_init[ind_t] # Replace w with a new value from w_no_init + end + param_callback = PresetTimeCallback(param_change_times, param_affect_beta!, save_positions=(false, false)) + u0 = [E_init, I_init, H_init] + p0 = [alpha_init, gamma, nu, w_init, epsilon] + extra_ode_precision = false + abstol = extra_ode_precision ? 1e-11 : 1e-9 + reltol = extra_ode_precision ? 1e-8 : 1e-6 + sol = solve(prob, Tsit5(); callback=param_callback, saveat=obstimes, save_start=true, + verbose=false, abstol=abstol, reltol=reltol, u0=u0, p=p0, tspan=(0.0, obstimes[end])) + # If the ODE solver fails, reject the sample by adding -Inf to the likelihood + if sol.retcode != :Success + Turing.@addlogprob! -Inf + return + end + sol_array = Array(sol) + I_comp_sol = clamp.(sol_array[2,2:end],1, 1e10) + + + # Likelihood calculations------------ + sol_hosp = clamp.(sol_array[3,2:end], 1, 1e10) + for i in 1:l_obs + data_hosp[i] ~ NegativeBinomial2(sol_hosp[i], sigma_hosp) + end + + + # Generated quantities + H_comp = sol_array[3, :] + rt_vals = alpha_t_no_init / nu + rt_init = alpha_init / nu + w_t = w_no_init + + return ( + E_init, + I_init, + H_init, + alpha_t = alpha_t, + gamma = gamma, + nu = nu, + w_t = w_t, + epsilon = epsilon, + rt_vals = rt_vals, + sigma_hosp = sigma_hosp, + H = H_comp, + rt_init = rt_init, + w_init = w_init + ) + + + end + \ No newline at end of file