Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive Reweighting of BPINN Loglikelihood #798

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using Symbolics: wrap, unwrap, arguments, operation
using SymbolicUtils
using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains
using MonteCarloMeasurements

import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives
import DomainSets: Domain, ClosedInterval
import ModelingToolkit: Interval, infimum, supremum #,Ball
Expand Down
27 changes: 24 additions & 3 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ mutable struct PDELogTargetDensity{
end
end

LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim

function LogDensityProblems.capabilities(::PDELogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return Tar.full_loglikelihood(setparameters(Tar, θ),
Expand All @@ -87,11 +93,12 @@ function setparameters(Tar::PDELogTargetDensity, θ)

a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))

if Tar.extraparams > 0
if Tar.extraparams > 0
b = θ[(end - Tar.extraparams + 1):end]

return ComponentArrays.ComponentArray(;
depvar = a,
p = b)
depvar = a,
p = b)
else
return ComponentArrays.ComponentArray(;
depvar = a)
Expand Down Expand Up @@ -298,6 +305,12 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)

pinnrep.iteration = [0]


pinnrep.iteration = [0]

dataset_pde, dataset_bc = discretization.dataset

if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
Expand Down Expand Up @@ -428,12 +441,20 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
end
return bpinnsols
else
println("now 1")

println("now 1")

initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integratorkwargs, initial_ϵ)
adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

Kernel = AdvancedHMC.make_kernel(Kernel, integrator)
println("now 2")

println("now 2")

samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

Expand Down
Loading
Loading