From 14e68a9f8cefc85f507ed28ed6031d5e1d641428 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 10 Dec 2024 19:38:09 +0000 Subject: [PATCH] also try injecting Mooncake and Enzyme into case studies --- EpiAware/docs/Project.toml | 2 ++ .../src/showcase/replications/chatzilena-2019/index.jl | 6 +++--- .../docs/src/showcase/replications/mishra-2020/index.jl | 7 +++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 76b276e49..ff27d90d8 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -8,9 +8,11 @@ DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 7363732cd..3c117f511 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -31,7 +31,7 @@ using CSV, DataFramesMeta #Data wrangling using CairoMakie, PairPlots # ╔═╡ 14641441-dbea-4fdf-88e0-64a57da60ef7 -using ReverseDiff #Automatic differentiation backend +using ADTypes, Mooncake #Automatic differentiation backend # ╔═╡ a0d91258-8ab5-4adc-98f2-8f17b4bd685c begin #Date utility and set Random seed @@ -591,7 +591,7 @@ Starting from the initial guess, the MAP point is calculated rapidly in one pass # ╔═╡ 6796ae76-bc2d-4895-ba0a-5e2c23c50dfb map_fit_stoch_mdl = maximum_a_posteriori(stochastic_mdl; - adtype = AutoReverseDiff(), + adtype = AutoMooncake(config = nothing), initial_params = initial_guess ) @@ -603,7 +603,7 @@ Now we can run NUTS, sampling 1000 posterior draws per chain for 4 chains. # ╔═╡ 156272d7-56c4-4ac4-bf3e-7882f4edc144 chn2 = sample( stochastic_mdl, - NUTS(; adtype = AutoReverseDiff(true)), + NUTS(; adtype = AutoMooncake(config = nothing)), MCMCThreads(), 1000, 4; initial_params = fill(map_fit_stoch_mdl.values.array, 4) ) diff --git a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl index 1a7453b9b..50dd0c119 100644 --- a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl +++ b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl @@ -28,7 +28,10 @@ using CSV, DataFramesMeta #Data wrangling using CairoMakie, PairPlots, TimeSeries #Plotting backend # ╔═╡ 97b5374e-7653-4b3b-98eb-d8f73aa30580 -using ReverseDiff #Automatic differentiation backend +begin + using ADTypes, Enzyme #Automatic differentiation backend + Enzyme.API.runtimeActivity!(true) +end # ╔═╡ 1642dbda-4915-4e29-beff-bca592f3ec8d begin #Date utility and set Random seed @@ -391,7 +394,7 @@ num_threads = min(10, Threads.nthreads()) inference_method = EpiMethod( pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)], sampler = NUTSampler( - adtype = AutoReverseDiff(compile = true), + adtype = AutoEnzyme(), ndraws = 2000, nchains = num_threads, mcmc_parallel = MCMCThreads())