diff --git a/EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl b/EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl index 08d55e10b..0d50f9c97 100644 --- a/EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl +++ b/EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl @@ -6,6 +6,7 @@ The `HierarchicalNormal` struct represents a non-centered hierarchical normal di - `HierarchicalNormal(mean, std_prior)`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior. - `HierarchicalNormal(; mean = 0.0, std_prior = truncated(Normal(0,0.1), 0, Inf))`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior using named arguments and with default values. - `HierarchicalNormal(std_prior)`: Constructs a `HierarchicalNormal` instance with the specified standard deviation prior. +- `HierarchicalNormal(mean, std_prior)`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior. ## Examples @@ -43,6 +44,10 @@ function HierarchicalNormal(std_prior::Distribution) return HierarchicalNormal(; std_prior = std_prior) end +function HierarchicalNormal(mean::Real, std_prior::Distribution) + return HierarchicalNormal(mean, std_prior, mean != 0) +end + @doc raw" function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n) diff --git a/EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl b/EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl index 1cdb42c02..4f6287447 100644 --- a/EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl +++ b/EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl @@ -1,6 +1,6 @@ @testitem "HierarchicalNormal constructor" begin using Distributions - int = HierarchicalNormal(0.1, truncated(Normal(0, 2), 0, Inf), true) + int = HierarchicalNormal(0.1, truncated(Normal(0, 2), 0, Inf)) @test typeof(int) <: AbstractTuringLatentModel @test int.mean == 0.1 @test int.std_prior == truncated(Normal(0, 2), 0, Inf) @@ -19,7 +19,7 @@ end using HypothesisTests: ExactOneSampleKSTest, pvalue using Distributions - hnorm = HierarchicalNormal(0.2, truncated(Normal(0, 1), 0, Inf), true) + hnorm = HierarchicalNormal(0.2, truncated(Normal(0, 1), 0, Inf)) hnorm_model = generate_latent(hnorm, 10) hnorm_model_out = hnorm_model() @test length(hnorm_model_out) == 10