Skip to content

Commit

Permalink
Merge pull request #73 from CDCgov/72-replace-named-iterator-for-prio…
Browse files Browse the repository at this point in the history
…rs-with-dict-objects

Changed process priors into Dicts rather than NamedTuples
  • Loading branch information
seabbs authored Feb 26, 2024
2 parents af4b679 + d1c7bf5 commit 6fd0d38
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
10 changes: 5 additions & 5 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function default_rw_priors()
return (
var_RW_prior = truncated(Normal(0.0, 0.05), 0.0, Inf),
init_rw_value_prior = Normal()
)
:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf),
:init_rw_value_prior => Normal()
) |> Dict
end

@model function random_walk(n; var_RW_prior, init_rw_value_prior)
Expand All @@ -29,9 +29,9 @@ A struct representing a latent process with its priors.
- `latent_process_priors`: NamedTuple containing the priors for the latent process.
"""
struct LatentProcess{F <: Function}
struct LatentProcess{F <: Function, D <: Distribution}
latent_process::F
latent_process_priors::NamedTuple
latent_process_priors::Dict{Symbol, D}
end

"""
Expand Down
14 changes: 7 additions & 7 deletions EpiAware/src/observation-processes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function default_delay_obs_priors()
return (neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),)
return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict
end

@model function delay_observations(
Expand All @@ -23,18 +23,18 @@ end
end

"""
struct ObservationModel{F<:Function}
struct ObservationModel{F <: Function, D<:Distribution}
A struct representing an observation model with its priors.
A struct representing an observation model.
# Fields
- `observation_model`: The observation model function for a `Turing` model.
- `observation_model_priors`: NamedTuple containing the priors for the observation model.
- `observation_model`: The observation model function.
- `observation_model_priors`: A dictionary of prior distributions for the observation model parameters.
"""
struct ObservationModel{F <: Function}
struct ObservationModel{F <: Function, D <: Distribution}
observation_model::F
observation_model_priors::NamedTuple
observation_model_priors::Dict{Symbol, D}
end

"""
Expand Down
4 changes: 2 additions & 2 deletions EpiAware/test/test_latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ end
@testitem "Testing default_rw_priors" begin
@testset "var_RW_prior" begin
priors = default_rw_priors()
var_RW = rand(priors.var_RW_prior)
var_RW = rand(priors[:var_RW_prior])
@test var_RW >= 0.0
end

@testset "init_rw_value_prior" begin
priors = default_rw_priors()
init_rw_value = rand(priors.init_rw_value_prior)
init_rw_value = rand(priors[:init_rw_value_prior])
@test typeof(init_rw_value) == Float64
end
end
1 change: 0 additions & 1 deletion EpiAware/test/test_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ end
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp)
process_priors = merge(default_rw_priors(), default_delay_obs_priors())
pos_shift = 1e-6

epimodel = Renewal(data)
Expand Down

0 comments on commit 6fd0d38

Please sign in to comment.