diff --git a/EpiAware/src/EpiObsModels/StackObservationModels.jl b/EpiAware/src/EpiObsModels/StackObservationModels.jl index 440112867..778de5e9e 100644 --- a/EpiAware/src/EpiObsModels/StackObservationModels.jl +++ b/EpiAware/src/EpiObsModels/StackObservationModels.jl @@ -78,7 +78,7 @@ Generate observations from a stack of observation models. Assumes a 1 to 1 mappi @assert obs_model.model_names==keys(y_t) .|> string |> collect "The model names must match the keys of the observation datasets." @assert keys(y_t)==keys(Y_t) "The keys of the observed and true values must match." - obs = map((obs_model.models, obs_model.model_names)) do (model, model_name) + obs = map(zip(obs_model.models, obs_model.model_names)) do (model, model_name) @submodel obs_tmp = generate_observations( model, y_t[Symbol(model_name)], Y_t[Symbol(model_name)]) return obs_tmp diff --git a/EpiAware/src/EpiObsModels/modifiers/ascertainment/Ascertainment.jl b/EpiAware/src/EpiObsModels/modifiers/ascertainment/Ascertainment.jl index 88ce3fc62..bde596953 100644 --- a/EpiAware/src/EpiObsModels/modifiers/ascertainment/Ascertainment.jl +++ b/EpiAware/src/EpiObsModels/modifiers/ascertainment/Ascertainment.jl @@ -69,7 +69,7 @@ Generates observations based on the `LatentDelay` observation model. - `obs_aux`: Additional observation-related variables. " @model function EpiAwareBase.generate_observations(obs_model::Ascertainment, y_t, Y_t) - @submodel expected_obs_mod, expected_aux = generate_latent( + @submodel expected_obs_mod = generate_latent( obs_model.latent_model, length(Y_t)) expected_obs = Y_t .* obs_model.link(expected_obs_mod) diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl index 8867680d5..780e9ffb5 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl @@ -12,14 +12,14 @@ @testset "Test with entirely missing data" begin mdl = generate_observations(obs_model, missing, I_t) - @test isapprox(mdl()[1], I_t, atol = 1e-3) + @test isapprox(mdl(), I_t, atol = 1e-3) end missing_I_t = vcat(missing, I_t) @testset "Test with leading missing expected observations" begin mdl = generate_observations(obs_model, missing_I_t, vcat(20, I_t)) - draw = mdl()[1] + draw = mdl() @test draw[2:end] == I_t @test abs(draw[1] - 20) > 0 @test isapprox(draw[1], 20, atol = 1e-3) diff --git a/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl b/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl index 91ad3273d..2dca40f32 100644 --- a/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl +++ b/EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl @@ -56,7 +56,7 @@ end n_samples = 1000 first_obs = sample(mdl, Prior(), n_samples; progress = false) |> chn -> generated_quantities(fix_mdl, chn) .|> - (gen -> gen[1][1]) |> + (gen -> gen[1]) |> vec direct_samples = EpiAware.EpiObsModels.NegativeBinomialMeanClust( I_t[1], neg_bin_cf^2) |> @@ -93,7 +93,6 @@ end mdl = generate_observations(delay_obs, y_t_scenario, I_t) sampled_obs = sample(mdl, Prior(), 1000; progress = false) |> chn -> generated_quantities(mdl, chn) .|> - (gen -> gen[1]) |> collect # Calculate mean of generated quantities @@ -134,20 +133,20 @@ end @testset "Test with entirely missing data" begin mdl = generate_observations(obs_model, missing, I_t) - @test mdl()[1][3:end] == expected_obs[3:end] - @test sum(mdl()[1] .|> ismissing) == 2 + @test mdl()[3:end] == expected_obs[3:end] + @test sum(mdl() .|> ismissing) == 2 end @testset "Test with missing data defined as a vector" begin mdl = generate_observations( obs_model, [missing, missing, missing, missing, missing], I_t) - @test mdl()[1][3:end] == expected_obs[3:end] - @test sum(mdl()[1] .|> ismissing) == 2 + @test mdl()[3:end] == expected_obs[3:end] + @test sum(mdl() .|> ismissing) == 2 end @testset "Test with data" begin pois_obs_model = LatentDelay(PoissonError(), delay_int) mdl = generate_observations(pois_obs_model, [10.0, 20.0, 30.0, 40.0, 50.0], I_t) - @test mdl()[1] == [10.0, 20.0, 30.0, 40.0, 50] + @test mdl() == [10.0, 20.0, 30.0, 40.0, 50] end end diff --git a/EpiAware/test/EpiObsModels/modifiers/ascertainment/Ascertainment.jl b/EpiAware/test/EpiObsModels/modifiers/ascertainment/Ascertainment.jl index 8afdc4b39..c0214078a 100644 --- a/EpiAware/test/EpiObsModels/modifiers/ascertainment/Ascertainment.jl +++ b/EpiAware/test/EpiObsModels/modifiers/ascertainment/Ascertainment.jl @@ -21,11 +21,21 @@ end # make a test based on above example @testitem "Test Ascertainment generate_observations" begin using Turing, DynamicPPL - obs = Ascertainment(NegativeBinomialError(), FixedIntercept(0.1); link = x -> x) + + struct ExpectedObs <: AbstractTuringObservationModel + model::AbstractTuringObservationModel + end + + @model EpiAware.EpiAwareBase.generate_observations(model::ExpectedObs, y_t, Y_t) = begin + expected_obs := Y_t + @submodel y_t = generate_observations(model.model, y_t, Y_t) + end + + obs = Ascertainment( + ExpectedObs(NegativeBinomialError()), FixedIntercept(0.1); link = x -> x) gen_obs = generate_observations(obs, missing, fill(100, 10)) samples = sample(gen_obs, Prior(), 100; progress = false) - gen = mapreduce(vcat, generated_quantities(gen_obs, samples)) do gen - gen[2][:expected_obs] - end + gen = get(samples, :expected_obs).expected_obs |> + x -> vcat(x...) @test all(gen .== 10.0) end diff --git a/EpiAware/test/EpiObsModels/modifiers/ascertainment/helpers.jl b/EpiAware/test/EpiObsModels/modifiers/ascertainment/helpers.jl index 0fdae8669..bcf7ce20e 100644 --- a/EpiAware/test/EpiObsModels/modifiers/ascertainment/helpers.jl +++ b/EpiAware/test/EpiObsModels/modifiers/ascertainment/helpers.jl @@ -1,6 +1,16 @@ @testitem "ascertainment_dayofweek correctly constructs a day of week ascertainment model" begin - using DynamicPPL, LogExpFunctions - obs = ascertainment_dayofweek(PoissonError()) + using DynamicPPL, LogExpFunctions, Turing, DataFrames + + struct ExpectedObs <: AbstractTuringObservationModel + model::AbstractTuringObservationModel + end + + @model EpiAware.EpiAwareBase.generate_observations(model::ExpectedObs, y_t, Y_t) = begin + expected_obs := Y_t + @submodel y_t = generate_observations(model.model, y_t, Y_t) + end + + obs = ascertainment_dayofweek(ExpectedObs(PoissonError())) incidence_each_ts = 100.0 nweeks = 2 @@ -10,6 +20,12 @@ expected_obs = repeat(7 * softmax(dayofweek_effect) .* incidence_each_ts, nweeks) fix_obs_model = fix( obs_model, (var"DayofWeek.ϵ_t" = dayofweek_effect, var"DayofWeek.std" = 1)) - gq_expected_obs = fix_obs_model()[2].expected_obs - @test expected_obs ≈ gq_expected_obs + samples = sample(fix_obs_model, Prior(), 10; progress = false) + + gq_expected_obs = get(samples, :expected_obs).expected_obs |> + x -> hcat(x...) |> + #iterate by row of a matrix + x -> map(eachrow(x)) do row + @test row ≈ expected_obs + end end