From 4ea4c3cad95d442e331cc2c88a071ffb7bc8107b Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 23 Feb 2024 15:11:18 +0000 Subject: [PATCH] rename `_dist` postfixes to `_prior` when used as a prior --- EpiAware/src/latent-processes.jl | 10 +++++----- .../prior_predictive_checking/ppc-latent-processes.jl | 6 +++--- EpiAware/test/test_latent-processes.jl | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index c6ecd9ebd..eda5df082 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -1,14 +1,14 @@ function default_rw_priors() return ( - var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf), - init_rw_value_dist = Normal() + var_RW_prior = truncated(Normal(0.0, 0.05), 0.0, Inf), + init_rw_value_prior = Normal() ) end -@model function random_walk(n; var_RW_dist, init_rw_value_dist) +@model function random_walk(n; var_RW_prior, init_rw_value_prior) ϵ_t ~ MvNormal(ones(n)) - σ²_RW ~ var_RW_dist - init ~ init_rw_value_dist + σ²_RW ~ var_RW_prior + init ~ init_rw_value_prior σ_RW = sqrt(σ²_RW) rw = Vector{eltype(ϵ_t)}(undef, n) diff --git a/EpiAware/test/prior_predictive_checking/ppc-latent-processes.jl b/EpiAware/test/prior_predictive_checking/ppc-latent-processes.jl index 4ec65f2e8..462b267d8 100644 --- a/EpiAware/test/prior_predictive_checking/ppc-latent-processes.jl +++ b/EpiAware/test/prior_predictive_checking/ppc-latent-processes.jl @@ -11,7 +11,7 @@ using Plots.PlotMeasures using EpiAware Random.seed!(0) n = 30 -latent_process_priors = (var_RW_dist = truncated(Normal(0.0, 0.5), 0.0, Inf),) +latent_process_priors = (var_RW_prior = truncated(Normal(0.0, 0.5), 0.0, Inf),) model = random_walk(n; latent_process_priors = latent_process_priors) n_samples = 2000 @@ -20,7 +20,7 @@ sampled_walks = prior_chn |> chn -> mapreduce(hcat, generated_quantities(model, gen[1] end ## From law of total variance and known mean of HalfNormal distribution -theoretical_std = [t * latent_process_priors.var_RW_dist.untruncated.σ * sqrt(2) / sqrt(π) +theoretical_std = [t * latent_process_priors.var_RW_prior.untruncated.σ * sqrt(2) / sqrt(π) for t in 1:n] .|> sqrt plt_ppc_rw = plot( @@ -46,7 +46,7 @@ plot!( ) plot!( σ_hist, - latent_process_priors.var_RW_dist, + latent_process_priors.var_RW_prior, lw = 2, c = :red, alpha = 0.5, diff --git a/EpiAware/test/test_latent-processes.jl b/EpiAware/test/test_latent-processes.jl index 867b90640..208574557 100644 --- a/EpiAware/test/test_latent-processes.jl +++ b/EpiAware/test/test_latent-processes.jl @@ -25,15 +25,15 @@ (var(samples_day_5) - 5) > -5 * theoretical_std_of_empiral_var end @testitem "Testing default_rw_priors" begin - @testset "var_RW_dist" begin + @testset "var_RW_prior" begin priors = default_rw_priors() - var_RW = rand(priors.var_RW_dist) + var_RW = rand(priors.var_RW_prior) @test var_RW >= 0.0 end - @testset "init_rw_value_dist" begin + @testset "init_rw_value_prior" begin priors = default_rw_priors() - init_rw_value = rand(priors.init_rw_value_dist) + init_rw_value = rand(priors.init_rw_value_prior) @test typeof(init_rw_value) == Float64 end end