diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index eda5df082..eabf12ec8 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -1,8 +1,8 @@ function default_rw_priors() return ( - var_RW_prior = truncated(Normal(0.0, 0.05), 0.0, Inf), - init_rw_value_prior = Normal() - ) + :var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf), + :init_rw_value_prior => Normal() + ) |> Dict end @model function random_walk(n; var_RW_prior, init_rw_value_prior) @@ -29,9 +29,9 @@ A struct representing a latent process with its priors. - `latent_process_priors`: NamedTuple containing the priors for the latent process. """ -struct LatentProcess{F <: Function} +struct LatentProcess{F <: Function, D <: Distribution} latent_process::F - latent_process_priors::NamedTuple + latent_process_priors::Dict{Symbol, D} end """ diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl index cc5528d13..023c67dcc 100644 --- a/EpiAware/src/observation-processes.jl +++ b/EpiAware/src/observation-processes.jl @@ -1,5 +1,5 @@ function default_delay_obs_priors() - return (neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),) + return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict end @model function delay_observations( @@ -23,18 +23,18 @@ end end """ - struct ObservationModel{F<:Function} + struct ObservationModel{F <: Function, D<:Distribution} -A struct representing an observation model with its priors. +A struct representing an observation model. # Fields -- `observation_model`: The observation model function for a `Turing` model. -- `observation_model_priors`: NamedTuple containing the priors for the observation model. +- `observation_model`: The observation model function. +- `observation_model_priors`: A dictionary of prior distributions for the observation model parameters. """ -struct ObservationModel{F <: Function} +struct ObservationModel{F <: Function, D <: Distribution} observation_model::F - observation_model_priors::NamedTuple + observation_model_priors::Dict{Symbol, D} end """ diff --git a/EpiAware/test/test_latent-processes.jl b/EpiAware/test/test_latent-processes.jl index 208574557..9177ada74 100644 --- a/EpiAware/test/test_latent-processes.jl +++ b/EpiAware/test/test_latent-processes.jl @@ -27,13 +27,13 @@ end @testitem "Testing default_rw_priors" begin @testset "var_RW_prior" begin priors = default_rw_priors() - var_RW = rand(priors.var_RW_prior) + var_RW = rand(priors[:var_RW_prior]) @test var_RW >= 0.0 end @testset "init_rw_value_prior" begin priors = default_rw_priors() - init_rw_value = rand(priors.init_rw_value_prior) + init_rw_value = rand(priors[:init_rw_value_prior]) @test typeof(init_rw_value) == Float64 end end diff --git a/EpiAware/test/test_models.jl b/EpiAware/test/test_models.jl index 61e56b2e3..14ea43345 100644 --- a/EpiAware/test/test_models.jl +++ b/EpiAware/test/test_models.jl @@ -59,7 +59,6 @@ end # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 epimodel = Renewal(data)