-
-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #724 from AstitvaAggarwal/develop
[WIP] API for BNNODE
- Loading branch information
Showing
6 changed files
with
995 additions
and
350 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Bayesian Physics informed Neural Network ODEs Solvers | ||
|
||
Bayesian inference for PINNs provides an approach to ODE solution finding and parameter estimation with quantified uncertainty. | ||
|
||
## The Lotka-Volterra Model | ||
|
||
The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order nonlinear differential equations. | ||
These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. | ||
The populations change through time according to the pair of equations | ||
|
||
$$ | ||
\begin{aligned} | ||
\frac{\mathrm{d}x}{\mathrm{d}t} &= (\alpha - \beta y(t))x(t), \\ | ||
\frac{\mathrm{d}y}{\mathrm{d}t} &= (\delta x(t) - \gamma)y(t) | ||
\end{aligned} | ||
$$ | ||
|
||
where $x(t)$ and $y(t)$ denote the populations of prey and predator at time $t$, respectively, and $\alpha, \beta, \gamma, \delta$ are positive parameters. | ||
|
||
We implement the Lotka-Volterra model and simulate it with ideal parameters $\alpha = 1.5$, $\beta = 1$, $\gamma = 3$, and $\delta = 1$ and initial conditions $x(0) = y(0) = 1$. | ||
|
||
We then solve the equations and estimate the parameters of the model with priors for $\alpha$, $\beta$, $\gamma$ and $\delta$ as Normal(1,2), Normal(2,2), Normal(2,2) and Normal(0,2) using a Flux.jl Neural Network, chain_flux. | ||
|
||
And also solve the equations for the constructed ODEProblem's provided ideal `p` values using a Lux.jl Neural Network, chain_lux. | ||
|
||
```julia | ||
function lotka_volterra(u, p, t) | ||
# Model parameters. | ||
α, β, γ, δ = p | ||
# Current state. | ||
x, y = u | ||
|
||
# Evaluate differential equations. | ||
dx = (α - β * y) * x # prey | ||
dy = (δ * x - γ) * y # predator | ||
|
||
return [dx, dy] | ||
end | ||
|
||
# initial-value problem. | ||
u0 = [1.0, 1.0] | ||
p = [1.5, 1.0, 3.0, 1.0] | ||
tspan = (0.0, 6.0) | ||
prob = ODEProblem(lotka_volterra, u0, tspan, p) | ||
|
||
# Plot simulation. | ||
``` | ||
With the [`saveat` argument](https://docs.sciml.ai/latest/basics/common_solver_opts/) we can specify that the solution is stored only at `saveat` time units(default saveat=1 / 50.0). | ||
|
||
```julia | ||
solution = solve(prob, Tsit5(); saveat = 0.05) | ||
plot(solve(prob, Tsit5())) | ||
|
||
``` | ||
|
||
We generate noisy observations to use for the parameter estimation tasks in this tutorial. | ||
To make the example more realistic we add random normally distributed noise to the simulation. | ||
|
||
|
||
```julia | ||
# Dataset creation for parameter estimation | ||
time = solution.t | ||
u = hcat(solution.u...) | ||
x = u[1, :] + 0.5 * randn(length(u[1, :])) | ||
y = u[2, :] + 0.5 * randn(length(u[1, :])) | ||
dataset = [x, y, time] | ||
|
||
# Neural Networks must have 2 outputs as u -> [dx,dy] in function lotka_volterra() | ||
chainflux = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh), Flux.Dense(6, 2)) |> Flux.f64 | ||
|
||
chainlux = Lux.Chain(Lux.Dense(1, 6, Lux.tanh), Lux.Dense(6, 6, Lux.tanh), Lux.Dense(6, 2)) | ||
``` | ||
A Dataset is required as parameter estimation is being done using provided priors in `param` keyword argument for BNNODE. | ||
|
||
```julia | ||
alg1 = NeuralPDE.BNNODE(chainflux, | ||
dataset = dataset, | ||
draw_samples = 1000, | ||
l2std = [ | ||
0.05, | ||
0.05, | ||
], | ||
phystd = [ | ||
0.05, | ||
0.05, | ||
], | ||
priorsNNw = (0.0, | ||
3.0), | ||
param = [ | ||
Normal(1, | ||
2), | ||
Normal(2, | ||
2), | ||
Normal(2, | ||
2), | ||
Normal(0, | ||
2), | ||
], | ||
n_leapfrog = 30, progress = true) | ||
|
||
sol_flux_pestim = solve(prob, alg1) | ||
|
||
# Dataset not needed as we are solving the equation with ideal parameters | ||
alg2 = NeuralPDE.BNNODE(chainlux, | ||
draw_samples = 1000, | ||
l2std = [ | ||
0.05, | ||
0.05, | ||
], | ||
phystd = [ | ||
0.05, | ||
0.05, | ||
], | ||
priorsNNw = (0.0, | ||
3.0), | ||
n_leapfrog = 30, progress = true) | ||
|
||
sol_lux = solve(prob, alg2) | ||
|
||
#testing timepoints must match keyword arg `saveat`` timepoints of solve() call | ||
t=collect(Float64,prob.tspan[1]:1/50.0:prob.tspan[2]) | ||
|
||
``` | ||
|
||
the solution for the ODE is retured as a nested vector sol_flux_pestim.ensemblesol. | ||
here, [$x$ , $y$] would be returned | ||
All estimated ode parameters are returned as a vector sol_flux_pestim.estimated_ode_params. | ||
here, [$\alpha$, $\beta$, $\gamma$, $\delta$] | ||
|
||
```julia | ||
# plotting solution for x,y for chain_flux | ||
plot(t,sol_flux_pestim.ensemblesol[1]) | ||
plot!(t,sol_flux_pestim.ensemblesol[2]) | ||
|
||
# estimated ODE parameters by .estimated_ode_params, weights and biases by .estimated_nn_params | ||
sol_flux_pestim.estimated_nn_params | ||
sol_flux_pestim.estimated_ode_params | ||
|
||
# plotting solution for x,y for chain_lux | ||
plot(t,sol_lux_pestim.ensemblesol[1]) | ||
plot!(t,sol_lux_pestim.ensemblesol[2]) | ||
|
||
# estimated weights and biases by .estimated_nn_params for chain_lux | ||
sol_lux_pestim.estimated_nn_params | ||
``` |
Oops, something went wrong.