diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index b85a775eb..27b68ed0f 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -5,10 +5,10 @@ function default_rw_priors() ) end -@model function random_walk(n; latent_process_priors = default_rw_priors()) +@model function random_walk(n; kwargs...) ϵ_t ~ MvNormal(ones(n)) - σ²_RW ~ latent_process_priors.var_RW_dist - init ~ latent_process_priors.init_rw_value_dist + σ²_RW ~ kwargs[:var_RW_dist] + init ~ kwargs[:init_rw_value_dist] σ_RW = sqrt(σ²_RW) rw = Vector{eltype(ϵ_t)}(undef, n) diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 8d896de85..8b3e73064 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -9,7 +9,7 @@ time_steps = epimodel.data.time_horizon @submodel latent_process, init, latent_process_aux = latent_process_obj.latent_process( time_steps; - latent_process_priors = latent_process_obj.latent_process_priors + latent_process_obj.latent_process_priors... ) #Transform into infections @@ -20,8 +20,8 @@ y_t, I_t, epimodel::AbstractEpiModel; - observation_process_priors = observation_process_obj.observation_model_priors, - pos_shift = pos_shift + pos_shift = pos_shift, + observation_process_obj.observation_model_priors... ) #Generate quantities diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl index 3389a5cfd..34171c55b 100644 --- a/EpiAware/src/observation-processes.jl +++ b/EpiAware/src/observation-processes.jl @@ -6,14 +6,13 @@ end y_t, I_t, epimodel::AbstractEpiModel; - observation_process_priors = default_delay_obs_priors(), - pos_shift = 1e-6 + kwargs... ) #Parameters - neg_bin_cluster_factor ~ observation_process_priors.neg_bin_cluster_factor_prior + neg_bin_cluster_factor ~ kwargs[:neg_bin_cluster_factor_prior] #Predictive distribution - case_pred_dists = (epimodel.data.delay_kernel * I_t) .+ pos_shift .|> + case_pred_dists = (epimodel.data.delay_kernel * I_t) .+ kwargs[:pos_shift] .|> μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor) #Likelihood diff --git a/EpiAware/test/test_latent-processes.jl b/EpiAware/test/test_latent-processes.jl index 5541e515a..867b90640 100644 --- a/EpiAware/test/test_latent-processes.jl +++ b/EpiAware/test/test_latent-processes.jl @@ -2,7 +2,8 @@ @testitem "Testing random_walk against theoretical properties" begin using DynamicPPL, Turing n = 5 - model = EpiAware.random_walk(n) + priors = EpiAware.default_rw_priors() + model = EpiAware.random_walk(n; priors...) fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process n_samples = 1000 samples_day_5 = sample(fixed_model, Prior(), n_samples) |> diff --git a/EpiAware/test/test_observation-processes.jl b/EpiAware/test/test_observation-processes.jl index 1386d3bb5..ae239ab70 100644 --- a/EpiAware/test/test_observation-processes.jl +++ b/EpiAware/test/test_observation-processes.jl @@ -8,14 +8,15 @@ data = EpiData([0.2, 0.3, 0.5], [1.0], 0.8, 3, exp) epimodel = DirectInfections(data) # Set up priors - observation_process_priors = default_delay_obs_priors() + priors = default_delay_obs_priors() # Call the function mdl = EpiAware.delay_observations( missing, I_t, epimodel; - observation_process_priors = observation_process_priors + pos_shift = 1e-6, + priors... ) fix_mdl = fix(mdl, neg_bin_cluster_factor = 0.00001) # Effectively Poisson sampling