Skip to content

Commit

Permalink
done for now
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Mar 26, 2024
1 parent 585a4f5 commit cf77408
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 38 deletions.
12 changes: 8 additions & 4 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
Expand All @@ -112,13 +114,16 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, alg.draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
end

Expand Down Expand Up @@ -177,13 +182,12 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
reltol = 1.0f-3,
verbose = false,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3),
estim_collocate = false)
maxiters = nothing)

@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress, verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand Down
4 changes: 2 additions & 2 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,15 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
datafree_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)
train_sets_bc = nothing)[1]
for i in eachindex(datafree_colloc_loss_functions)]

function L2_loss2(θ, allstd)
stdpdesnew = allstd[4]

# first vector of losses,from tuple -> pde losses, first[1] pde loss
pde_loglikelihoods = [sum([pde_loss_function(θ, stdpdesnew[i])
for (i, pde_loss_function) in enumerate(pde_loss_functions[1])])
for (i, pde_loss_function) in enumerate(pde_loss_functions)])
for pde_loss_functions in pde_loss_function_points]

# bc_loglikelihoods = [sum([bc_loss_function(θ, stdpdesnew[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions]
Expand Down
2 changes: 1 addition & 1 deletion test/BPINN_PDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Random.seed!(100)
u_predict = pmean(sol1.ensemblesol[1])

@test u_predictu_real atol=0.05
@test mean(u_predict .- u_real) < 1e-5
@test mean(u_predict .- u_real) < 1e-3
end

@testset "Example 2: 1D ODE" begin
Expand Down
94 changes: 93 additions & 1 deletion test/BPINN_PDEinvsol_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,96 @@ end
p_ = sol1.estimated_de_params[1]
@test sum(abs, pmean(p_) - 10.00) < 0.3 * idealp[1]
# @test sum(abs, pmean(p_[2]) - (8 / 3)) < 0.3 * idealp[2]
end
end

function recur_expression(exp, Dict_differentials)
for in_exp in exp.args
if !(in_exp isa Expr)
# skip +,== symbols, characters etc
continue

elseif in_exp.args[1] isa ModelingToolkit.Differential
# first symbol of differential term
# Dict_differentials for masking differential terms
# and resubstituting differentials in equations after putting in interpolations
# temp = in_exp.args[end]
Dict_differentials[eval(in_exp)] = Symbolics.variable("diff_$(length(Dict_differentials) + 1)")
return
else
recur_expression(in_exp, Dict_differentials)
end
end
end

println("Example 3: 2D Periodic System with New parameter estimation")
@parameters t, p
@variables u(..)

Dt = Differential(t)
eqs = Dt(u(t)) - cos(p * t) * u(t) ~ 0
bcs = [u(0) ~ 0.0]
domains = [t Interval(0.0, 2.0)]

chainl = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 1))
initl, st = Lux.setup(Random.default_rng(), chainl)

@named pde_system = PDESystem(eqs,
bcs,
domains,
[t],
[u(t)],
[p],
defaults = Dict([p => 4.0]))

analytic_sol_func1(u0, t) = u0 + sin(2 * π * t) / (2 * π)
timepoints = collect(0.0:(1 / 100.0):2.0)
u1 = [analytic_sol_func1(0.0, timepoint) for timepoint in timepoints]
u1 = u1 .+ (u1 .* 0.2) .* randn(size(u1))
dataset = [hcat(u1, timepoints)]

discretization = BayesianPINN([chainl], GridTraining([0.02]), param_estim = true,
dataset = [dataset, nothing])

# creating dictionary for masking equations
eqs = pde_system.eqs
Dict_differentials = Dict()
exps = toexpr.(eqs)
nullobj = [recur_expression(exp, Dict_differentials) for exp in exps]

sol1 = ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 1500,
bcstd = [0.05],
phystd = [0.01], l2std = [0.01], phystdnew = [0.05],
priorsNNw = (0.0, 1.0),
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)],
Dict_differentials = Dict_differentials,
progress = true)

sol2 = ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 1500,
bcstd = [0.05],
phystd = [0.01], l2std = [0.01],
priorsNNw = (0.0, 1.0),
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)],
progress = true)

param = 2 * π
ts = vec(sol1.timepoints[1])
u_real = [analytic_sol_func1(0.0, t) for t in ts]
u_predict = pmean(sol1.ensemblesol[1])

@test u_predictu_real atol=1.5
@test mean(u_predict .- u_real) < 0.1
@test sol1.estimated_de_params[1]param atol=param * 0.3

ts = vec(sol2.timepoints[1])
u_real = [analytic_sol_func1(0.0, t) for t in ts]
u_predict = pmean(sol2.ensemblesol[1])

@test u_predictu_real atol=1.5
@test mean(u_predict .- u_real) < 0.1
@test sol1.estimated_de_params[1]param atol=param * 0.3
82 changes: 52 additions & 30 deletions test/bpinnexperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,32 @@ plot!(solution, labels = ["x" "y"])
chain = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh),
Lux.Dense(6, 2))

