Skip to content

Commit

Permalink
add a new constructor to get old default behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Dec 10, 2024
1 parent 2be792b commit d4023db
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The `HierarchicalNormal` struct represents a non-centered hierarchical normal di
- `HierarchicalNormal(mean, std_prior)`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior.
- `HierarchicalNormal(; mean = 0.0, std_prior = truncated(Normal(0,0.1), 0, Inf))`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior using named arguments and with default values.
- `HierarchicalNormal(std_prior)`: Constructs a `HierarchicalNormal` instance with the specified standard deviation prior.
- `HierarchicalNormal(mean, std_prior)`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior.
## Examples
Expand Down Expand Up @@ -43,6 +44,10 @@ function HierarchicalNormal(std_prior::Distribution)
return HierarchicalNormal(; std_prior = std_prior)
end

function HierarchicalNormal(mean::Real, std_prior::Distribution)
return HierarchicalNormal(mean, std_prior, mean != 0)
end

@doc raw"
function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n)
Expand Down
4 changes: 2 additions & 2 deletions EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "HierarchicalNormal constructor" begin
using Distributions
int = HierarchicalNormal(0.1, truncated(Normal(0, 2), 0, Inf), true)
int = HierarchicalNormal(0.1, truncated(Normal(0, 2), 0, Inf))
@test typeof(int) <: AbstractTuringLatentModel
@test int.mean == 0.1
@test int.std_prior == truncated(Normal(0, 2), 0, Inf)
Expand All @@ -19,7 +19,7 @@ end
using HypothesisTests: ExactOneSampleKSTest, pvalue
using Distributions

hnorm = HierarchicalNormal(0.2, truncated(Normal(0, 1), 0, Inf), true)
hnorm = HierarchicalNormal(0.2, truncated(Normal(0, 1), 0, Inf))
hnorm_model = generate_latent(hnorm, 10)
hnorm_model_out = hnorm_model()
@test length(hnorm_model_out) == 10
Expand Down

0 comments on commit d4023db

Please sign in to comment.