Skip to content

Commit

Permalink
Namedtuples for Adaptor and Integrators
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Oct 1, 2023
1 parent b2116c3 commit e573f2c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 66 deletions.
64 changes: 28 additions & 36 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
```julia
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing],
init_params = nothing, physdt = 1 / 20.0, nchains = 1,
autodiff = false, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, MCMCargs = (n_leapfrog=30),
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
```
Expand Down Expand Up @@ -50,17 +48,16 @@ chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6,
alg = NeuralPDE.BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0),
n_leapfrog = 30, progress = true)
priorsNNw = (0.0, 3.0), progress = true)
sol_lux = solve(prob, alg)
# with parameter estimation
alg = NeuralPDE.BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
n_leapfrog = 30, progress = true)
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
sol_lux_pestim = solve(prob, alg)
```
Expand All @@ -83,7 +80,8 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems"
"""
struct BNNODE{C, K, IT, A, M, H <: Union{Int64, NamedTuple},
struct BNNODE{C, K, IT <: NamedTuple,
A <: NamedTuple, H <: NamedTuple,
ST <: Union{Nothing, AbstractTrainingStrategy},
I <: Union{Nothing, Vector{<:AbstractFloat}},
P <: Union{Nothing, Vector{<:Distribution}},
Expand All @@ -100,31 +98,29 @@ struct BNNODE{C, K, IT, A, M, H <: Union{Int64, NamedTuple},
phystd::Vector{Float64}
dataset::D
physdt::Float64
MCMCargs::H
MCMCkwargs::H
nchains::Int64
init_params::I
Integrator::IT
Adaptor::A
Metric::M
targetacceptancerate::Float64
jitter_rate::Float64
tempering_rate::Float64
Adaptorkwargs::A
Integratorkwargs::IT
autodiff::Bool
progress::Bool
verbose::Bool
end
function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCargs = (n_leapfrog = 30), nchains = 1,
init_params = nothing, Integrator = Leapfrog, Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8, jitter_rate = 3.0,
tempering_rate = 3.0, autodiff = false, progress = false, verbose = false)
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,), nchains = 1,
init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
autodiff = false, progress = false, verbose = false)
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCargs,
nchains, init_params, Integrator,
Adaptor, Metric, targetacceptancerate,
jitter_rate, tempering_rate,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
autodiff, progress, verbose)
end

Expand Down Expand Up @@ -184,9 +180,9 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params, Integrator, Adaptor, Metric,
nchains, physdt, targetacceptancerate, jitter_rate, tempering_rate,
MCMCargs, autodiff, progress, verbose = alg
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand All @@ -207,13 +203,9 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
nchains = nchains,
autodiff = autodiff,
Kernel = Kernel,
Integrator = Integrator,
Adaptor = Adaptor,
targetacceptancerate = targetacceptancerate,
Metric = Metric,
jitter_rate = jitter_rate,
tempering_rate = tempering_rate,
MCMCargs = MCMCargs,
Adaptorkwargs = Adaptorkwargs,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)

Expand Down
56 changes: 34 additions & 22 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,25 +335,27 @@ function NNodederi(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool)
end
end

function kernelchoice(Kernel, MCMCargs)
function kernelchoice(Kernel, MCMCkwargs)
if Kernel == HMCDA
δ, λ = MCMCargs[], MCMCargs[]
δ, λ = MCMCkwargs[], MCMCkwargs[]
Kernel(δ, λ)
elseif Kernel == NUTS
δ, max_depth, Δ_max = MCMCargs[], MCMCargs[:max_depth], MCMCargs[:Δ_max]
δ, max_depth, Δ_max = MCMCkwargs[], MCMCkwargs[:max_depth], MCMCkwargs[:Δ_max]
Kernel(δ, max_depth = max_depth, Δ_max = Δ_max)
else
# HMC
n_leapfrog = MCMCargs
n_leapfrog = MCMCkwargs[:n_leapfrog]
Kernel(n_leapfrog)
end
end

function integratorchoice(Integrator, initial_ϵ, jitter_rate,
tempering_rate)
function integratorchoice(Integratorkwargs, initial_ϵ)
Integrator = Integratorkwargs[:Integrator]
if Integrator == JitteredLeapfrog
jitter_rate = Integratorkwargs[:jitter_rate]
Integrator(initial_ϵ, jitter_rate)
elseif Integrator == TemperedLeapfrog
tempering_rate = Integratorkwargs[:tempering_rate]
Integrator(initial_ϵ, tempering_rate)
else
Integrator(initial_ϵ)
Expand All @@ -374,11 +376,12 @@ ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining,
dataset = [nothing],init_params = nothing,
draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [],nchains = 1,autodiff = false, Kernel = HMC,
Integrator = Leapfrog, Adaptor = StanHMCAdaptor,
targetacceptancerate = 0.8, Metric = DiagEuclideanMetric,
jitter_rate = 3.0, tempering_rate = 3.0,
MCMCargs = (n_leapfrog = 30), progress = false,verbose = false)
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
```
!!! warn
Expand Down Expand Up @@ -444,8 +447,14 @@ Incase you are only solving the Equations for solution, do not provide dataset
# AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
* `Kernel`: Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implemenations HMC/NUTS/HMCDA)
* `targetacceptancerate`: Target percentage(in decimal) of iterations in which the proposals were accepted(0.8 by default)
* `Integrator(jitter_rate, tempering_rate), Metric, Adaptor`: https://turinglang.org/AdvancedHMC.jl/stable/
* `Integratorkwargs`: A NamedTuple containing the chosen integrator and its keyword Arguments, as follows :
* `Integrator`: https://turinglang.org/AdvancedHMC.jl/stable/
* `jitter_rate`: https://turinglang.org/AdvancedHMC.jl/stable/
* `tempering_rate`: https://turinglang.org/AdvancedHMC.jl/stable/
* `Adaptorkwargs`: A NamedTuple containing the chosen Adaptor, it's Metric and targetacceptancerate, as follows :
* `Adaptor`: https://turinglang.org/AdvancedHMC.jl/stable/
* `Metric`: https://turinglang.org/AdvancedHMC.jl/stable/
* `targetacceptancerate`: Target percentage(in decimal) of iterations in which the proposals were accepted(0.8 by default)
* `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's(HMC/NUTS/HMCDA) Arguments, as follows :
* `n_leapfrog`: number of leapfrog steps for HMC
* `δ`: target acceptance probability for NUTS and HMCDA
Expand All @@ -467,10 +476,11 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false,
Kernel = HMC, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, MCMCargs = (n_leapfrog = 30),
Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)

# NN parameter prior mean and variance(PriorsNN must be a tuple)
Expand Down Expand Up @@ -544,6 +554,9 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
end
end

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]

# Define Hamiltonian system (nparameters ~ dimensionality of the sampling space)
metric = Metric(nparameters)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)
Expand All @@ -560,12 +573,11 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
initial_θ = vcat(randn(nparameters - ninv),
initial_θ[(nparameters - ninv + 1):end])
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integrator, initial_ϵ, jitter_rate,
tempering_rate)
integrator = integratorchoice(Integratorkwargs, initial_ϵ)
adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

MCMC_alg = kernelchoice(Kernel, MCMCargs)
MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
progress = progress, verbose = verbose)
Expand All @@ -579,11 +591,11 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
return chains, samplesc, statsc
else
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integrator, initial_ϵ, jitter_rate, tempering_rate)
integrator = integratorchoice(Integratorkwargs, initial_ϵ)
adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

MCMC_alg = kernelchoice(Kernel, MCMCargs)
MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)
Expand Down
12 changes: 4 additions & 8 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,19 @@ fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chainflux1,
param = [
LogNormal(9,
0.5),
],
Metric = DiagEuclideanMetric)
])

fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chainlux1,
dataset = dataset,
draw_samples = 2500,
physdt = 1 / 50.0,
priorsNNw = (0.0, 3.0),
param = [LogNormal(9, 0.5)],
Metric = DiagEuclideanMetric)
param = [LogNormal(9, 0.5)])

alg = NeuralPDE.BNNODE(chainflux1, dataset = dataset,
draw_samples = 2500, physdt = 1 / 50.0,
priorsNNw = (0.0, 3.0),
param = [LogNormal(9, 0.5)],
Metric = DiagEuclideanMetric)
param = [LogNormal(9, 0.5)])

sol2flux = solve(prob, alg)

Expand All @@ -160,8 +157,7 @@ alg = NeuralPDE.BNNODE(chainlux1, dataset = dataset,
param = [
LogNormal(9,
0.5),
],
Metric = DiagEuclideanMetric)
])

sol2lux = solve(prob, alg)

Expand Down

0 comments on commit e573f2c

Please sign in to comment.