alg = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.1, 0.1],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 3.0),
param = [
Normal(1, 2),
Normal(2, 2),
Normal(2, 2),
Normal(0, 2)], progress = true)

@time sol_pestim1 = solve(prob, alg; saveat = dt,)
@time sol_pestim2 = solve(prob, alg; estim_collocate = true, saveat = dt)
alg1 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.1, 0.1],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 3.0),
param = [
Normal(1, 2),
Normal(2, 2),
Normal(2, 2),
Normal(0, 2)], progress = true)

alg2 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.1, 0.1],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 3.0),
param = [
Normal(1, 2),
Normal(2, 2),
Normal(2, 2),
Normal(0, 2)], estim_collocate = true, progress = true)

@time sol_pestim1 = solve(prob, alg1; saveat = dt)
@time sol_pestim2 = solve(prob, alg2; saveat = dt)
plot(times, sol_pestim1.ensemblesol[1], label = "estimated x1")
plot!(times, sol_pestim2.ensemblesol[1], label = "estimated x2")
plot!(times, sol_pestim1.ensemblesol[2], label = "estimated y1")
Expand All @@ -66,28 +78,29 @@ plot!(times, sol_pestim2.ensemblesol[2], label = "estimated y2")
# comparing it with the original solution
plot!(solution, labels = ["true x" "true y"])

@show sol_pestim1.estimated_ode_params
@show sol_pestim2.estimated_ode_params
@show sol_pestim1.estimated_de_params
@show sol_pestim2.estimated_de_params

function fitz(u, p , t)
function fitz(u, p, t)
v, w = u[1], u[2]
a,b,τinv,l = p[1], p[2], p[3], p[4]
dv = v - 0.33*v^3 -w + l
dw = τinv*(v + a - b*w)
a, b, τinv, l = p[1], p[2], p[3], p[4]

dv = v - 0.33 * v^3 - w + l
dw = τinv * (v + a - b * w)

return [dv, dw]
end

prob_ode_fitzhughnagumo = ODEProblem(fitz, [1.0,1.0], (0.0,10.0), [0.7,0.8,1/12.5,0.5])
prob_ode_fitzhughnagumo = ODEProblem(
fitz, [1.0, 1.0], (0.0, 10.0), [0.7, 0.8, 1 / 12.5, 0.5])
dt = 0.5
sol = solve(prob_ode_fitzhughnagumo, Tsit5(), saveat = dt)

sig = 0.20
data = Array(sol)
dataset = [data[1,:] .+ (sig .* rand(length(sol.t))), data[2, :] .+ (sig .* rand(length(sol.t))), sol.t]
priors = [Normal(0.5,1.0), Normal(0.5,1.0), Normal(0.0,0.5), Normal(0.5,1.0)]

dataset = [data[1, :] .+ (sig .* rand(length(sol.t))),
data[2, :] .+ (sig .* rand(length(sol.t))), sol.t]
priors = [Normal(0.5, 1.0), Normal(0.5, 1.0), Normal(0.0, 0.5), Normal(0.5, 1.0)]

plot(sol.t, dataset[1], label = "noisy x")
plot!(sol.t, dataset[2], label = "noisy y")
Expand All @@ -98,7 +111,7 @@ chain = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 10, tanh),

Adaptorkwargs = (Adaptor = AdvancedHMC.StanHMCAdaptor,
Metric = AdvancedHMC.DiagEuclideanMetric, targetacceptancerate = 0.8)
alg = BNNODE(chain;
alg1 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.1, 0.1],
Expand All @@ -107,12 +120,21 @@ priorsNNw = (0.01, 3.0),
Adaptorkwargs = Adaptorkwargs,
param = priors, progress = true)

@time sol_pestim3 = solve(prob_ode_fitzhughnagumo, alg; saveat = dt)
@time sol_pestim4 = solve(prob_ode_fitzhughnagumo, alg; estim_collocate = true, saveat = dt)
alg2 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.1, 0.1],
phystd = [0.1, 0.1],
priorsNNw = (0.01, 3.0),
Adaptorkwargs = Adaptorkwargs,
param = priors, estim_collocate = true, progress = true)

@time sol_pestim3 = solve(prob_ode_fitzhughnagumo, alg1; saveat = dt)
@time sol_pestim4 = solve(prob_ode_fitzhughnagumo, alg2; saveat = dt)
plot!(sol.t, sol_pestim3.ensemblesol[1], label = "estimated x1")
plot!(sol.t, sol_pestim4.ensemblesol[1], label = "estimated x2")
plot!(sol.t, sol_pestim3.ensemblesol[2], label = "estimated y1")
plot!(sol.t, sol_pestim4.ensemblesol[2], label = "estimated y2")

@show sol_pestim3.estimated_ode_params
@show sol_pestim4.estimated_ode_params
@show sol_pestim3.estimated_de_params
@show sol_pestim4.estimated_de_params

0 comments on commit cf77408

Please sign in to comment.