Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs modifications to use QuadratureTraining instead of GridTraining #729

Merged
merged 143 commits into from
Oct 7, 2023
Merged
Changes from 1 commit
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
b179412
Docs modifications
Sep 3, 2023
70609cc
Removed last instances of GridTraining in tutorials
Sep 4, 2023
e1996c6
Removed dx and dt
Sep 5, 2023
b41c094
docs build fail cleanup
Sep 6, 2023
98638a5
Docs modifications
Sep 3, 2023
d2bd11a
Removed last instances of GridTraining in tutorials
Sep 4, 2023
136d52f
Removed dx and dt
Sep 5, 2023
aa87003
docs build fail cleanup
Sep 6, 2023
87433bc
tried adding a compat version to project.toml
Sep 27, 2023
60218bc
Merge branch 'docs_update' of github.com:sdesai1287/NeuralPDE.jl into…
Sep 27, 2023
ce0bd7d
Merge branch 'SciML:master' into docs_update
sdesai1287 Oct 2, 2023
7837d10
update Project.toml
Oct 2, 2023
5e20ed1
New PR
AstitvaAggarwal Aug 18, 2023
99c7384
Almost done ig
AstitvaAggarwal Aug 20, 2023
1009bff
ready player 1
AstitvaAggarwal Aug 21, 2023
6420f0d
added docs, minor changes, more tests
AstitvaAggarwal Aug 22, 2023
ffd2514
prev tests did not pass the vibe check
AstitvaAggarwal Aug 23, 2023
1cdd5d5
tests
AstitvaAggarwal Aug 23, 2023
444af19
test should pass
AstitvaAggarwal Aug 23, 2023
3f63ab0
ready player one
AstitvaAggarwal Aug 24, 2023
54dd7ee
reduced iters
AstitvaAggarwal Aug 24, 2023
fab3ced
more changes
AstitvaAggarwal Aug 24, 2023
bf06750
optimizing tests
AstitvaAggarwal Aug 24, 2023
037bffe
yuh
AstitvaAggarwal Aug 24, 2023
5fc60b5
.......
AstitvaAggarwal Aug 25, 2023
c25a0ee
pls work man
AstitvaAggarwal Aug 25, 2023
278beab
|TT|
AstitvaAggarwal Aug 25, 2023
2cabddf
[TT]
AstitvaAggarwal Aug 25, 2023
99a5862
statistics dependancy compatib
AstitvaAggarwal Aug 25, 2023
cae3107
im back
AstitvaAggarwal Aug 25, 2023
ecda132
Flux changes
AstitvaAggarwal Aug 25, 2023
c5ef920
.
AstitvaAggarwal Aug 25, 2023
75505e6
less std for weights
AstitvaAggarwal Aug 26, 2023
b43c531
less std for weights
AstitvaAggarwal Aug 26, 2023
1cd1c64
Cleaner Tests, all pass, handled edge cases
AstitvaAggarwal Aug 29, 2023
1d3553a
minor changes
AstitvaAggarwal Aug 29, 2023
183dca0
Julia versions affect accuracy
AstitvaAggarwal Aug 29, 2023
0ac0714
Added my suggested missing Loss function part, adjusted tests
AstitvaAggarwal Sep 1, 2023
71d9127
minor changes
AstitvaAggarwal Sep 1, 2023
faa0b8f
added example, fixed multi dependant variable case, verfied performan…
AstitvaAggarwal Sep 1, 2023
a56f960
fixed tests
AstitvaAggarwal Sep 2, 2023
346e863
float 64 flux layers
AstitvaAggarwal Sep 3, 2023
5ac7676
Bump actions/checkout from 3 to 4
dependabot[bot] Sep 4, 2023
dfec9e6
now uses diff training strategies,
AstitvaAggarwal Sep 9, 2023
5bb907b
tests
AstitvaAggarwal Sep 10, 2023
f2b57c0
relaxed tests
AstitvaAggarwal Sep 10, 2023
d3f3d87
relaxed tests
AstitvaAggarwal Sep 11, 2023
ba4b624
added docs
AstitvaAggarwal Sep 13, 2023
e755bc6
CompatHelper: bump compat for ComponentArrays to 0.15, (keep existing…
Sep 9, 2023
40faafa
Add BPINNs tutorial to docs pages
xtalax Sep 15, 2023
1813449
CompatHelper: add new compat entry for MonteCarloMeasurements at vers…
Sep 15, 2023
c95323c
CompatHelper: bump compat for SciMLBase to 2, (keep existing compat)
Sep 22, 2023
cd3e8b0
Update Project.toml
ChrisRackauckas Sep 22, 2023
87a1cf8
Remove v1.6 from CI
ChrisRackauckas Oct 1, 2023
09f84ff
Update pipeline.yml
ChrisRackauckas Oct 1, 2023
54f5917
Update Project.toml
ChrisRackauckas Oct 1, 2023
8f3252a
Update Project.toml
ChrisRackauckas Oct 1, 2023
c5f7a1b
update Project.toml
Oct 2, 2023
d3ca8b3
Merge branch 'docs_update' of github.com:sdesai1287/NeuralPDE.jl into…
Oct 2, 2023
b133d83
New PR
AstitvaAggarwal Aug 18, 2023
a626cad
Almost done ig
AstitvaAggarwal Aug 20, 2023
3db9f82
ready player 1
AstitvaAggarwal Aug 21, 2023
5cbc0ec
added docs, minor changes, more tests
AstitvaAggarwal Aug 22, 2023
486b876
prev tests did not pass the vibe check
AstitvaAggarwal Aug 23, 2023
cd6ceab
tests
AstitvaAggarwal Aug 23, 2023
bd20f19
test should pass
AstitvaAggarwal Aug 23, 2023
2b172bf
ready player one
AstitvaAggarwal Aug 24, 2023
947dd0a
reduced iters
AstitvaAggarwal Aug 24, 2023
c26e000
more changes
AstitvaAggarwal Aug 24, 2023
45dd116
optimizing tests
AstitvaAggarwal Aug 24, 2023
9fc1561
yuh
AstitvaAggarwal Aug 24, 2023
70917e2
.......
AstitvaAggarwal Aug 25, 2023
bfa76c4
pls work man
AstitvaAggarwal Aug 25, 2023
8b53efd
|TT|
AstitvaAggarwal Aug 25, 2023
1adf3e8
[TT]
AstitvaAggarwal Aug 25, 2023
64aafa0
statistics dependancy compatib
AstitvaAggarwal Aug 25, 2023
6af3bb0
im back
AstitvaAggarwal Aug 25, 2023
8fc955d
Flux changes
AstitvaAggarwal Aug 25, 2023
b46e478
.
AstitvaAggarwal Aug 25, 2023
98bda3c
less std for weights
AstitvaAggarwal Aug 26, 2023
6db19cb
less std for weights
AstitvaAggarwal Aug 26, 2023
4c0732b
Cleaner Tests, all pass, handled edge cases
AstitvaAggarwal Aug 29, 2023
dc4c02f
minor changes
AstitvaAggarwal Aug 29, 2023
a065b8b
Julia versions affect accuracy
AstitvaAggarwal Aug 29, 2023
1ecf50e
Added my suggested missing Loss function part, adjusted tests
AstitvaAggarwal Sep 1, 2023
1362f18
minor changes
AstitvaAggarwal Sep 1, 2023
e62d6dc
added example, fixed multi dependant variable case, verfied performan…
AstitvaAggarwal Sep 1, 2023
47994c9
fixed tests
AstitvaAggarwal Sep 2, 2023
a93f23b
float 64 flux layers
AstitvaAggarwal Sep 3, 2023
db6f4b4
now uses diff training strategies,
AstitvaAggarwal Sep 9, 2023
ef814f1
tests
AstitvaAggarwal Sep 10, 2023
3cee89d
relaxed tests
AstitvaAggarwal Sep 10, 2023
597f08f
relaxed tests
AstitvaAggarwal Sep 11, 2023
7264aab
added docs
AstitvaAggarwal Sep 13, 2023
6569a4e
CompatHelper: add new compat entry for MonteCarloMeasurements at vers…
Sep 15, 2023
61f1e86
CompatHelper: bump compat for SciMLBase to 2, (keep existing compat)
Sep 22, 2023
cc6e9c1
Docs modifications
Sep 3, 2023
894a4b5
Removed last instances of GridTraining in tutorials
Sep 4, 2023
87aa3e4
update Project.toml
Oct 2, 2023
0a97d60
Merge branch 'docs_update' of github.com:sdesai1287/NeuralPDE.jl into…
Oct 2, 2023
352e3e8
New PR
AstitvaAggarwal Aug 18, 2023
4b4a77c
Almost done ig
AstitvaAggarwal Aug 20, 2023
9f78a3a
ready player 1
AstitvaAggarwal Aug 21, 2023
652f8f4
added docs, minor changes, more tests
AstitvaAggarwal Aug 22, 2023
6bb3df3
prev tests did not pass the vibe check
AstitvaAggarwal Aug 23, 2023
7a8f4b5
tests
AstitvaAggarwal Aug 23, 2023
b46c211
test should pass
AstitvaAggarwal Aug 23, 2023
647aae0
ready player one
AstitvaAggarwal Aug 24, 2023
613228b
reduced iters
AstitvaAggarwal Aug 24, 2023
9329a4b
more changes
AstitvaAggarwal Aug 24, 2023
03be17c
optimizing tests
AstitvaAggarwal Aug 24, 2023
d2ceedd
yuh
AstitvaAggarwal Aug 24, 2023
e05ed86
.......
AstitvaAggarwal Aug 25, 2023
726cdb0
pls work man
AstitvaAggarwal Aug 25, 2023
ffcb277
|TT|
AstitvaAggarwal Aug 25, 2023
c5daa60
[TT]
AstitvaAggarwal Aug 25, 2023
4ba95ad
statistics dependancy compatib
AstitvaAggarwal Aug 25, 2023
464b82c
im back
AstitvaAggarwal Aug 25, 2023
2b5a211
Flux changes
AstitvaAggarwal Aug 25, 2023
186e326
.
AstitvaAggarwal Aug 25, 2023
9a1f9aa
less std for weights
AstitvaAggarwal Aug 26, 2023
f22481d
less std for weights
AstitvaAggarwal Aug 26, 2023
e735c84
Cleaner Tests, all pass, handled edge cases
AstitvaAggarwal Aug 29, 2023
63de20d
minor changes
AstitvaAggarwal Aug 29, 2023
9a94743
Julia versions affect accuracy
AstitvaAggarwal Aug 29, 2023
d4786c7
Added my suggested missing Loss function part, adjusted tests
AstitvaAggarwal Sep 1, 2023
9c38687
minor changes
AstitvaAggarwal Sep 1, 2023
ba58a2a
added example, fixed multi dependant variable case, verfied performan…
AstitvaAggarwal Sep 1, 2023
70c7175
fixed tests
AstitvaAggarwal Sep 2, 2023
345067a
float 64 flux layers
AstitvaAggarwal Sep 3, 2023
ab5700f
now uses diff training strategies,
AstitvaAggarwal Sep 9, 2023
37d5e34
tests
AstitvaAggarwal Sep 10, 2023
da7f26f
relaxed tests
AstitvaAggarwal Sep 10, 2023
324ae46
relaxed tests
AstitvaAggarwal Sep 11, 2023
5ca7d77
added docs
AstitvaAggarwal Sep 13, 2023
7831268
CompatHelper: add new compat entry for MonteCarloMeasurements at vers…
Sep 15, 2023
790ac82
CompatHelper: bump compat for SciMLBase to 2, (keep existing compat)
Sep 22, 2023
e6f588e
Docs modifications
Sep 3, 2023
aa39260
Removed last instances of GridTraining in tutorials
Sep 4, 2023
61a3c6e
update Project.toml
Oct 2, 2023
c9ab921
Merge branch 'docs_update' of github.com:sdesai1287/NeuralPDE.jl into…
Oct 2, 2023
3560a97
Merge branch 'master' into docs_update
ChrisRackauckas Oct 6, 2023
0528d75
Merge branch 'master' into docs_update
ChrisRackauckas Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleaner Tests, all pass, handled edge cases
  • Loading branch information
AstitvaAggarwal authored and Samedh Desai committed Oct 2, 2023
commit e735c84908ee00f3fe17a4e6dcd99b8d460e08ec
133 changes: 66 additions & 67 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
@@ -45,27 +45,26 @@ linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t))
sol = solve(prob, Tsit5(); saveat = 0.05)
u = sol.u[1:100]
time = sol.t[1:100]
x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u)))
x̂ = u .+ (u .* 0.2) .* randn(size(u))
dataset = [x̂, time]

chainflux12 = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh),
Flux.Dense(6, 1)) |> f64
chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))

alg = NeuralPDE.BNNODE(chainlux12, draw_samples = 2000,
alg = NeuralPDE.BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0),
n_leapfrog = 30, progress = true)

sol3lux = solve(prob, alg)
sol_lux = solve(prob, alg)

# parameter estimation
alg = NeuralPDE.BNNODE(chainlux12,dataset = dataset,
# with parameter estimation
alg = NeuralPDE.BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 3.0),
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
n_leapfrog = 30, progress = true)

sol3lux_pestim = solve(prob, alg)
sol_lux_pestim = solve(prob, alg)
```

## Solution Notes
@@ -87,10 +86,10 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El

"""
struct BNNODE{C, K, IT, A, M,
I <: Union{Nothing, Vector{<:AbstractFloat}},
P <: Union{Vector{Nothing}, Vector{<:Distribution}},
D <:
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}} <:
I <: Union{Nothing, Vector{<:AbstractFloat}},
P <: Union{Vector{Nothing}, Vector{<:Distribution}},
D <:
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}} <:
NeuralPDEAlgorithm
chain::C
Kernel::K
@@ -119,26 +118,26 @@ struct BNNODE{C, K, IT, A, M,
verbose::Bool

function BNNODE(chain, Kernel = HMC; draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing],
init_params = nothing,
physdt = 1 / 20.0, nchains = 1,
autodiff = false, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, max_depth = 10, Δ_max = 1000,
n_leapfrog = 20, δ = 0.65, λ = 0.3, progress = false,
verbose = false)
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing],
init_params = nothing,
physdt = 1 / 20.0, nchains = 1,
autodiff = false, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, max_depth = 10, Δ_max = 1000,
n_leapfrog = 20, δ = 0.65, λ = 0.3, progress = false,
verbose = false)
new{typeof(chain), typeof(Kernel), typeof(Integrator), typeof(Adaptor),
typeof(Metric), typeof(init_params), typeof(param),
typeof(dataset)}(chain, Kernel, draw_samples,
priorsNNw, param, l2std,
phystd, dataset, init_params,
physdt, nchains, autodiff, Integrator,
Adaptor, targetacceptancerate,
Metric, jitter_rate, tempering_rate,
max_depth, Δ_max, n_leapfrog,
δ, λ, progress, verbose)
priorsNNw, param, l2std,
phystd, dataset, init_params,
physdt, nchains, autodiff, Integrator,
Adaptor, targetacceptancerate,
Metric, jitter_rate, tempering_rate,
max_depth, Δ_max, n_leapfrog,
δ, λ, progress, verbose)
end
end

@@ -164,14 +163,14 @@ end

"""
BPINN Solution contains the original solution from AdvancedHMC.jl sampling(BPINNstats contains fields related to that)
> ensemblesol is the Probabilistic Etimate(MonteCarloMeasurements.jl Particles type) of Ensemble solution from All Neural Network's(made using all sampled parameters) output's.
> ensemblesol is the Probabilistic Estimate(MonteCarloMeasurements.jl Particles type) of Ensemble solution from All Neural Network's(made using all sampled parameters) output's.
> estimated_nn_params - Probabilistic Estimate of NN params from sampled weights,biases
> estimated_ode_params - Probabilistic Estimate of ODE params from sampled unknown ode paramters
"""
struct BPINNsolution{O <: BPINNstats, E,
NP <: Vector{<:MonteCarloMeasurements.Particles{<:Float64}},
OP <: Union{Vector{Nothing},
Vector{<:MonteCarloMeasurements.Particles{<:Float64}}}}
NP <: Vector{<:MonteCarloMeasurements.Particles{<:Float64}},
OP <: Union{Vector{Nothing},
Vector{<:MonteCarloMeasurements.Particles{<:Float64}}}}
original::O
ensemblesol::E
estimated_nn_params::NP
@@ -180,23 +179,23 @@ struct BPINNsolution{O <: BPINNstats, E,
function BPINNsolution(original, ensemblesol, estimated_nn_params, estimated_ode_params)
new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params),
typeof(estimated_ode_params)}(original, ensemblesol, estimated_nn_params,
estimated_ode_params)
estimated_ode_params)
end
end

function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
alg::BNNODE,
args...;
dt = nothing,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
reltol = 1.0f-3,
verbose = false,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = 500)
alg::BNNODE,
args...;
dt = nothing,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
reltol = 1.0f-3,
verbose = false,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
@unpack chain, l2std, phystd, param, priorsNNw, Kernel,
draw_samples, dataset, init_params, Integrator, Adaptor, Metric,
nchains, max_depth, Δ_max, n_leapfrog, physdt, targetacceptancerate,
@@ -210,26 +209,26 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
end

mcmcchain, samples, statistics = ahmc_bayesian_pinn_ode(prob, chain, dataset = dataset,
draw_samples = draw_samples,
init_params = init_params,
physdt = physdt, l2std = l2std,
phystd = phystd,
priorsNNw = priorsNNw,
param = param,
nchains = nchains,
autodiff = autodiff,
Kernel = Kernel,
Integrator = Integrator,
Adaptor = Adaptor,
targetacceptancerate = targetacceptancerate,
Metric = Metric,
jitter_rate = jitter_rate,
tempering_rate = tempering_rate,
max_depth = max_depth,
Δ_max = Δ_max,
n_leapfrog = n_leapfrog, δ = δ,
λ = λ, progress = progress,
verbose = verbose)
draw_samples = draw_samples,
init_params = init_params,
physdt = physdt, l2std = l2std,
phystd = phystd,
priorsNNw = priorsNNw,
param = param,
nchains = nchains,
autodiff = autodiff,
Kernel = Kernel,
Integrator = Integrator,
Adaptor = Adaptor,
targetacceptancerate = targetacceptancerate,
Metric = Metric,
jitter_rate = jitter_rate,
tempering_rate = tempering_rate,
max_depth = max_depth,
Δ_max = Δ_max,
n_leapfrog = n_leapfrog, δ = δ,
λ = λ, progress = progress,
verbose = verbose)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
91 changes: 37 additions & 54 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mutable struct LogTargetDensity{C, S, I, P <: Vector{<:Distribution},
D <:
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
}
D <:
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}},
}
dim::Int
prob::DiffEqBase.ODEProblem
chain::C
@@ -17,21 +17,13 @@ mutable struct LogTargetDensity{C, S, I, P <: Vector{<:Distribution},
extraparams::Int
init_params::I

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
new{
typeof(chain),
Nothing,
typeof(strategy),
typeof(init_params),
typeof(priors),
typeof(dataset),
}(dim,
new{typeof(chain), Nothing, typeof(init_params), typeof(priors), typeof(dataset)}(dim,
prob,
chain,
nothing, strategy,
nothing,
dataset,
priors,
phystd,
@@ -41,20 +33,13 @@ mutable struct LogTargetDensity{C, S, I, P <: Vector{<:Distribution},
extraparams,
init_params)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
new{
typeof(chain),
typeof(st),
typeof(strategy),
typeof(init_params),
typeof(priors),
typeof(dataset),
new{typeof(chain), typeof(st), typeof(init_params), typeof(priors), typeof(dataset)
}(dim,
prob,
chain, st, strategy,
chain, st,
dataset, priors,
phystd, l2std,
autodiff,
@@ -356,7 +341,7 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
t = collect(eltype(dt), Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])
else
t = vcat(collect(eltype(dt), Tar.prob.tspan[1]:dt:Tar.prob.tspan[2]),
Tar.dataset[end])
Tar.dataset[end])
end

# parameter estimation chosen or not
@@ -382,13 +367,13 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
# this is a vector{vector{dx,dy}}(handle case single u(float passed))
if length(out[:, 1]) == 1
physsol = [f(out[:, i][1],
ode_params,
t[i])
ode_params,
t[i])
for i in 1:length(out[1, :])]
else
physsol = [f(out[:, i],
ode_params,
t[i])
ode_params,
t[i])
for i in 1:length(out[1, :])]
end
physsol = reduce(hcat, physsol)
@@ -400,10 +385,10 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(nnsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
Tar.phystd[i] .*
ones(length(physsol[i, :]))))),
physsol[i, :])
LinearAlgebra.Diagonal(map(abs2,
Tar.phystd[i] .*
ones(length(physsol[i, :]))))),
physsol[i, :])
end
return physlogprob
end
@@ -421,10 +406,10 @@ function L2LossData(Tar::LogTargetDensity, θ)
for i in 1:length(Tar.prob.u0)
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(MvNormal(nn[i, :],
LinearAlgebra.Diagonal(map(abs2,
Tar.l2std[i] .*
ones(length(Tar.dataset[i]))))),
Tar.dataset[i])
LinearAlgebra.Diagonal(map(abs2,
Tar.l2std[i] .*
ones(length(Tar.dataset[i]))))),
Tar.dataset[i])
end
return L2logprob
end
@@ -565,26 +550,23 @@ n_leapfrog -> number of leapfrog steps for HMC
λ -> target trajectory length for HMCDA
progress -> controls whether to show the progress meter or not.
verbose -> controls the verbosity. (Sample call args in AHMC)

"""

"""

# dataset would be (x̂,t)
# priors: pdf for W,b + pdf for ODE params
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1,
autodiff = false,
Kernel = HMC, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, max_depth = 10, Δ_max = 1000,
n_leapfrog = 10, δ = 0.65, λ = 0.3, progress = false,
verbose = false)
dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1,
autodiff = false,
Kernel = HMC, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, max_depth = 10, Δ_max = 1000,
n_leapfrog = 10, δ = 0.65, λ = 0.3, progress = false,
verbose = false)

# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
@@ -642,8 +624,9 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;

t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
ℓπ = LogTargetDensity(nparameters, prob, recon, st, dataset, priors,
phystd, l2std, autodiff, physdt, ninv,
initial_nnθ)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Loading